本文为雷锋字幕组编译的技术博客,原标题 Step-by-step walkthrough of RNN Training - Part I,作者为 Eniola Alese。 翻译 | 赵朋飞 程思婕 整理 | 凡江
单个 RNN Cell 中的前向传播算法
在之前的文章中,我们介绍了 RNN 的基本结构并将其按时间序列展开成 Cells 循环链,称为 RNN cells。下面,我们将揭示单个 RNN Cell 的内部结构和前向传播计算过程。
将其过程分解成多个步骤:
第一步:cell 接受两个输入:x⟨t⟩ 和 a⟨t-1⟩。
第二步:接下来,计算矩阵乘积 ⨂,W_xh 乘 x⟨t⟩,W_ah 乘 a⟨t-1⟩。然后,通过将以上两项乘积相加,并加上偏置 b_h,计算得出 h⟨t⟩。
第三步:紧接着上面的步骤,将 h(t) 传给激活函数(比如 tanh 或 relu)计算 a(t)。本例中激活函数使用 tanh 函数。
第四步: cell 输出 a⟨t⟩ 并将其传给下一 cell 做进一步计算。
第五步:然后,计算 o⟨t⟩; 这是所有输出可能取值的非标准化对数概率。方法是计算矩阵乘积⨂,W_ao 乘 a⟨t⟩,并与 b_o 相加。
第六步:最后,通过将 o⟨t⟩ 传输给激活函数(例如 sigmoid 或 softmax),得到了一个实际输出的标准化概率向量 ŷ⟨t⟩。输出的激活函数的选择通常取决于期望的输出类型(sigmoid 用于二元输出,softmax 用于多类别输出)。
前向传播算法
前向传播算法在整个 RNN 网络中运行以上步骤,而不仅仅在单个 RNN cell 中运行。从隐藏层状态 a⟨0⟩的初始化开始,在所有时间序列 t = 1 to T 中共享权值和偏置向量 W_xh,W_ah, W_ao, b_h, b_o,在每个时间序列中重复上面的每一步。
例如,如果我们拥有一个 8 个序列的输入 x⟨1⟩,x⟨2⟩,......x⟨8⟩,这个网络的前向传播计算过程是步骤 1-6 在循环中重复 8 次。
单个 RNN 单元的反向传播
RNN 中反向传播的目的是计算出最终的损失值 L 分别对权值矩阵(W_xh,W_ah,W_ao)和偏置向量(b_h,b_o)的偏导数值。
推导出所需的导数值非常简单,我们只需要利用链式法则(https://en.wikipedia.org/wiki/Chain_rule)就能计算出它们。
第一步:为了计算代价,需要先定义损失函数。一般根据具体手中的任务来选择该损失函数。在这个例子里,对于多分类输出问题,我们采用交叉熵损失函数 L⟨t⟩,其具体计算过程如下:
第二步:接下来我们开始往后计算损失函数 L⟨t⟩ 对预测输出值的激活值 ŷ⟨t⟩ 的偏导数值。因为在前向传播过程中 softmax 函数以多分类的输出值作为输入,因此下面的偏导数值
的计算分为两种情况:分类 i 时和分类 k 时:
第三步:接着利用分类 i 时和分类 k 时的偏导数值
,可以计算出损失函数 L⟨t⟩ 对预测输出值 o⟨t⟩ 的偏导数值:
第四步:利用偏导数值
及链式法则,计算出损失函数 L⟨t⟩ 对输出过程中的偏置向量 b_o 的偏导数值:
第五步:利用偏导数值
及链式法则,计算出损失函数 L⟨t⟩ 对隐层至输出层中的权值矩阵 W_ao 的偏导数值:
第六步: 利用偏导数值
、
及链式法则,计算出损失函数 L⟨t⟩ 对隐状态的激活值 a⟨t⟩ 的偏导数值:
第七步: 利用偏导数值
及链式法则,计算出损失函数 L⟨t⟩ 对隐状态 h⟨t⟩ 的偏导数值:
第八步: 利用偏导数值
及链式法则,计算出损失函数 L⟨t⟩ 对隐状态的偏置向量 b_h 的偏导数值:
第九步:利用偏导数值
及链式法则,计算出损失函数 L⟨t⟩ 对输入层至隐层中的偏置矩阵 W_xh 的偏导数值:
第十步:利用偏导数值
及链式法则,计算出损失函数 L⟨t⟩ 对输入层至隐层中的偏置矩阵 W_ah 的偏导数值:
随时间反向传播(BPTT)
就像前文中提到的前向传播过程一样,将循环网络展开,BPTT 将沿此一直运行着上述步骤。
主要的区别在于我们必须将每个时间步 t 的偏导数值
累加起来,从而更新权值和偏置,这是因为这些参数在前向传播的过程中是被各个时间步所共享的。
总结
在本文的第一部分和第二部分中,我们了解了循环神经网络训练过程中所涉及到的前向传播和反向传播。接下来,我们将着眼于 RNN 中所存在的梯度消失问题,并讨论 LSTM 和 GRU 网络的进展。
博客原址:
Part I
https://medium.com/learn-love-ai/step-by-step-walkthrough-of-rnn-training-part-i-7aee5672dea3
Part II
https://medium.com/learn-love-ai/step-by-step-walkthrough-of-rnn-training-part-ii-7141084d274b
RNN 输入变长序列 padding