首先,简短介绍一下Batch Normalization,通常Batch Normalization更为大家所知,所以在此简要介绍BN来引入Instance Normalization。
引入BN层主要是为了解决"Internal Covariate Shift"问题,关于这个问题李宏毅老师有个视频讲解比较形象[4],可以参考。Batch Normalization主要是作用在batch上,对NHW做归一化,对小batchsize效果不好,添加了BN层能加快模型收敛,一定程度上还有的dropout的作用。
BN的基本思想其实相当直观:因为深层神经网络在做非线性变换前的激活输入值(就是那个x=WU+B,U是输入)随着网络深度加深或者在训练过程中,其分布逐渐发生偏移或者变动,之所以训练收敛慢,一般是整体分布逐渐往非线性函数的取值区间的上下限两端靠近(对于Sigmoid函数来说,意味着激活输入值WU+B是大的负值或正值),所以这导致反向传播时低层神经网络的梯度消失,这是训练深层神经网络收敛越来越慢的本质原因,而BN就是通过一定的规范化手段,把每层神经网络任意神经元这个输入值的分布强行拉回到均值为0方差为1的标准正态分布,其实就是把越来越偏的分布强制拉回比较标准的分布,这样使得激活输入值落在非线性函数对输入比较敏感的区域,这样输入的小变化就会导致损失函数较大的变化,意思是这样让梯度变大,避免梯度消失问题产生,而且梯度变大意味着学习收敛速度快,能大大加快训练速度。
在BN论文中有下面这样一幅图,比较清楚的表示了BN具体是怎么操作的:
前三步就是对一个batch内的数据进行归一化,使得数据分布一致:沿着通道计算每个batch的均值,计算每个batch的方差,对 X X X做归一化。重点在第四步,**加入缩放和平移参数 γ , β \gamma,\beta γ,β **。这两个参数可以通过学习得到,增加这两个参数的主要目的是完成归一化之余,还要保留原来学习到的特征。
总结如下:
IN和BN最大的区别是,IN作用于单张图片,BN作用于一个batch。IN多适用于生成模型中,例如风格迁移。像风格迁移这类任务,每个像素点的信息都非常重要,BN就不适合这类任务。BN归一化考虑了一个batch中所有图片,这样会令每张图片中特有的细节丢失。IN对HW做归一化,同时保证了每个图像实例之间的独立。
论文中所给的公式如下:
总结如下:
def Instancenorm(x, gamma, beta):
# x_shape:[B, C, H, W]
results = 0.
eps = 1e-5
x_mean = np.mean(x, axis=(2, 3), keepdims=True)
x_var = np.var(x, axis=(2, 3), keepdims=True0)
x_normalized = (x - x_mean) / np.sqrt(x_var + eps)
results = gamma * x_normalized + beta
return results
pytorch中使用BN和IN:
class IBNorm(nn.Module):
""" Combine Instance Norm and Batch Norm into One Layer
"""
def __init__(self, in_channels):
super(IBNorm, self).__init__()
in_channels = in_channels
self.bnorm_channels = int(in_channels / 2)
self.inorm_channels = in_channels - self.bnorm_channels
self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True)
self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False) # IN,多用于风格迁移
def forward(self, x):
bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous())
in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous())
return torch.cat((bn_x, in_x), 1)
下图来自何凯明大神2018年的论文Group Normalization[3],可以说很直观了。
[4] 【深度学习李宏毅 】 Batch Normalization (中文)
[4] *深入理解Batch Normalization批标准化
[6] *BatchNormalization、LayerNormalization、InstanceNorm、GroupNorm、SwitchableNorm总结
GroupNorm、SwitchableNorm总结](https://blog.csdn.net/liuxiao214/article/details/81037416)