三次简化一张图:一招理解LSTM/GRU门控机制

机器之心专栏

作者:张皓

RNN 在处理时序数据时十分成功。但是,对 RNN 及其变种 LSTM 和 GRU 结构的理解仍然是一个困难的任务。本文介绍一种理解 LSTM 和 GRU 的简单通用的方法。通过对

LSTM

和GRU数学形式化的三次简化,最后将数据流形式画成一张图,可以简洁直观地对其中的原理进行理解与分析。此外,本文介绍的三次简化一张图的分析方法具有普适性,可广泛用于其他门控网络的分析。

1. RNN、梯度爆炸与梯度消失

1.1 RNN

近些年,深度学习模型在处理有非常复杂内部结构的数据时十分有效。例如,图像数据的像素之间的 2 维空间关系非常重要,CNN(convolution neural networks,卷积神经网络)处理这种空间关系十分有效。而时序数据(sequential data)的变长输入序列之间时序关系非常重要,RNN(recurrent neural networks,循环神经网络,注意和 recursive neural networks,递归神经网络的区别)处理这种时序关系十分有效。

我们使用下标 t 表示输入时序序列的不同位置,用 h_t 表示在时刻 t 的系统隐层状态向量,用 x_t 表示时刻 t 的输入。t 时刻的隐层状态向量 h_t 依赖于当前词 x_t 和前一时刻的隐层状态向量 h_(t-1):

其中 f 是一个非线性映射函数。一种通常的做法是计算 x_t 和h_(t-1)的线性变换后经过一个非线性激活函数,例如

其中 W_(xh) 和 W_(hh) 是可学习的参数矩阵,激活函数 tanh 独立地应用到其输入的每个元素。

为了对 RNN 的计算过程做一个可视化,我们可以画出下图:

图中左边是输入 x_t 和 h_(t-1)、右边是输出 h_t。计算从左向右进行,整个运算包括三步:输入 x_t 和 h_(t-1) 分别乘以 W_(xh) 和 W_(hh) 、相加、经过 tanh 非线性变换。

我们可以认为 h_t 储存了网络中的记忆(memory),RNN 学习的目标是使得 h_t 记录了在 t 时刻之前(含)的输入信息 x_1, x_2,..., x_t。在新词 x_t 输入到网络之后,之前的隐状态向量 h_(t-1) 就转换为和当前输入 x_t 有关的 h_t。

1.2 梯度爆炸与梯度消失

虽然理论上 RNN 可以捕获长距离依赖,但实际应用中,RNN 将会面临两个挑战:梯度爆炸(gradient explosion)和梯度消失(vanishing gradient)。

我们考虑一种简单情况,即激活函数是恒等(identity)变换,此时

在进行误差反向传播(error backpropagation)时,当我们已知损失函数对 t 时刻隐状态向量 h_t 的偏导数时,利用链式法则,我们计算损失函数对 t 时刻隐状态向量 h_0 的偏导数

我们可以利用 RNN 的依赖关系,沿时间维度展开,来计算

也就是说,在误差反向传播时我们需要反复乘以参数矩阵 W_(hh)。我们对矩阵 W_(hh) 进行奇异值分解(SVD)

其中 r 是矩阵 W_(hh) 的秩(rank)。因此,

那么我们最后要计算的目标

当 t 很大时,该偏导数取决于矩阵 W_(hh) 的最大的奇异值是大于 1 还是小于 1,要么结果太大,要么结果太小:

(1).梯度爆炸。当> 1,,那么

此时偏导数将会变得非常大,实际在训练时将会遇到 NaN 错误,会影响训练的收敛,甚至导致网络不收敛。这好比要把本国的产品卖到别的国家,结果被加了层层关税,等到了别国市场的时候,价格已经变得非常高,老百姓根本买不起。在 RNN 中,梯度(偏导数)就是价格,随着向前推移,梯度越來越大。这种现象称为梯度爆炸。

梯度爆炸相对比较好处理,可以用梯度裁剪(gradient clipping)来解决:

这好比是不管前面的关税怎么加,设置一个最高市场价格,通过这个最高市场价格保证老百姓是买的起的。在 RNN 中,不管梯度回传的时候大到什么程度,设置一个梯度的阈值,梯度最多是这么大。

(2).梯度消失。当

此时偏导数将会变得十分接近 0,从而在梯度更新前后没有什么区别,这会使得网络捕获长距离依赖(long-term dependency)的能力下降。这好比打仗的时候往前线送粮食,送粮食的队伍自己也得吃粮食。当补给点离前线太远时,还没等送到,粮食在半路上就已经被吃完了。在 RNN 中,梯度(偏导数)就是粮食,随着向前推移,梯度逐渐被消耗殆尽。这种现象称为梯度消失。

梯度消失现象解决起来困难很多,如何缓解梯度消失是 RNN 及几乎其他所有深度学习方法研究的关键所在。LSTM 和 GRU 通过门(gate)机制控制 RNN 中的信息流动,用来缓解梯度消失问题。其核心思想是有选择性的处理输入。比如我们在看到一个商品的评论时

Amazing! This box of cereal gave me a perfectly balanced breakfast, as all things should be. In only ate half of it but will definitely be buying again!

我们会重点关注其中的一些词,对它们进行处理

Amazing!This box of cereal gave me aperfectly balanced breakfast, as all things should be. In only ate half of it but willdefinitely be buying again!

LSTM 和 GRU 的关键是会选择性地忽略其中一些词,不让其参与到隐层状态向量的

更新中,最后只保留相关的信息进行预测。

2. LSTM

2.1 LSTM 的数学形式

LSTM(Long Short-Term Memory)由 Hochreiter 和 Schmidhuber 提出,其数学上的形式化表示如下:

其中代表逐元素相乘,sigm 代表 sigmoid 函数

和 RNN 相比,LSTM 多了一个隐状态变量 c_t,称为细胞状态(cell state),用来记录信息。

这个公式看起来似乎十分复杂,为了更好的理解 LSTM 的机制,许多人用图来描述 LSTM 的计算过程。比如下面这张图:

似乎看完之后,对 LSTM 的理解仍然是一头雾水?这是因为这些图想把 LSTM 的所有细节一次性都展示出来,但是突然暴露这么多的细节会使你眼花缭乱,从而无处下手。

2.2 三次简化一张图

因此,本文提出的方法旨在简化门控机制中不重要的部分,从而更关注在 LSTM 的核心思想。整个过程是三次简化一张图,具体流程如下:

(1).第一次简化:忽略门控单元 i_t 、f_t 、o_t 的来源。3 个门控单元的计算方法完全相同,都是由输入经过线性映射得到的,区别只是计算的参数不同:

使用相同计算方式的目的是它们都扮演了门控的角色,而使用不同参数的目的是为了误差反向传播时对三个门控单元独立地进行更新。在理解 LSTM 运行机制的时候,为了对图进行简化,我们不在图中标注三个门控单元的计算过程,并假定各门控单元是给定的。

(2).第二次简化:考虑一维门控单元 i_t 、 f_t 、 o_t。LSTM 中对各维是独立进行门控的,所以为了表示和理解方便,我们只需要考虑一维情况,在理解 LSTM 原理之后,将一维推广到多维是很直接的。经过这两次简化,LSTM 的数学形式只有下面三行

由于门控单元变成了一维,所以向量和向量的逐元素相乘符号变成了数和向量相乘 · 。

(3).第三次简化:各门控单元二值输出。门控单元 i_t 、f_t 、o_t 的由于经过了 sigmoid 激活函数,输出是范围是 [0, 1]。激活函数使用 sigmoid 的目的是为了近似 0/1 阶跃函数,这样 sigmoid 实数值输出单调可微,可以基于误差反向传播进行更新。

既然 sigmoid 激活函数是为了近似 0/1 阶跃函数,那么,在进行 LSTM 理解分析的时候,为了理解方便,我们认为各门控单元 二值输出,即门控单元扮演了电路中开关的角色,用于控制信息的通断。

(4).一张图。将三次简化的结果用电路图表述出来,左边是输入,右边是输出。在 LSTM 中,有一点需要特别注意,LSTM 中的细胞状态 c_t 实质上起到了 RNN 中隐层单元 h_t 的作用,这点在其他文献资料中不常被提到,所以整个图的输入是 x_t 和 c_,而不是 x_t 和 h_(t-1)。为了方便画图,我们需要将公式做最后的调整

最终结果如下:

和 RNN 相同的是,网络接受两个输入,得到一个输出。其中使用了两个参数矩阵 W_(xc) 和 W_(hc),以及 tanh 激活函数。不同之处在于,LSTM 中通过 3 个门控单元 i_t 、f_t 、o_t 来对的信息交互进行控制。当 i_t=1(开关闭合)、f_t=0(开关打开)、o_t=1(开关闭合)时,LSTM 退化为标准的 RNN。

2.3 LSTM 各单元作用分析

根据这张图,我们可以对 LSTM 中各单元作用进行分析:

输出门 o_(t-1):输出门的目的是从细胞状态 c_(t-1) 产生隐层单元 h_(t-1)。并不是c_(t-1)中的全部信息都和隐层单元h_(t-1)有关,c_(t-1)可能包含了很多对h_(t-1)无用的信息。因此,o_t 的作用就是判断c_(t-1)中哪些部分是对h_(t-1)有用的,哪些部分是无用的。

输入门 i_t。i_t 控制当前词 x_t 的信息融入细胞状态 c_t。在理解一句话时,当前词 x_t 可能对整句话的意思很重要,也可能并不重要。输入门的目的就是判断当前词 x_t 对全局的重要性。当 i_t 开关打开的时候,网络将不考虑当前输入 x_t。

遗忘门 f_t: f_t 控制上一时刻细胞状态c_(t-1)的信息融入细胞状态 c_t。在理解一句话时,当前词 x_t 可能继续延续上文的意思继续描述,也可能从当前词 x_t 开始描述新的内容,与上文无关。和输入门 i_t 相反,f_t 不对当前词 x_t 的重要性作判断,而判断的是上一时刻的细胞状态c_(t-1)对计算当前细胞状态 c_t 的重要性。当 f_t 开关打开的时候,网络将不考虑上一时刻的细胞状态c_(t-1)。

细胞状态c_t:c_t 综合了当前词 x_t 和前一时刻细胞状态c_(t-1)的信息。这和 ResNet 中的残差逼近思想十分相似,通过从c_(t-1)到 c_t 的「短路连接」,梯度得已有效地反向传播。当 f_t 处于闭合状态时,c_t 的梯度可以直接沿着最下面这条短路线传递到c_(t-1),不受参数 W_(xh) 和 W_(hh) 的影响,这是 LSTM 能有效地缓解梯度消失现象的关键所在。

3. GRU

3.1 GRU 的数学形式

GRU 是另一种十分主流的 RNN 衍生物。RNN 和 LSTM 都是在设计网络结构用于缓解梯度消失问题,只不过是网络结构有所不同。GRU 在数学上的形式化表示如下:

3.2 三次简化一张图

为了理解 GRU 的设计思想,我们再一次运用三次简化一张图的方法来进行分析:

(1).第一次简化:忽略门控单元 z_t 和 r_t 的来源。

(2).考虑一维门控单元 z_t 和 r_t。经过这两次简化,GRU 的数学形式是以下两行

(3).第三次简化:各门控单元二值输出。这里和 LSTM 略有不同的地方在于,当 z_t=1 时h_t = h_(t-1) ;而当 z_t = 0 时,h_t =。因此,z_t 扮演的角色是一个个单刀双掷开关。

(4).一张图。将三次简化的结果用电路图表述出来,左边是输入,右边是输出。

与 LSTM 相比,GRU 将输入门 i_t 和遗忘门 f_t 融合成单一的更新门 z_t,并且融合了细胞状态 c_t 和隐层单元 h_t。当 r_t=1(开关闭合)、 z_t=0(开关连通上面)GRU 退化为标准的 RNN。

3.3 GRU 各单元作用分析

根据这张图, 我们可以对 GRU 的各单元作用进行分析:

重置门 r_t: r_t 用于控制前一时刻隐层单元 h_(t-1) 对当前词 x_t 的影响。如果h_(t-1)对 x_t 不重要,即从当前词 x_t 开始表述了新的意思,与上文无关。那么开关 r_t 可以打开,使得h_(t-1)对 x_t 不产生影响。

更新门 z_t: z_t 用于决定是否忽略当前词 x_t。类似于 LSTM 中的输入门 i_t,z_t 可以判断当前词 x_t 对整体意思的表达是否重要。当 z_t 开关接通下面的支路时,我们将忽略当前词 x_t,同时构成了从h_(t-1)到 h_t 的短路连接,这使得梯度得已有效地反向传播。和 LSTM 相同,这种短路机制有效地缓解了梯度消失现象,这个机制于 highway networks 十分相似。

4. 小结

尽管 RNN、LSTM、和 GRU 的网络结构差别很大,但是他们的基本计算单元是一致的,都是对 x_t 和 h_t 做一个线性映射加 tanh 激活函数,见三个图的红色框部分。他们的区别在于如何设计额外的门控机制控制梯度信息传播用以缓解梯度消失现象。LSTM 用了 3 个门、GRU 用了 2 个,那能不能再少呢?MGU(minimal gate unit)尝试对这个问题做出回答,它只有一个门控单元。最后留个小练习,参考 LSTM 和 GRU 的例子,你能不能用三次简化一张图的方法来分析一下 MGU 呢?

参考文献

Yoshua Bengio, Patrice Y. Simard, and Paolo Frasconi. Learning long-term dependencies with gradient descent is difficult. IEEE Transactions on Neural Networks 5(2): 157-166, 1994.

Kyunghyun Cho, Bart van Merrienboer, Çaglar Gülçehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, and Yoshua Bengio. Learning phrase representations using RNN encoder-decoder for statistical machine translation. In EMNLP, pages 1724-1734, 2014.

Junyoung Chung, Çaglar Gülçehre, KyungHyun Cho, and Yoshua Bengio. Empirical evaluation of gated recurrent neural networks on sequence modeling. In NIPS Workshop, pages 1-9, 2014.

Felix Gers. Long short-term memory in recurrent neural networks. PhD Dissertation, Ecole Polytechnique Fédérale de Lausanne, 2001.

Ian J. Goodfellow, Yoshua Bengio, and Aaron C. Courville. Deep learning. Adaptive Computation and Machine Learning, MIT Press, ISBN 978-0-262-03561-3, 2016.

Alex Graves. Supervised sequence labelling with recurrent neural networks. Studies in Computational Intelligence 385, Springer, ISBN 978-3-642-24796-5, 2012.

Klaus Greff, Rupesh Kumar Srivastava, Jan Koutník, Bas R. Steunebrink, and Jürgen Schmidhuber. LSTM: A search space odyssey. IEEE Transactions on Neural Networks and Learning Systems. 28(10): 2222-2232, 2017.

Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In CVPR, pages 770-778, 2016.

Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Identity mappings in deep residual networks. In ECCV, pages 630-645, 2016.

Sepp Hochreiter and Jürgen Schmidhuber. Long short-term memory. Neural Computation 9(8): 1735-1780, 1997.

Rafal Józefowicz, Wojciech Zaremba, and Ilya Sutskever. An empirical exploration of recurrent network architectures. In ICML, pages 2342-2350, 2015.

Zachary Chase Lipton. A critical review of recurrent neural networks for sequence learning. CoRR abs/1506.00019, 2015.

Razvan Pascanu, Tomas Mikolov, and Yoshua Bengio. On the difficulty of training recurrent neural networks. In ICML, pages 1310-1318, 2013.

Rupesh Kumar Srivastava, Klaus Greff, and Jürgen Schmidhuber. Highway networks. In ICML Workshop, pages 1-6, 2015.

Guo-Bing Zhou, Jianxin Wu, Chen-Lin Zhang, and Zhi-Hua Zhou. Minimal gated unit for recurrent neural networks. International Journal of Automation and Computing, 13(3): 226-234, 2016.

本文为机器之心专栏,转载请联系本公众号获得授权。

------------------------------------------------

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20181218A0H9OG00?refer=cp_1026
  • 腾讯「云+社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。

扫码关注云+社区

领取腾讯云代金券