首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

深度学习笔记6:神经网络优化算法之从SGD到Adam

从前面的学习中,笔者带大家一起学会了如何手动搭建神经网络,以及神经网络的正则化等实用层面的内容。这些都使得我们能够更深入的理解神经网络的机制,而并不是初次接触深度学习就上手框架,虽然对外宣称神经网络是个黑箱机制,但是作为学习者我们极度有必要搞清楚算法在每个环节到底都干了些什么。

今天笔者需要讲的是深度学习的一个大的主题——优化算法。采用何种方式对损失函数进行迭代优化,这是机器学习的一大主题之一,当一个机器学习问题有了具体的模型和评估策略,所有的机器学习问题都可以形式化为一个最优化问题。这也是为什么我们说优化理论和凸优化算法等学科是机器学习一大支柱的原因所在。从纯数学的角度来看,所有的数学模型尽管形式不一,各有头面,但到最后几乎到可以归约为最优化问题。所以,有志于奋战在机器学习和深度学习领域的各位,学好最优化,责无旁贷啊。

要说机器学习和深度学习的优化算法,梯度下降必然是核心所在。神经网络发展至今,优化算法层出不穷,但大底是出不了梯度下降的框框架架。这一篇笔记,笔者就和大家一起学习和回顾深度学习中常用的优化算法。在前面手动搭建神经网络的代码实践中,我们对于损失函数的优化采用了一般的梯度下降法,所以本篇总结就从梯度下降法开始。

梯度下降法 Gradient Descent

想必大家对于梯度下降是很熟悉了,选择负梯度方向进行参数更新算是常规操作了。话不多说,对于多层神经网络如何执行梯度下降:

在上述代码中,我们传入含有权值和偏置的字典、梯度字段和更新的学习率作为参数,按照开头的公式编写权值更新代码,一个简单的多层网络的梯度下降算法就写出来了。

小批量梯度下降法 mini-batch Gradient Descent

在工业数据环境下,直接对大数据执行梯度下降法训练往往处理速度缓慢,这时候将训练集分割成小一点的子集进行训练就非常重要了。这个被分割成的小的子集就叫做 mini-batch,意为小批量。对每一个小批量同时执行梯度下降会大大提高训练效率。在实际利用代码实现的时候,小批量梯度下降算法通常包括两个步骤:充分打乱数据(shuffle)和分组组合数据(partition)。如下图所示。

shuffle

partition

具体代码实现为:

小批量梯度下降的实现思路非常清晰,先打乱数据在分组数据,需要注意的细节在于最后一个小批量所含的训练样本数,通常而言最后一个小批量会少于前面批量所含样本数。

随机梯度下降 Stochastic Gradient Descent

当小批量所含的训练样本数为 1 的时候,小批量梯度下降法就变成了随机梯度下降法(SGD)。SGD虽然以单个样本为训练单元训练速度会很快,但牺牲了向量化运算所带来的便利性,在较大数据集上效率并不高。

我们可以看一下梯度下降和随机梯度下降在实现上的差异:

所以,从本质上看,梯度下降法、小批量梯度下降法和随机梯度下降法,并没有区别。唯一的区别就在于它们执行一次训练过程所需要用到的训练样本数。梯度下降法用到的是全集训练数据,随机梯度下降则是单个样本数据,而小批量则是介于二者之间。

带动量的梯度下降法(momentum)

正如上图中看到的一样,我们假设梯度下降的横向为参数 W 的下降方向,而偏置 b 的下降方向为纵轴,我们总是希望在纵轴上的震荡幅度小一点,学习速度慢一点,而在横轴上学习速度快一点,无论是小批量梯度下降还是随机梯度下降,好像都不能避免这个问题。为了解决这个问题,带动量的梯度下降法来了。带动量的梯度下降考虑历史梯度的加权平均值作为速率进行优化。执行公式如下:

根据上述公式编写带动量的梯度下降法实现代码:

实现带动量的梯度下降的关键点有两个:一是动量是考虑历史梯度进行梯度下降的,二是这里的需要指定的超参数变成了两个:一个是学习率,一个是梯度加权参数。

Adam算法

Adam 全称为 Adaptive Moment Estimation,是在带动量的梯度下降法的基础上融合了一种称为 RMSprop(加速梯度下降)的算法而成的。相较于带动量的梯度下降法,无论是RMSprop 还是 Adam,其中的改进思路都在于如何让横轴上的学习更快以及让纵轴上的学习更慢。RMSprop 和 Adam 在带动量的梯度下降法的基础上,引入了平方梯度,并对速率进行了偏差纠正。具体计算公式如下:

实现代码如下:

除了以上这些算法,还有一些像 Adadelta 之类的算法我们没有提到,有需要了解的同学可以自行查找相关资料。最后用一个图来展示各种优化算法的效果:

参考资料:

https://www.coursera.org/learn/machine-learning

https://www.deeplearning.ai/

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20180705B1P0RJ00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券