深层神经网络参数调优(三) ——mini-batch梯度下降与指数加权平均

深层神经网络参数调优(三)——mini-batch梯度下降与指数加权平均

(原创内容,转载请注明来源,谢谢)

一、mini-batch梯度下降

1、概述

之前提到的梯度下降,每优化一次的w和b,都要用到全部的样本集,把其称为批量梯度下降(batch),这里提出一个与其相对应的概念,叫做mini梯度下降。

mini-batch的目的,也是为了获取最优化代价函数的情况下的w和b,其主要改进的问题在于:当样本集数量太大,如果每次遍历整个样本集才完成一次的更新w和b,那运行时间太长。

2、主要做法

mini-batch只遍历一次样本集,就会对w和b进行多次的优化。首先其将样本集拆成m’等分,这里假设样本总数5000万,每份样本子集1000个数据,即m‘=1000,则可以分成5000份的样本子集。

则mini-batch的思想是,把这5000份的子集,当作5000个训练集,每次遍历完5000条数据,就更新一次w和b。这样,遍历一次样本集,可以对w和b优化5000次,速度大大提升。

3、具体做法

具体做法,类似批量梯度下降,也是用类似梯度下降的方式,只不过其不是一次遍历整个数据集才优化w和b,而是每1000个数据优化1次w和b。

具体过程,是先取出1000个数据,前向计算出结果,再反向传导计算出代价函数对w和b的偏导数;接着计算出代价函数的和,然后取这1000次的平均值,进行优化;然后再拿出1000个数据,再次计算代价函数与导数,再次优化,重复进行直到全部数据集取完即可。

4、mini-batch优化原理

画出迭代次数和代价函数的图,就可以看到,批量梯度下降,必然是每次迭代的代价函数都会下降。

而mini-batch,由于其一次只是取了一部分数据的平均值,故相当于是局部最优值,因此优化的时候有可能会有误差,是迂回的逼近最优值点的。

5、进一步讨论mini-batch

1)随机梯度下降

当样本子集的数量取1时,则相当于每计算一个样本,就进行一次优化w和b,这样的做法称为随机梯度下降。因为单个样本的最优值到底是否是最优值具有随机性,且一般这种做法会先随机打散整个数据集,故这样做就称为随机梯度下降。

2)三种梯度下降的比较

批量梯度下降,具有直接逼近最优值的特点,在图像上也是一路逼近最优值。但是每次要遍历整个样本集,数据量很大的时候速度太慢。

随机梯度下降,是另外一个极端,每个样本的计算都要修改w和b,在图像上是非常曲折迂回的逼近最优值。这样做在数据量大的时候,效果也不错,但是缺点在于,其需要逐个迭代整个样本集,没法利用到numpy提供的向量化计算方式,其实效率也不是非常高。

mini-batch,相当于上面两种算法的折中算法,因此同时克服了上面两个问题。

6、注意事项

1)当样本数量太小,如少于2000,则不需要考虑mini-batch,直接用批量梯度下降即可,用其他的算法反而可能因为训练数量不足导致误差更大。

2)当需要用到mini-batch,可以设定每个样本子集的数量是2的n次幂,如64、128、256、512、1024等,这是考虑到计算机的内存,因此这样利用率更高,效果更好。即需要考虑到CPU/GPU的内存,才能最大化的利用mini-batch的优点。

二、指数加权平均

1、概述

指数加权平均(exponentiallyweighted averages),是一种计算平均值的方式,这个本身不是用来做神经网络的优化,但是其思想在后面学到其他的梯度下降方式的时候,具有重要作用,故单独拎出来学习。

2、基本公式

假设现在需要计算一年某些天温度的均值。假设第t天(1月1日算第1天)的实际温度是θt。

现在假设一个值0<=β<=1,则要计算1/(1-β)天的温度的平均值,可以使用下面公式: 假设t=0时v0=0;t不等于0时,vt=βvt-1+(1-β)θt,这里的vt即在t天内温度的均值。

3、性质

且该函数具有如下性质:

当β比较小,则1/(1-β)也很小,则相当于计算很短时间内的温度均值,则这个均值会有很大的波动;

当β很大,则1/(1-β)也很大,相当于计算很多天内的均值,曲线会很平缓,但是其由于是计算很多天的均值,故其对某一天的温度不那么敏感,相当于有个延迟,曲线会整体偏右。

4、理解公式

现假设t=100,即计算100天内的均值,且设β=0.9,则公式如下:

v100=0.9v99+0.1θ100 =0.1θ100+0.9v99 =0.1θ100+0.9(0.1 v98+0.9θ99) =…=0.1θ100 +0.1*0.9θ99+0.1*0.92θ98+…0.1*0.999θ1

根据上式,可以认为是两个函数的相乘:一个是每一天的温度θt,另一个是一个衰减曲线,越往前的天数,对于当天的影响,也越是逐步降低,如下图所示:

下面解释这个计算结果是1/(1-β)天的温度的平均值的原因:

这个是一个近似的计算。当β=0.9,根据上面的公式,相当于计算10天的均值。而根据上面v的一串的公式,0.910≈1/e≈1/3,即超过10天的数据已经比相当于原值的1/3还要小,这里近似忽略不计。

5、偏差纠正

上面vt=βvt-1+(1-β)θt的公式,存在一个问题。因为这个计算,t=0时是从0开始的。而考虑到t=1时,则第一项为0 了,而第二项,由于β通常会设置一个接近1的数,则第二项的值会非常小。因此,相当于第一个数值,会比其真实的值,小非常多(如β设置成0.9,则用上面公式计算出来的v1,相当于实际v1的0.1倍)。当计算的平均的数多了,由于慢慢的数值加回去了,故会逐渐消除了这个影响。

因此,偏差纠正的目的,在于纠正刚开始一些平均值计算不准确的问题。

为了解决这个问题,对公式进行一些小改动,如下:

vt=(βvt-1+(1-β)θt )/(1-βt)。

当t非常大,即计算的总数很大时,1-βt≈1,则还是原来的公式,上面已经说了计算范围大了以后用原公式没有影响;当t很小的时候,由于1-βt也很小,故除以这个数,可以起到放大的作用,起到偏差纠正的效果。

只有需要使用到前面的一些均值的时候,才会这么做。通常如果数据量很大,而又没那么关心前面的一些数值,也可以不用进行偏差纠正,减少计算量。

6、优点

指数加权平均,最大的有点在于节约存储,且速度较快,因为计算前t个值的平均值,其只关心t-1个值的平均值,以及第t个数的数值。这样,计算机的内存并不需要存太多的东西,且代码量大大减少,结构清晰。

三、总结

在样本量特别大的时候,mini-batch确实有其明显的优势,收敛速度快得多。另外指数加权平均,我还没学到后面的课程,但是我个人认为,这个快速计算平均值的特性,应该可以用到后面计算代价函数上。因为代价函数也是要计算m个数的损失函数的均值。

——written by linhxx 2018.02.06

原文发布于微信公众号 - 决胜机器学习(phpthinker)

原文发表时间:2018-02-06

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏java闲聊

JDK1.8 ArrayList 源码解析

当运行 ArrayList<Integer> list = new ArrayList<>() ; ,因为它没有指定初始容量,所以它调用的是它的无参构造

1192
来自专栏ml

朴素贝叶斯分类器(离散型)算法实现(一)

1. 贝叶斯定理:        (1)   P(A^B) = P(A|B)P(B) = P(B|A)P(A)   由(1)得    P(A|B) = P(B|...

3427
来自专栏学海无涯

Android开发之奇怪的Fragment

说起Android中的Fragment,在使用的时候稍加注意,就会发现存在以下两种: v4包中的兼容Fragment,android.support.v4.ap...

3155
来自专栏计算机视觉与深度学习基础

Leetcode 114 Flatten Binary Tree to Linked List

Given a binary tree, flatten it to a linked list in-place. For example, Given...

1938
来自专栏拭心的安卓进阶之路

Java 集合深入理解(12):古老的 Vector

今天刮台风,躲屋里看看 Vector ! 都说 Vector 是线程安全的 ArrayList,今天来根据源码看看是不是这么相...

2437
来自专栏开发与安全

算法:AOV网(Activity on Vextex Network)与拓扑排序

在一个表示工程的有向图中,用顶点表示活动,用弧表示活动之间的优先关系,这样的有向图为顶点表示活动的网,我们称之为AOV网(Activity on Vextex ...

2517
来自专栏聊聊技术

原 初学图论-Kahn拓扑排序算法(Kah

2878
来自专栏Hongten

ArrayList VS Vector(ArrayList和Vector的区别)_面试的时候经常出现

1682
来自专栏刘君君

JDK8的HashMap源码学习笔记

3008
来自专栏MelonTeam专栏

ArrayList源码完全分析

导语: 这里分析的ArrayList是使用的JDK1.8里面的类,AndroidSDK里面的ArrayList基本和这个一样。 分析的方式是逐个API进行解析 ...

4479

扫码关注云+社区