计算图的微积分:反向传播

介绍

后向传播是训练深度模型在计算上易于处理的关键算法。对于现代神经网络,相对于单纯的实现,它可以使梯度下降的训练速度提高一千万倍。这相当于模型训练时间是需要一个星期还是20万年的差距。

除了用于深度学习外,反向传播在许多其他领域是一个强大的计算工具,从天气预报到数值计算 - 只是在不同领域名称不同而已。事实上,这个算法在不同领域至少被重新改进了几十次(见Griewank(2010))。一般在独立应用程序中名称是“反向模式求导”。

从根本上讲,这是一种快速计算导数的技术。不仅在深度学习,而且在各种数值计算的情况下,反向传播是一个必不可少的窍门。

计算图

计算图是思考数学表达式的好方法。例如,考虑表达式e = (a + b )* (b + 1 )。这里包含三个操作:两个加法和一个乘法。为了帮助我们更好理解,我们引入两个中间变量c和d,这样每个函数的输出都有一个参数,我们现在有

c=a+b

d=b+1

e=c*d

为了生成一个计算图,我们把这些操作和输入变量一起做成节点。当一个节点的值是另一个节点的输入时,箭头从一个节点到另一个节点。

这些图形总是出现在计算机科学领域,特别是在谈论功能性程序时。它们与依赖图和调用图的概念很相似。它们也是深受欢迎的深度学习框架Theano的核心抽象。

我们可以通过将输入变量设置为特定值并通过图形计算节点来评估表达式。例如,我们设置a = 2和b = 1:

表达式评估为6。

计算图上的导数

如果想了解计算图中的导数,关键是了解边上的导数。如果a直接影响c,那么我们想知道它是如何影响c。如果a改变一点,c如何更改?我们称之为c的偏导数关于a。

为了评估这个图中的导数,我们需要加法律和乘法律:

下面的计算图,每条边都标注有导数。

如果我们想了解那些不直接连接的节点是如何相互影响的呢?我们来考虑一下e是怎样受到a的影响的。如果a以1的速度变化,c也以1的速度变化。反过来,c以1的变速导致e以2的速度改变。所以e以2的速度变化随着a的变化。

一般规则是将从一个节点到另一个节点的所有可能路径相加,将路径的每个边上的偏导数相乘。例如,要得到e关于b的导数,我们有:

这就是b如何影响e通过c、d。

这个一般的“求和路径”规则只是对链式法则不同的思考方式。

路径分解

“所有路径相加”的问题在于,可能的路径数目多到“爆炸”的。

在上图中,有三条路径X到Y,有三条路径Y到Z。如果我们想得到导数∂Z/∂X

通过所有路径相加,我们需要计算 9条路径的加法:

以上只有九条路径,但随着图形变得越来越复杂,路径的数量会呈指数增长。

他们不是单纯地把所有路径求和,而是要将它们分解:

这就是“正向模式求导”和“反向模式求导”的思想。它们是通过分解路径高效地计算求和。不是分开求解的方式,而是对于通过每个节点的路径求和,更有效地计算相同的和。实际上,这两种算法都只计算一次边!

正向模式的差异从图的输入开始,并向最后移动。在每个节点上,它汇总了所有输入路径。这些路径中的每一个路径代表了输入影响该节点的一种因素。把它们相加,就得到了节点受输入影响的总体方式,这就是导数。

虽然你可能没有用图的方式来思考,但是如果你对微积分类进行介绍的话,前向模式求导与你隐含的学习要做的非常相似。

另一方面,反向模式微积分开始于曲线图的输出并像起点移动。在每个节点上,它合并该节点发起的所有路径。

正向模式求导跟踪一个输入如何影响到每个节点。反向模式求导跟踪每个节点如何影响一个输出。也就是说,正向模式求导适用于操作符∂/∂X到每个节点,而反向模式求导则应用算子∂Z/∂到每个节点。

计算成功

在这一点上,你可能惊讶为什么有人会关心反向模式求导。相对于正向模式,它看起来像做一个奇怪的方式做同样的事情。这样做有什么优势吗?

我们再次看看最初的例子:

我们可以使用b的正向模式由b向上求导,这样就得到了每个点关于b的导数。

我们计算了∂e/∂b,即计算了输出对于某个输入的导数。

如果我们从e一步步向下做反向模式求导呢?这就得到了e对每个点的导数:

当我说反向模式导数给了我们关于每个节点e的导数时,实际指的是每个节点。我们得到两个∂e/∂a和∂e/∂b,e是关于这两个输入的演化。正向模式导数给我们的输出是单一输入的导数,但是反向模式导数给我们的是所有输入模式。

对于这个图,这只是两个变化因素,但想象一个具有一百万个输入和一个输出的函数,正向模式导数要求我们通过一百万次这样的图来获得演化。反向模式导数可以一举将它们全部拿下!一百万的加速变化是相当不错的!

当训练神经网络时,我们将成本(一个描述神经网络性能好坏的值)看作参数(描述网络行为的值)的函数。我们要计算相对于所有参数的成本导数,以便在梯度下降中使用。现在,神经网络中通常有数百万甚至上千万的参数。所以,反向模式分化在神经网络的背景下被称为反向传播,使我们大大加快了运算速度!

(是否在任何情况下,正向模式求导都更有意义呢?是的,有!在反向模式给出一个输出的分化对于所有输入的情况下,正向模式给出了所有输出对于一个输入,如果一个函数具有很多输出,则正向模式分化可以更快,更快)。

这不平凡吗?

当我第一次明白什么是反向传播时,我的反应是:“噢,这只是链式法则!我们怎么花了这么长时间才弄明白呢?“我不是唯一一个有这种反应的人。确实,如果你问“在前馈神经网络中是否有一个完美的方法来计算导数呢?”的答案并不困难。

但我认为这比看起来要困难得多。你看,在反向传播发明的时候,人们并不十分关注我们研究的前馈神经网络。求导是训练模型的正确方法,但这也不是很明显。一旦你意识到你可以快速计算导数,这些只是一个循环依赖。

更糟糕的是,将任何一个循环依赖关系作为部分事物来处理是非常容易的。用求导工具训练神经网络吗?当然,你只是陷在了本地极小值。显然,计算所有这些导数将会存在很高的代价。在我们没有立即开始列出可能不会的原因的情况下,只是因为我们知道这个方法有效而已。

事后看来这样做确实有好处。一旦你提出这个问题,最困难的工作已经完成了。

结论

求导比你想象的要廉价。这是从这篇文章中吸取的主要教训。事实上,它们并不廉价,我们头脑不灵活的人不得不一再重新发现这个事实。在深度学习中理解这一点非常重要。在其他领域也是非常有用的东西,如果不是常识的话也更是如此。

还有其他的获得的经验吗?我觉得有。

反向传播也是理解导数如何流经模型的有用透镜。这对于推理某些模型难以优化的原因非常有帮助。一个经典的例子就是在递归神经网络中消除梯度的问题。

最后,我说明一个广泛的算法教训,用以摆脱这些技术。反向传播和正向模式求导都使用了一对强大的技巧(线性化和动态规划)来更有效地计算导数。如果你真的了解这些技术,你可以使用它们有效地计算其他涉及导数的有趣表达式。我们将在下一篇博客文章中对此进行探讨。

这篇文章给出了一个非常抽象的反向传播的处理过程。我强烈建议阅读迈克尔·尼尔森有关于这部分的章节,以便于进一步精彩的讨论,更具体地集中在神经网络上。

致谢

感谢Greg CorradoJon ShlensSamy BengioAnelia Angelova花时间来校对这篇文章。

同时感谢Dario AmodeiMichael NielsenYoshua Bengio讨论解释反向传播的方法。最后感谢所有那些容忍我在讲座和研讨会系列中解释反向传播的人!

  1. 这可能有点像动态编程。那是因为它!

更多的帖子

本文的版权归 AlexanderTan 所有,如需转载请联系作者。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏智能算法

Yoshua Bengio等大神传授:26条深度学习经验

原文地址:http://www.marekrei.com/blog/26-things-i-learned-in-the-deep-learning-summe...

3776
来自专栏ATYUN订阅号

【业界】Facebook发布开源“Detectron”深度学习库,用于对象检测

AiTechYun 编辑:nanan 在刚刚过去的一月份(2018年1月),Facebook的研究机构Facebook AI Research(FAIR)发布了...

3154
来自专栏新智元

解决3D重建难题,伯克利大学根据单张平面彩图重建高精度3D结构

【新智元导读】3D重建是计算机视觉中的一个核心问题,应用于电影制作、视频游戏的内容制作、虚拟现实和增强现实、3D打印,等等。伯克利大学的研究团队提出根据单张彩色...

3486
来自专栏AI2ML人工智能to机器学习

一步一步走向锥规划 - QP

一般来说凸优化(Convex Optimization, CO)中最一般的是锥规划 (Cone Programming, CP) 问题, 前面我们介绍了点 最...

1271
来自专栏月色的自留地

从锅炉工到AI专家(4)

2177
来自专栏CSDN技术头条

Yoshua Bengio等大神传授:26条深度学习经验

【编者按】8月初的蒙特利尔深度学习暑期班,由Yoshua Bengio、 Leon Bottou等大神组成的讲师团奉献了10天精彩的讲座,剑桥大学自然语言处理与...

2116
来自专栏鸿的学习笔记

重新思考深度学习里的泛化

2017 ICLR提交的“UnderstandingDeep Learning required Rethinking Generalization”必然会打乱...

963
来自专栏AI研习社

通过简单代码回顾卷积块

我会努力定期去阅读机器学习和人工智能相关的论文。这也是能够持续跟进最新进展的唯一途径。作为一名计算机科学家,我经常在翻阅科学描述文本或者公式的数学符号时碰壁。我...

1454
来自专栏量子位

以为GAN只能“炮制假图”?它还有这7种另类用途

最近,AI方案设计师Alexandor Honchar在Medium网站上分享一篇文章。他认为生成对抗网络(GAN)目前在生成图像取得了巨大进展,生成的图像几乎...

1292
来自专栏Python中文社区

Python机器学习工具:Scikit-Learn介绍与实践

Scikit-learn 简介 官方的解释很简单: Machine Learning in Python, 用python来玩机器学习。 什么是机器学习 机...

2147

扫码关注云+社区

领取腾讯云代金券