前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >推荐一个神级工具:能缓解梯度消失问题&提升训练速度

推荐一个神级工具:能缓解梯度消失问题&提升训练速度

作者头像
double
发布2019-09-18 15:40:24
8150
发布2019-09-18 15:40:24
举报
文章被收录于专栏:算法channel算法channel

推荐阅读:

完整图解:特征工程提速40倍的4个方法

深度学习的一个本质问题

  • Internal Covariate Shift
  • 什么是BN

深度学习的一个本质问题

深度神经网络一直以来就有一个特点:随着网络加深,模型会越来越难以训练。所以深度学习有一个非常本质性的问题:为什么随着网络加深,训练会越来越困难?为了解决这个问题,学界业界也一直在尝试各种方法。

sigmoid作为激活函数一个最大的问题会引起梯度消失现象,这使得神经网络难以更新权重。使用ReLu激活函数可以有效的缓解这一问题。

对神经网络使用正则化方法也能对这个问题有所帮助,使用dropout来对神经网络进行简化,可以有效缓解神经网络的过拟合问题,对于深度网络的训练也有一定的帮助。ResNet使用残差块和skip connection来解决这个问题,使得深度加深时网络仍有较好的表现力。

BN本质上也是一种解决深度神经网络难以训练问题的方法。

Internal Covariate Shift

机器学习的一个重要假设就是IID(Independent Identically Distributed)假设,即独立同分布假设。所谓独立同分布,就是指训练数据和测试数据是近似于同分布的,如若不然,机器学习模型就会很难有较好的泛化性能。

一个重要的观点就是深度神经网络在训练过程中每一层的输入并不满足独立同分布假设,当叠加的网络层每一层的输入分布都发生变化时,这使得神经网络训练难以收敛。这种神经网络隐藏层输入分布的不断变化的现象就叫Internal Covariate Shift(ICS)。ICS问题正是导致深度神经网络难以训练的重要原因之一。

什么是BN

一直在做铺垫,还没说到底什么是BN。Batch Normalization,简称BN,翻译过来就是批标准化,因为这个Normalization是建立在Mini-Batch SGD的基础之上的。BN是针对ICS问题而提出的一种解决方案。一句话来说,BN就是使得深度神经网络训练过程中每一层网络输入都保持相同分布。

既然ICS问题表明神经网络隐藏层输入分布老是不断变化,我们能否让每个隐藏层输入分布稳定下来?通常来说,数据标准化是将数据喂给机器学习模型之前一项重要的数据预处理技术,数据标准化也即将数据分布变换成均值为0,方差为1的标准正态分布,所以也叫0-1标准化。图像处理领域的数据标准化也叫白化(whiten),当然,白化方法除了0-1标准化之外,还包括极大极小标准化方法。

所以一个很关键的联想就是能否将这种白化操作推广到神经网络的每一个隐藏层?答案当然是可以的。

ICS问题导致深度神经网络训练难以收敛,隐藏层输入分布逐渐向非线性激活函数取值区间的两端靠近,比如说sigmoid函数的两端就是最大正值或者最小负值。这里说一下梯度饱和和梯度敏感的概念。当取值位于sigmoid函数的两端时,即sigmoid取值接近0或1时,梯度接近于0,这时候就位于梯度饱和区,也就是容易产生梯度消失的区域,相应的梯度敏感就是梯度计算远大于0,神经网络反向传播时每次都能使权重得到很好的更新。

当梯度逐渐向这两个区域靠近时,就会产生梯度爆炸或者梯度消失问题,这也是深度神经网络难以训练的根本原因。BN将白化操作应用到每一个隐藏层,对每个隐藏层输入分布进行标准化变换,把每层的输入分布都强行拉回到均值为0方差为1的标准正态分布。这样一来,上一层的激活输出值(即当前层的激活输入值)就会落在非线性函数对输入的梯度敏感区,远离了原先的梯度饱和区,神经网络权重易于更新,训练速度相应加快。

那么具体到实际应用时,BN操作应该放在哪里?以一个全连接网络为例:

可以看到,BN操作是对每一个隐藏层的激活输出做标准化,即BN层位于隐藏层之后。对于Mini-Batch SGD来说,一次训练包含了m个样本,具体的BN变换就是执行以下公式的过程:

这里有个问题,就是在标准化之后为什么又做了个scale and shift的变换。从作者在论文中的表述看,认为每一层都做BN之后可能会导致网络的表征能力下降,所以这里增加两个调节参数(scale和shift),对变换之后的结果进行反变换,弥补网络的表征能力。

BN不仅原理上说的通,但关键还是效果好。BN大大缓解了梯度消失问题,提升了训练速度,模型准确率也得到提升,另外BN还有轻微的正则化效果。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-09-11,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 程序员郭震zhenguo 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
批量计算
批量计算(BatchCompute,Batch)是为有大数据计算业务的企业、科研单位等提供高性价比且易用的计算服务。批量计算 Batch 可以根据用户提供的批处理规模,智能地管理作业和调动其所需的最佳资源。有了 Batch 的帮助,您可以将精力集中在如何分析和处理数据结果上。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档