专栏首页机器学习养成记为什么相比于RNN,LSTM在梯度消失上表现更好

为什么相比于RNN,LSTM在梯度消失上表现更好

Background

对于深度学习模型,在train参数的时候,需要采用随机梯度下降方法(SGD,Stochastic Gradient Descent):

是参数学习率;

是损失函数在

位置的偏导数,即梯度。

在LSTM被推出之前,RNN一直是主流的深度学习语言模型,但是其存在两个十分明显的技术缺陷,一个是梯度消失,另一个是梯度爆炸。

  • 梯度消失:RNN模型在某些

的取值上,梯度特别小(几乎为0);

  • 梯度爆炸:RNN在某些

的取值上,梯度特别大。(注:这里的大小指的是绝对值)

因此,若控制学习率

不变,那么参数要么几乎不变(几乎学不到信息,得不到更新),要么就变化特别大。换句话说,RNN在理论上是一个优秀的模型,如果能够找到合适的参数,可以很好地描述语言模型。但由于梯度消失(梯度爆炸)的存在,这组合适的参数并不好找,因此,大多数时候对RNN只能找到一组凑合差不多的参数解,从而影响到最终model的表现!

已经有了RNN,为什么还要整出来一个LSTM的重要原因,即LSTM解决了梯度消失的问题。

Introduction of RNN

首先了解一下RNN的梯度消失是如何产生的。先来借用一下李宏毅老师课件的图,展示一下RNN的结构:

每个时间阶段

,RNN读入一个变量

和上一个阶段产生的隐含层的信息

一起被吃进计算模块

中,产生新的隐含层信息

,新的隐含层信息一方面继续传播更新,另一方面产生该时刻的一个输出

。隐含层信息

和所有截止到时间

的历史输入都有关,并会依赖于RNN的信息路径一直传播下去。

其中,

是旧的隐含层信息对新的隐含层信息的影响,

是输入信息对新的隐含层信息的影响,

是隐含层信息对输出结果的影响。不考虑bias,

都是要学习的参数。

采用梯度下降法来求参数,首先定义Loss Function,有:

对上述公式两边求梯度,有:

分别对

进行求导,有:

................(1)

................(2)

.................(3)

梯度消失主要就是针对上面(2)和(3)两个式子来讲的。可以看到,上面公式里面有依赖于时间

的连乘符号;在修正某个位置

的误差时,计算出的梯度需要考虑

之前的所有时间

的隐含层信息对时间

的隐含层信息的影响。当

越远时,这个影响被迭代的次数就越多,对应着隐含层之间的连乘次数就越多。

就是这个连乘的结构产生了梯度消失,梯度爆炸也是它导致的!那么,这个连乘结构是怎么发生作用的呢?

首先,根据RNN的定义,把隐含层之间的函数关系表示出来,具体有:

其中,

表示Sigmoid激活函数;

于是有:

根据Sigmoid的特性,存在关系:

显而易见,

存在上界,为

,那么:

(1)当

时,始终有

,因此无论参数如何取值,当

距离很大时,连乘项都会趋向于无穷,在这种情况下就会导致梯度爆炸;

(2)当

时,始终有

,因此无论参数如何取值,当

距离很大时,连乘项都会趋于0,在这种情况下就会导致梯度消失;

注:或许有读者会疑问,根据梯度的表达式,梯度爆炸看似是会发生的,但是梯度消失应该不会啊;因为那种很长的连乘的项,只是求和内容中的子项,公式中还存在大量的"短距"连乘项,这些项不是仍然可以组成梯度方向用于参数更新么?

我的理解是:在随机梯度下降的最开始,公式中的"短距"连乘项或许会产生一些梯度方向,但是随着随着参数的动态更新,这些"短距"连乘项构成的方向会引导Loss Function到局部最优的位置上去。假如在这个过程中梯度爆炸没有发生,那么当Loss Function越接近于局部最优,这些"短距"连乘项就会越接近于0。毕竟,对于RNN模型,在整个迭代过程中,如果不考虑"长距"连乘项,局部最优的地方就是梯度为0的地方。

为了解决梯度爆炸的问题,可以用一些调参技巧,如可以限制梯度必须小于某个阈值,如果大于这个阈值,就取这个阈值进行更新。为了解决梯度消失的问题,也可以在不改变模型主题结构的情况下,克服调参的困难,如采用Relu或者LeakyRelu代替Sigmoid作为激活函数。

Introduction of LSTM (Long Short-term Memory)

LSTM(长短期记忆网络)的结构天然的可以克服梯度消失的问题,先介绍LSTM的基本结构:

LSTM中有两个隐含层,

。由LSTM的结构可知,在每个迭代周期,

需要不断地被更新。和RNN的

类似,

演化而来。

的内容保留一部分,同时也加入一些新东西,就构成了

中保留的比例由遗忘阀门

来控制,

具有相同维度,阀门变量的作用通过元素之间的对应乘积来体现。

上新加入的东西是

,即输入向量。

也有相同的维度,但是

也需要按照一定的比例来加入,这个比例由输入阀门

控制,输入向量和输入阀门的对应乘积就是真正需要新加入的内容。

当获得

后,这个隐含层的内容就会通过某种方式产生输出。

的内容通过一个阀门变量来控制是否从神经网络中的对应时间节点

予以输出,这个阀门变量为

具有相同的维度,乘积后的向量记为

,

通过一个网络层的转化就可以得到输出结果

和当前阶段的输入变量

以及上一阶段的待输出的隐含层信息

相关,具体表达式如上图所示。其中,

是需要估计的参数。

如果根据BPTT把梯度的结构展开,会包含连乘项:

为了便于分析,如果考虑bias,同时忽略输入变量

的作用,那么隐含层之间的关系可以表示为:

于是,需要连乘的项可表示为:

该值范围在0~1之间,但是在实际参数更新中,可以通过控制bias比较大,使得该值接近于1;在这种情况下,即使通过很多次连乘的操作,梯度也不会消失,仍然可以保留"长距"连乘项的存在。即总可以通过选择合适的参数,在不发生梯度爆炸的情况下,找到合理的梯度方向来更新参数,而且这个方向可以充分地考虑远距离的隐含层信息的传播影响。

另外需要强调的是,LSTM除了在结构上天然地克服了梯度消失的问题,更重要的是具有更多的参数来控制模型;通过四倍于RNN的参数量,可以更加精细地预测时间序列变量。

本文分享自微信公众号 - 机器学习养成记(chenchenwings)

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2019-04-26

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • LeetCode刷题DAY 28:移动零

    给定一个数组nums,编写一个函数将所有 0 移动到数组的末尾,同时保持非零元素的相对顺序。如:输入[0,1,0,3,12]返回[1,3,12,0,0],注意其...

    三猫
  • LeetCode刷题DAY 21:最长公共前缀

    找到字符串数组中的最长公共前缀并返回,无则返回""。如:输入["flower","flow","flight"],返回"fl"。

    三猫
  • LeetCode刷题DAY 7:只出现一次的数字III

    给定一个整数数组 nums,其中恰好有两个元素只出现一次,其余所有元素均出现两次。找出只出现一次的那两个元素。如输入[1,2,3,2,4,1],输出[3,4]。

    三猫
  • 设计模式—单例模式

    对于常用的23种设计模式,这里笔者会根据自己学习和出现频率、重要程度进行学习记录吧。并且每种设计模式可能会根据暂时需求侧重学习深浅。

    bigsai
  • Asp.NetCore 部署

    一年之间
  • 梯度下降法

    本文主要是为了讲解 梯度下降法 的原理和实践, 至于什么是梯度下降法, 他能做什么, 相信百度一下你就都知道了, 所以下面进入正题

    solve
  • nginx编译升级1.14.1 stable 或1.15.6 主线版本,亲试可行

    HTTP/2 漏洞影响 1.9.5 和 1.15.5 之间的所有 nginx 版本,MP4 模块安全问题影响运行 nginx 1.0.7, 1.1.3 及更高版...

    zhangdd
  • 线上服务CPU100%问题快速定位实战

    功能问题,通过日志,单步调试相对比较好定位。 性能问题,例如线上服务器CPU100%,如何找到相关服务,如何定位问题代码,更考验技术人的功底。 58到家架构部,...

    架构师之路
  • Linux CPU负载过高排查方法

    问:如何定位是哪个服务进程导致CPU过载,哪个线程导致CPU过载,哪段代码导致CPU过载?

    用户5760343
  • 它的梯度去哪儿了?

    斯坦福 NLP 第 9 课介绍了 RNN 语言模型的一个问题:梯度消失现象。那么什么是梯度消失? 为什么 RNN 会出现梯度消失呢? 梯度消失问题需要引起重视吗...

    double

扫码关注云+社区

领取腾讯云代金券