Research on Batch Normalization

摘要:本文主要针对Batch Normalization技术,探究其对神经网络的作用,总结BN能够加速神经网络训练的原因,并对Internal covariate shift的情况进行探讨,同时探讨BN在Tensorflow中的实现。最后,简单介绍近年来对BN的改进,如Layer Norm和Group Norm。

一、什么是批归一化(Batch Normalization)?

Batch Normalization是一个为了克服神经网络层数增多导致难以训练的一项技术。它的主要优势在于可以缓解在深度神经网络中常见的梯度消失问题,从而使得深度神经网络的训练更加容易。Batch Norm通过对每一批(即一个batch)的训练数据进行归一化处理,来减少数据偏差对网络学习的影响。和传统意义上仅对输入数据做归一化处理不同的是,BN可以对每一层的输入进行归一化处理,以保证数据变为均值为0、标准差为1的分布。由于BN技术简单有效,在学术界和工业界多种应用中均取得了非常好的效果。

二、BN在神经网络的作用

2.1 Internal Covariate Shift的影响

要说明BN的原理,则不得不说Internal Covariate Shift对神经网络训练的影响,而BN则是为了降低它的影响,加速了神经网络的训练过程。

在统计机器学习中有一个常见的假设,即训练数据和测试数据是独立同分布的。如果这一假设不成立,就会在很大程度上影响模型的效果。一般的,Covariate Shift是指当训练数据和测试数据的分布不一致时,训练获得的模型泛化能力不好。

在Covariate Shift的基础上,对于神经网络的各层输出,在经过了层内操作后,各层输出分布就会与对应的输入信号分布不同,而且差异会随着网络深度增大而加大,造成深度神经网络难以训练的问题。由于输入分布的差异是在网路参数内部形成的,因此在最初提出BN的论文中,作者提出了一个Internal Covariate Shift的概念,表示在网络内部的参数和激活函数造成的各层输入分布不一致。进而造成梯度传播中的梯度弥散问题。

例如,考虑这样一个情况:网络输入x来着于一个均值为0、方差为1的标准分布;接着,输入数据经过一个Wx+b的仿射变换,于是输入数据均值变为b、方差变为D(W);如果再经过一个ReLU激活函数,则在这个基础上又会截断<0部分的数据,变成一个截断分布。而这些变化的结果又会进一步转化为下一层的输入,下一层的输出分布的表达也会更加复杂。因此,神经网络层数越多,收敛速度就会越慢,乃至无法训练。

2.2 BN如何解决Internal Covariate Shift

上述问题在BN提出之前,已成为神经网络精度的一个瓶颈。而BN之所以能够获得那么好的效果,主要原因也是它重点弱化了Internal Covariate Shift对网络的影响。

Batch Normalization变换的操作具体如下:

可以看到,BN主要分为两步,首先对输入数据按mini-batch进行一个归一化,变为均值为0、方差为1的分布。注意这里互联网上有些博客表示变换之后服从标准分布,这是不正确的;它应该依然服从输入数据的分布,只是均值和方差改变了,因此不一定是一个标准分布。

第二步则是在此基础上增加一个仿射变换,学习其超参数γ和β。这里主要原因是转化过后可能改变了输入的取值范围,因此需要在此基础上进行放缩和平移。我之前在DSSM模型上进行增加BN层实验的时候,曾将γ初值设为1、β初值设为0,但在之后的训练中这两个可学习参数几乎没变。由于这两个参数会直接改变下一层的输入,因此对网络的学习可能也是有影响的。

这两个操作进行结束后,输入变为一个均值为E(β)、方差为D(γ)的分布,而这两个超参数是根据梯度下降学习得到的,理论上经过足够多的训练,网络会学习到一个合适的分布,层之间的internal covariate shift将不再存在。

可以看到,这些操作都是可微分的,因此梯度的反向传播算法可以直接使用,在常见的支持自动求导的框架均可以方便的实现BN算法。

这里补充一点,均值和方差的统计量不只有按batch计算一种方式。在一些业务场景下,分布不均的训练数据按batch求出的E(x)和D(x)与全局分布相差甚远,典型例子就是计算广告的CTR预估系统。这时可以对一段时间全局的所有训练实例计算出均值和方差,只是这种方法一般这个计算量太大,所以一般才用按batch计算的简化方式。

2.3 Batch Norm的主要效果

BN虽然只用了一个简单的变换,没有太多的理论依据,却依然被广大用户欢迎。究其原因,主要还是效果好,笔者简单整理BN带来的主要效果如下:

  1. 极大提升了训练速度,收敛过程大大加快;
  2. 增加分类效果,一种解释是这是类似于Dropout的一种防止过拟合的正则化表达方式,所以不用Dropout也能达到相当的效果;
  3. 简化调参过程,对于初始化要求没那么高,而且可以使用大的学习率等。

三、Batch Norm的实现

3.1 BN训练与预测阶段的异同

上一节已经描述了BN在训练阶段的步骤,但预测(inference)阶段和训练阶段,尤其在线上预测等情况下,每个输入只有一个实例,显然没有办法直接求得均值和方差。为了解决这个问题,作者的解决方法是预测时使用的均值和方差,其实也是根据训练集计算得到的。在训练过程中,我们可以记录每一个batch的均值和方差,对这N个均值和方差求其数学期望即可得到全局的方差。具体计算过程为:

接下来,可以在预测时采用下面的式子进行计算:

3.2 Batch Norm在Tensorflow中的实现

TF官方定义了进行BN的API,根据其官方文档,分别为:tf.nn.batch_ normalization、tf.keras.layers.BatchNormalization(未来将替换掉tf.layers.batch_ normalization)和tf.contrib.layers.batch_norm。这三种实现都是基于BN的论文,其中主要的区别在于tf.nn.batch_normalization只进行了一个BN的计算,需要传入均值、方差以及γ和β两个超参数,而另外两者则通过一个类定义了网络的一个“Batch Norm层”,封装程度更高。尽管在程序里实例化出的BN层也是调用了tf.nn.batch_normalization进行计算,但是不再需要用户管理超参数初始化等繁杂事情,对用户更加友好。

一般来说,layers.batch_normalization可以直接被调用,但具体的BN计算还是交给了tf.nn.batch_normalization:

可以看到,这个操作只是进行一个典型的BN运算,所有变量都是被外部传进来,没有进行特殊的处理。

封装更好的tf.keras.layers.BatchNormalization定义如下:

可以看到,传入的主要参数有trainable(是否训练)、epsilon和各参数的初始化类型。后面各个函数的工程化比较复杂,篇幅限制不再赘述。

四、Batch Norm的兄弟姐妹

归一化层,除了Batch Normalization(2015年),目前主要有这几个方法:Layer Normalization(2016年)、Instance Normalization(2017年)、Group Normalization(2018年)、Switchable Normalization(2018年)。若将输入的图像shape记为[N, C, H, W],这几个方法主要的区别就是在:

  1. BatchNorm是在batch上,对NHW做归一化,对小batch size效果不好;
  2. LayerNorm在通道方向上,对CHW归一化,主要对RNN作用明显;
  3. InstanceNorm在图像像素上,对HW做归一化,用在风格化迁移;
  4. GroupNorm将channel分组,然后再做归一化;
  5. SwitchableNorm是将BN、LN、IN结合,赋予权重,让网络自己去学习归一化层应该使用什么方法。

(adsbygoogle = window.adsbygoogle || []).push({});

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

扫码关注云+社区

领取腾讯云代金券