探索LSTM:基本概念到内部结构

安妮 若朴 编译整理 量子位 出品 | 公众号 QbitAI

作者简介:Edwin Chen,在MIT研究数学/语言学,在微软研究院研究语音识别,在Clarium研究量化交易,在Twitter研究广告,在Google研究机器学习。

在这篇文章中,作者首先介绍了一下神经网络、RNN和LSTM的基本概念,然后举例对比了三种网络的性能,并且进一步讲解了LSTM。

LSTM是神经网络一个相当简单的延伸扩展,而且在过去几年里取得了很多惊人成就。我第一次了解到LSTM时,简直有点目瞪口呆。不知道你能不能从下图中发现LSTM之美。

OK,咱们这就开始切入正题。先简单介绍一下神经网络和LSTM。

神经网络

假设我们有一个来自电影的图像序列,然后想用一个活动来标记每张图像。例如这是一场打斗么?角色在交谈?角色在吃东西?

应该怎么做?

一种方法是忽略图像的顺序属性,把每张图片单独考虑,构建一个单张图片的分类器。例如,给出足够的图片和标签:

  • 算法首先学习检测形状、边缘等低级模式
  • 在更多数据的驱动下,算法学会将低级模式组合成复杂形态,例如一个椭圆形上面有两个圆形加一个三角形可以被认为是人脸
  • 如果还有更多数据,算法会学到将这些高级模式映射到活动本身,例如有嘴、有牛排、有叉子的场景可能是吃饭

这就是一个深度神经网络:得到一个图像输入,返回一个活动输出。

神经网络的数学原理如图所示:

用RNN记住信息

忽略图像的顺序可以算是初步的机器学习。更进一步,如果是一幕海滩的场景,我们应该在后续帧中强化海滩相关的标记:如果有人在水中,大概可以标记为游泳;而闭眼的场景,可能是在晒太阳。

同样,如果场景是一个超市,有人手拿培根,应该被标记为购物,而不是做饭。

我们想做的事情,是让模型追踪世界的状态。

  1. 看到每个图像后,模型输出一个标签,并更新其对世界的知识。例如,模型能学会自动发现和追踪信息,例如位置、时间和电影进度等。重要的是,模型应该能自动发现有用的信息。
  2. 对于给定的新图像,模型应该融合收集而来的知识,从而更好的工作。

这样就成了一个循环神经网络RNN。除了简单的接收一张图片返回一个活动标记之外,RNN会通过给信息分配不同的权重,从而在内部保留了对世界的记忆,以便更好的执行分类任务。

RNN的数学原理如图所示:

通过LSTM实现长期记忆

模型如何更新对世界的认知?到目前为止,还没有任何规则限制,所以模型的认知可能非常混乱。这一帧模型认为人物身处美国,下一帧如果出现了寿司,模型可能认为人物身处日本……

这种混乱的背后,是信息的快速变换和消失,模型难以保持长期记忆。所以我们需要让网络学习如何更新信息。方法如下:

  1. 增加遗忘机制。例如当一个场景结束是,模型应该重置场景的相关信息,例如位置、时间等。而一个角色死亡,模型也应该记住这一点。所以,我们希望模型学会一个独立的忘记/记忆机制,当有新的输入时,模型应该知道哪些信息应该丢掉。
  2. 增加保存机制。当模型看到一副新图的时候,需要学会其中是否有值得使用和保存的信息。
  3. 所以当有一个新的输入时,模型首先忘掉哪些用不上的长期记忆信息,然后学习新输入有什么值得使用的信息,然后存入长期记忆中。
  4. 把长期记忆聚焦到工作记忆中。最后,模型需要学会长期记忆的哪些部分立即能派上用场。不要一直使用完整的长期记忆,而要知道哪些部分是重点。

这样就成了一个长短期记忆网络(LSTM)。

RNN会以相当不受控制的方式在每个时间步长内重写自己的记忆。而LSTM则会以非常精确的方式改变记忆,应用专门的学习机制来记住、更新、聚焦于信息。这有助于在更长的时期内跟踪信息。

LSTM的数学原理如图所示:

卡比兽

神奇宝贝中的卡比兽

我们不妨拿《神奇宝贝》中的卡比兽对比下不同类别的神经网络。

神经网络

当我们输入一张卡比兽被喷水的图片时,神经网络会认出卡比兽和水,推断出卡比兽有60%的概率在洗澡,30%的概率在喝水,10%的概率被攻击。

循环神经网络(RNN)

在隐藏状态(Hidden State)为“战斗场景开始”的情况下输入神奇宝贝喷水进攻图,RNN能够根据“嘴中喷水”的场景推测图一神奇宝贝是在进攻的概率为85%。之后我们在记忆为“在战斗、敌人在攻击和敌人是水性攻击”三个条件下输入图片二,RNN就会分析出“卡比兽被攻击”是概率最大的情况。

LSTM

在长期记忆(Long-Term Memory)为“卡比兽喜欢吃竹子”、“每个战斗回合为一分钟”和“敌人在屏幕中央”,工作记忆(Working Memory)为“每个战斗回合为一分钟”“敌人在屏幕中央”的情况下,输入卡比兽被喷水的图片,LSTM会选择性处理一些信息。它选择性记忆了卡比兽的痛苦的表情,忘掉了“屏幕中央的是敌人”这条信息,得出卡比兽被攻击的可能性最大。

学会编码

有一种字符级的LSTM模型,可以通过输入的字符级序列来预测下一个可能出现的字符。我将用这种模型向大家展示LSTM的用法。

虽然这个方法看起来不成熟,但不得不说字符级的模型使非常实用,个人觉得比单词级模型还要实用一些。比如下面这两个例子:

1. 假设有一种代码自动填充器足够智能,允许手机端编程

理论上讲,LSTM可以跟踪当前所用方法的返回类型,更好地建议返回哪个变量;也可以通过返回错误类型告诉你程序是否有bug。

2. 像机器翻译这种自然语言处理程序通常很难处理生僻术语

怎样才能把之前从未见过的形容词转换成相应的副词?即使知道一条推文是什么意思,但怎样为它生成标签?字符级模型就可以帮你处理这些新出现的术语,不过这也是另外一个领域研究的事情了。

所以在一开始,我用亚马逊AWS弹性计算云EC2的p2.xlarge在Apache Commons Lang代码库训练了三层LSTM,几个小时后生成了这个程序:

虽然这段代码并不完美,但已经比我认识的很多数据专家编写的代码优秀了。从这里我们可以看出LSTM已经学会很多有趣且正确的编码行为:

可以构建class:优先放许可证,之后是程序包和导入包,再之后放注释和类别定义,最后是变量和方法。它也知道如何创造方法:需要遵循正确的描述顺序,查看装饰器是否处于正确的位置,以适当的语句返回无类型指针。重要的是,这些行为还跨越了大量的代码。

可以跟踪子程序和嵌套级别:如果语句循环总是被关闭的话,缩进处理是一个不错的选择。

它甚至知道如何创建测试。

此模型到底是怎样做到上面这些功能的呢?我们可以看几个隐藏状态。

这是一个神经元,看起来它似乎在追踪代码的缩进级别。在神经元以字符为输入进行读取,例如试图生成下一个字符的时候,每个字符都根据神经元的状态被标记了颜色,红色表示负值,蓝色表示正值。

这里有一个可以计算两个标签距离的神经元:

还有一个在TensorFlow代码库中生成的不同3层LSTM的有趣输出结果:

研究LSTM的内部结构

上面我们了解了几个隐藏状态的例子,不妨再聊得深入一些。我在考虑LSTM cell和它们的其他记忆机制。或许它们之间也存在令人惊叹的关系。

计数

为了探究这个问题,我们需要先教LSTM学会计数,所以我生成了下面这个序列:

aaaaaXbbbbb

这串序列中,在N个a之后跟着一个定界符X,之后又跟着N个字符b。在这里,1<=N<=10。我们用此序列训练带有10个隐藏神经元的单层LSTM。

正如预期的那样,LSTM在其训练范围内表现良好,甚至可以生成一些超过了训练范围的东西。

aaaaaaaaaaaaaaaXbbbbbbbbbbbbbbb aaaaaaaaaaaaaaaaXbbbbbbbbbbbbbbbb aaaaaaaaaaaaaaaaaXbbbbbbbbbbbbbbbbb aaaaaaaaaaaaaaaaaaXbbbbbbbbbbbbbbbbbb aaaaaaaaaaaaaaaaaaaXbbbbbbbbbbbbbbbbbb

我们期望找到一个隐藏的状态神经元计算a的数量:

为此我还专门构建了一个小型网页应用,它不仅可以在刚刚的基础上计算a的数量,还能计算b的数量。

此时cell表现很相似:

还有一件有趣的事情,工作记忆看起来像一个长期记忆的增强版,这在一般情况下是正常的吗?

答案是肯定的,这也和我们期望的完全一样。因为长期记忆被双曲正切激活函数限制了输出内容。下面是10个cell状态节点】的总览,我们可以看到很多代表接近0值的浅色cell。

相比之下,这10个工作记忆的神经元看起来更专注,1号、3号、5号和7号在序列的前半部分看起来都是0。

那我们再看看2号神经元,这里给大家展示一些备用记忆和输入门(Input Gate)。它们在神经元的每半部分都是稳定的——就像神经元在每个步骤都在计算a+=1或b+=1一样。

最终,我们得到了所有神经元的内部结构:

如果你也想为不同的神经元计数,可以看看这个可视化工具。

可视化工具链接: http://blog.echen.me/lstm-explorer/#/network?file=counter

伯爵

还记得美国公共广播协会(PBS)制作播出的儿童教育电视节目《芝麻街》,里面有一个魔方吸血鬼样子设计的玩偶,叫伯爵。他喜欢数数,对计算机的兴趣没有边界。我就把这一部分的标题命名为伯爵好了。

《芝麻街》里的计算机爱好者“数数伯爵”

现在我们看一个稍微复杂一点的计数器,这一次我生成了序列化的表单是:

aaXaXaaYbbbbb

上面这串序列的特点是N个a和X任意交叉排列,再加入一个定界符Y,最后加入N个b。LSTM仍需计数a的数目,这次也同样需要忽略X。

完整的LSTM链接: http://blog.echen.me/lstm-explorer/#/network?file=selective_counter

我们希望得到遇到X时输入门是0的计数神经元

上面就是20号神经元的cell状态。在还没有到达定界符Y时它会一直增大,之后一直递减直到序列的末尾——就像它计算的是num_bs_left_to_print变量,根据a的增量和b的递减不断变化。

它的输入门确实忽略了X:

有趣的是,备用存储器完全激活了不相关的定界符X,所以我们还是需要一个输入门。(如果输入门不是架构的一部分,想必神经网络将学会用其他方式学会忽略X)

那我们继续看10号神经元。

这个神经元很有意思,因为它只有在读取定界符Y时才可以激活,但它仍然试图编码目前在序列中看到的a。这很难从图中看出,但当读取到Y属于有相同数量的a的序列时,所有的元胞状态都是几乎相同的。可以看到,序列中a越少,Y的颜色越浅。

记住状态

接下来,我想看看LSTM是如何记住元胞状态的。我在再次生成一些序列:

AxxxxxxYa BxxxxxxYb

在这个序列中,A或B后面可以接1-10个x,之后接入定界符Y,最后以开头字母的小写结尾。这种神经网络需要记住序列是否为一个A或B序列状态。

我们希望找到一个神经元,当记住这个序列以A开始时触发,另一个神经元会在记住它以B开始时触发。我们也确实找到了。

例如,这里有一个神经元A,当它读到一个A时,它会激活,并记住它,直到需要生成最终的字符为止。注意,输入门忽略了中间所有的x字符。

这里还有一个副本B:

有趣的是,直到网络读取Y定界符前,A和B的状态信息都是不需要的,但隐藏的状态还是会在所有中间输入中触发。这看起来有点“低效”,但也许是因为这个神经元还承担了计算x数量的任务。

复制任务

最后,我们看看如何让LSTM复制信息。

对于这个训练任务就,我训练了两层的LSTM序列:

baaXbaa abcXabc

也就是说,这是一个由a、b、c组成的三个字符的子序列,后接定界符X,最后接一个相同的序列。

我并不确定复制神经元是什么样,所以为了弄清存储最初子序列部分的神经元,我查看了它们在读取定界符X时的隐藏态。因为网络需要对初始子序列进行编码,因此它的状态应该根据所学内容的不同显示不同的格式。

以下面这张图表为例,它绘制了5号神经元在读取定界符X时的隐藏状态。神经元显然能区分序列是不是以“c”开始。

再举一个例子,下图是20号神经元在读取X时的隐藏状态,它好像是把从b开始的序列都挑了出来。

如果观察20号神经元的cell状态,你会发现它几乎可以自己捕捉全部3个字符的子序列。在只有一个维度的情况下能做到这样,非常了不起。

这是20号神经元的cell在整个序列中的隐藏状态。注意它的隐藏状态已经在全部初始子序列中被切断。

然而,如果我们仔细观察会发现,每当下一个字符是b时,神经元就会被触发。因此与其说这个神经元代表的是一个从b开始的序列,不如说是下一个字符是b的预测器。

据我所知,这种模式似乎在整个网络中都保持着——所有的神经元都在预测下一个字符,而不是记住特定位置的字符。例如,5号神经元似乎是“下一个字符是c”的预测器。

我不确定这是否是LSTM在复制信息时学会的默认行为,或者是其他的复制机制也会这样做。

状态和门

为了真正了解不同LSTM中不同的状态和门的作用,我们换个角度重新讲讲上一部分。

cell状态和隐藏状态

我们最初把cell状态描述为一种长期记忆,而隐藏状态是一种提取和聚焦这些记忆的方法。

所以当记忆与当前无关时,我们期望隐藏状态可以关闭——前面讲的序列复制神经元就是这样。

遗忘门

遗忘门(Forget Gate)会丢弃cell状态的信息(0意味着完全忘记,1意味着完全记住),所以我们期望当它需要记住一些确切的东西时,它可以被完全激活;当不再需要这些信息时,它可以被再次关闭。

这就是我们看到的这个A记忆神经元:遗忘门在被触发时记住当它经过x时的A状态,当它准备生成最后的a时将关闭。

输入门(保存门)

输入门(Input Gate,我以前叫它“保存门”)决定了是否从新输入中保存信息。因此它需要在遇到无用信息时关闭。

这也就是选择性计数神经元所做的:计算a和b的值,但忽略不相关的x。

令人惊讶的是,我们的LSTM方程中没有任何地方指定了输入、遗忘和输出门的工作方式,神经网络会自学什么是最好的。

扩展

现在我们探讨一下LSTM是如何出现的。

首先,很多我们要解决的问题都是按时序或者一定顺序排列的,我们可以将过去的经验融合到我们的模型中来。但是我们已经知道神经网络的隐藏层编码着重要的信息,所以我们为何不把这些隐藏层作为从一个时间步传递到另一个时间步的记忆呢?因此,出现了循环神经网络(RNN)

人类不会轻易地相信一些信息——当我们读到一篇关于政治的文章时,我们不会立刻相信我们读到的并把它当成自己的人生信仰。我们会选择哪些信息值得记住,哪些信息需要丢弃,哪些信息可以在下次阅读时提供一些决策依据。也就是说,我们想要学习如何收集、修正和应用信息。那我们为什么不让神经网络来学习这些呢?所以就有了LSTM

任务都完成了,现在我们可以对它进行修改。

你可能会觉得LSTM区分长期记忆和工作记忆时傻傻的:为什么不合二为一?又或许你发现了独立记忆门和存储门有点多余。现在又有人提出一种新的LSTM变种,我们称之为GRU(Gated Recurrent Units)

想了解更多GRU专业知识?这里有一份论文: https://arxiv.org/abs/1412.3555

当决定记忆存储和集中你哪些信息时,我们不能单独依赖工作记忆,那为什么不能用长期记忆呢?所以我们又发现了Peephole LSTM。

Peephole LSTM Paper(PDF): http://machinelearning.wustl.edu/mlpapers/paper_files/GersSS02.pdf

实现神经网络伟大复兴

让我们看看最后这个例子,我双层LSTM在特朗普的推特上训练。尽管这个数据集很“大”,它仍然学会了很多的模式。

比如,这是一个记录话题标签、URL和@的神经元。

这里还有一个专有名词检测器(注意:它不是只在遇到大写字母时激活):

这是一个助动词+“to be”检测器(包括“will be”、“I’ve always been”、“has never been”等形式):

这是一个引用属性:

还有一个MAGA(多智能体遗传算法)和大写的神经元:

下面是LSTM发布特朗普的推文声明——其实只有一个是真的,不妨猜猜看:

川普数据集在这里: https://www.kaggle.com/benhamner/clinton-trump-tweets

总结

这篇文章写到这里,已经基本接近尾声了。在结束前,我们再回顾一下你收获了什么:

你需要记住的是:

虽然本文的篇幅较长,且讲解内容较多,但LSTM仍然是一个没有止境的研究领域。如果你恰巧对LSTM感兴趣,那就钻进去好好研究吧。

原文发布于微信公众号 - 量子位(QbitAI)

原文发表时间:2017-06-11

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏AI科技评论

业界 | 如期而至!谷歌开源 BERT 模型源代码

AI 科技评论按:自上个月谷歌公开 BERT 模型以来,BERT 模型以其双向、深层等特点,成功在 11 项 NLP 任务中取得 state of the ar...

17440
来自专栏人工智能头条

史上最全TensorFlow学习资源汇总

1.9K50
来自专栏AI科技评论

循环神经网络的增强方法:注意力机制以及更多

AI 科技评论按:一提起 RNN,我们最容易想到的实现就是 LSTM + attention。LSTM 作为经典的网络结构可谓是老当益壮,而另一方面注意力机制过...

20910
来自专栏ATYUN订阅号

使用LSTM预测比特币价格

本文以“时间序列预测的LSTM神经网络”这篇文章为基础。如果没有阅读,我强烈建议你读一读。 考虑到近期对比特币货币的泡沫的讨论,我写了这篇文章,主要是为了预测比...

39970
来自专栏AI研习社

博客 | 如期而至!谷歌开源 BERT 模型源代码

雷锋网 AI 科技评论按:自上个月谷歌公开 BERT 模型以来,BERT 模型以其双向、深层等特点,成功在 11 项 NLP 任务中取得 state of th...

15630
来自专栏AI研习社

深度学习下的医学图像分析(三)

本文将从卷积神经网络的角度讨论深度学习。在本文中,我们将使用Keras和Theano,重点关注深度学习的基本原理。本文将展示两个例子——其中一个例子使用Kera...

44040
来自专栏机器之心

教程 | 如何为时间序列数据优化K-均值聚类速度?

355100
来自专栏CDA数据分析师

数据挖掘150道试题 敢不敢来自测!

单选题 1. 某超市研究销售纪录数据后发现,买啤酒的人很大概率也会购买尿布,这种属于数据挖掘的哪类问题?(A) A. 关联规则发现 B. 聚类 C. 分类 ...

27990
来自专栏华章科技

数据挖掘150道试题,测测你的专业能力过关吗

2.以下两种描述分别对应哪两种对分类算法的评价标准?(A) (a)警察抓小偷,描述警察抓的人中有多少个是小偷的标准。 (b)描述有多少比例的小偷给警察抓了的标准...

10510
来自专栏YoungGy

MMD_2b_NearestNeighborLearning

main question: how to efficiently train (build model and find model parameters...

19170

扫码关注云+社区

领取腾讯云代金券