机器之心专栏
作者:张皓
RNN 在处理时序数据时十分成功。但是,对 RNN 及其变种 LSTM 和 GRU 结构的理解仍然是一个困难的任务。本文介绍一种理解 LSTM 和 GRU 的简单通用的方法。通过对 LSTM 和 GRU 数学形式化的三次简化,最后将数据流形式画成一张图,可以简洁直观地对其中的原理进行理解与分析。此外,本文介绍的三次简化一张图的分析方法具有普适性,可广泛用于其他门控网络的分析。
1. RNN、梯度爆炸与梯度消失
1.1 RNN
近些年,深度学习模型在处理有非常复杂内部结构的数据时十分有效。例如,图像数据的像素之间的 2 维空间关系非常重要,CNN(convolution neural networks,卷积神经网络)处理这种空间关系十分有效。而时序数据(sequential data)的变长输入序列之间时序关系非常重要,RNN(recurrent neural networks,循环神经网络,注意和 recursive neural networks,递归神经网络的区别)处理这种时序关系十分有效。
我们使用下标 t 表示输入时序序列的不同位置,用 h_t 表示在时刻 t 的系统隐层状态向量,用 x_t 表示时刻 t 的输入。t 时刻的隐层状态向量 h_t 依赖于当前词 x_t 和前一时刻的隐层状态向量 h_(t-1):
其中 f 是一个非线性映射函数。一种通常的做法是计算 x_t 和 h_(t-1) 的线性变换后经过一个非线性激活函数,例如
其中 W_(xh) 和 W_(hh) 是可学习的参数矩阵,激活函数 tanh 独立地应用到其输入的每个元素。
为了对 RNN 的计算过程做一个可视化,我们可以画出下图:
图中左边是输入 x_t 和 h_(t-1)、右边是输出 h_t。计算从左向右进行,整个运算包括三步:输入 x_t 和 h_(t-1) 分别乘以 W_(xh) 和 W_(hh) 、相加、经过 tanh 非线性变换。
我们可以认为 h_t 储存了网络中的记忆(memory),RNN 学习的目标是使得 h_t 记录了在 t 时刻之前(含)的输入信息 x_1, x_2,..., x_t。在新词 x_t 输入到网络之后,之前的隐状态向量 h_(t-1) 就转换为和当前输入 x_t 有关的 h_t。
1.2 梯度爆炸与梯度消失
虽然理论上 RNN 可以捕获长距离依赖,但实际应用中,RNN 将会面临两个挑战:梯度爆炸(gradient explosion)和梯度消失(vanishing gradient)。
我们考虑一种简单情况,即激活函数是恒等(identity)变换,此时
在进行误差反向传播(error backpropagation)时,当我们已知损失函数
对 t 时刻隐状态向量 h_t 的偏导数
时,利用链式法则,我们计算损失函数
对 t 时刻隐状态向量 h_0 的偏导数
我们可以利用 RNN 的依赖关系,沿时间维度展开,来计算
也就是说,在误差反向传播时我们需要反复乘以参数矩阵 W_(hh)。我们对矩阵 W_(hh) 进行奇异值分解(SVD)
其中 r 是矩阵 W_(hh) 的秩(rank)。因此,
那么我们最后要计算的目标
当 t 很大时,该偏导数取决于矩阵 W_(hh) 的最大的奇异值
是大于 1 还是小于 1,要么结果太大,要么结果太小:
(1). 梯度爆炸。当
> 1,
,那么
此时偏导数
将会变得非常大,实际在训练时将会遇到 NaN 错误,会影响训练的收敛,甚至导致网络不收敛。这好比要把本国的产品卖到别的国家,结果被加了层层关税,等到了别国市场的时候,价格已经变得非常高,老百姓根本买不起。在 RNN 中,梯度(偏导数)就是价格,随着向前推移,梯度越來越大。这种现象称为梯度爆炸。
梯度爆炸相对比较好处理,可以用梯度裁剪(gradient clipping)来解决:
这好比是不管前面的关税怎么加,设置一个最高市场价格,通过这个最高市场价格保证老百姓是买的起的。在 RNN 中,不管梯度回传的时候大到什么程度,设置一个梯度的阈值,梯度最多是这么大。
(2). 梯度消失。当
< 1,
,那么
此时偏导数
将会变得十分接近 0,从而在梯度更新前后没有什么区别,这会使得网络捕获长距离依赖(long-term dependency)的能力下降。这好比打仗的时候往前线送粮食,送粮食的队伍自己也得吃粮食。当补给点离前线太远时,还没等送到,粮食在半路上就已经被吃完了。在 RNN 中,梯度(偏导数)就是粮食,随着向前推移,梯度逐渐被消耗殆尽。这种现象称为梯度消失。
梯度消失现象解决起来困难很多,如何缓解梯度消失是 RNN 及几乎其他所有深度学习方法研究的关键所在。LSTM 和 GRU 通过门(gate)机制控制 RNN 中的信息流动,用来缓解梯度消失问题。其核心思想是有选择性的处理输入。比如我们在看到一个商品的评论时
Amazing! This box of cereal gave me a perfectly balanced breakfast, as all things should be. In only ate half of it but will definitely be buying again!
我们会重点关注其中的一些词,对它们进行处理
Amazing! This box of cereal gave me a perfectly balanced breakfast, as all things should be. In only ate half of it but will definitely be buying again!
LSTM 和 GRU 的关键是会选择性地忽略其中一些词,不让其参与到隐层状态向量的
更新中,最后只保留相关的信息进行预测。
2. LSTM
2.1 LSTM 的数学形式
LSTM(Long Short-Term Memory)由 Hochreiter 和 Schmidhuber 提出,其数学上的形式化表示如下:
其中
代表逐元素相乘,sigm 代表 sigmoid 函数
和 RNN 相比,LSTM 多了一个隐状态变量 c_t,称为细胞状态(cell state),用来记录信息。
这个公式看起来似乎十分复杂,为了更好的理解 LSTM 的机制,许多人用图来描述 LSTM 的计算过程。比如下面这张图:
似乎看完之后,对 LSTM 的理解仍然是一头雾水?这是因为这些图想把 LSTM 的所有细节一次性都展示出来,但是突然暴露这么多的细节会使你眼花缭乱,从而无处下手。
2.2 三次简化一张图
因此,本文提出的方法旨在简化门控机制中不重要的部分,从而更关注在 LSTM 的核心思想。整个过程是三次简化一张图,具体流程如下:
(1). 第一次简化:忽略门控单元 i_t 、f_t 、o_t 的来源。3 个门控单元的计算方法完全相同,都是由输入经过线性映射得到的,区别只是计算的参数不同:
使用相同计算方式的目的是它们都扮演了门控的角色,而使用不同参数的目的是为了误差反向传播时对三个门控单元独立地进行更新。在理解 LSTM 运行机制的时候,为了对图进行简化,我们不在图中标注三个门控单元的计算过程,并假定各门控单元是给定的。
(2). 第二次简化:考虑一维门控单元 i_t 、 f_t 、 o_t。LSTM 中对各维是独立进行门控的,所以为了表示和理解方便,我们只需要考虑一维情况,在理解 LSTM 原理之后,将一维推广到多维是很直接的。经过这两次简化,LSTM 的数学形式只有下面三行
由于门控单元变成了一维,所以向量和向量的逐元素相乘符号
变成了数和向量相乘 · 。
(3). 第三次简化:各门控单元二值输出。门控单元 i_t 、f_t 、o_t 的由于经过了 sigmoid 激活函数,输出是范围是 [0, 1]。激活函数使用 sigmoid 的目的是为了近似 0/1 阶跃函数,这样 sigmoid 实数值输出单调可微,可以基于误差反向传播进行更新。
既然 sigmoid 激活函数是为了近似 0/1 阶跃函数,那么,在进行 LSTM 理解分析的时候,为了理解方便,我们认为各门控单元 {0, 1} 二值输出,即门控单元扮演了电路中开关的角色,用于控制信息的通断。
(4). 一张图。将三次简化的结果用电路图表述出来,左边是输入,右边是输出。在 LSTM 中,有一点需要特别注意,LSTM 中的细胞状态 c_t 实质上起到了 RNN 中隐层单元 h_t 的作用,这点在其他文献资料中不常被提到,所以整个图的输入是 x_t 和 c_{t-1},而不是 x_t 和 h_(t-1)。为了方便画图,我们需要将公式做最后的调整
最终结果如下:
和 RNN 相同的是,网络接受两个输入,得到一个输出。其中使用了两个参数矩阵 W_(xc) 和 W_(hc),以及 tanh 激活函数。不同之处在于,LSTM 中通过 3 个门控单元 i_t 、f_t 、o_t 来对的信息交互进行控制。当 i_t=1(开关闭合)、f_t=0(开关打开)、o_t=1(开关闭合)时,LSTM 退化为标准的 RNN。
2.3 LSTM 各单元作用分析
根据这张图,我们可以对 LSTM 中各单元作用进行分析:
3. GRU
3.1 GRU 的数学形式
GRU 是另一种十分主流的 RNN 衍生物。RNN 和 LSTM 都是在设计网络结构用于缓解梯度消失问题,只不过是网络结构有所不同。GRU 在数学上的形式化表示如下:
3.2 三次简化一张图
为了理解 GRU 的设计思想,我们再一次运用三次简化一张图的方法来进行分析:
(1). 第一次简化:忽略门控单元 z_t 和 r_t 的来源。
(2). 考虑一维门控单元 z_t 和 r_t。经过这两次简化,GRU 的数学形式是以下两行
(3). 第三次简化:各门控单元二值输出。这里和 LSTM 略有不同的地方在于,当 z_t=1 时h_t = h_(t-1) ;而当 z_t = 0 时,h_t =
。因此,z_t 扮演的角色是一个个单刀双掷开关。
(4). 一张图。将三次简化的结果用电路图表述出来,左边是输入,右边是输出。
与 LSTM 相比,GRU 将输入门 i_t 和遗忘门 f_t 融合成单一的更新门 z_t,并且融合了细胞状态 c_t 和隐层单元 h_t。当 r_t=1(开关闭合)、 z_t=0(开关连通上面)GRU 退化为标准的 RNN。
3.3 GRU 各单元作用分析
根据这张图, 我们可以对 GRU 的各单元作用进行分析:
4. 小结
尽管 RNN、LSTM、和 GRU 的网络结构差别很大,但是他们的基本计算单元是一致的,都是对 x_t 和 h_t 做一个线性映射加 tanh 激活函数,见三个图的红色框部分。他们的区别在于如何设计额外的门控机制控制梯度信息传播用以缓解梯度消失现象。LSTM 用了 3 个门、GRU 用了 2 个,那能不能再少呢?MGU(minimal gate unit)尝试对这个问题做出回答,它只有一个门控单元。最后留个小练习,参考 LSTM 和 GRU 的例子,你能不能用三次简化一张图的方法来分析一下 MGU 呢?
参考文献
本文为机器之心专栏,转载请联系本公众号获得授权。