前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >pytorch BatchNorm参数详解,计算过程

pytorch BatchNorm参数详解,计算过程

作者头像
全栈程序员站长
发布2022-09-01 15:17:22
1.1K0
发布2022-09-01 15:17:22
举报
文章被收录于专栏:全栈程序员必看

大家好,又见面了,我是你们的朋友全栈君。

目录

说明

BatchNorm1d参数

num_features

eps

momentum

affine

track_running_stats

BatchNorm1d训练时前向传播

BatchNorm1d评估时前向传播

总结


说明

网络训练时和网络评估时,BatchNorm模块的计算方式不同。如果一个网络里包含了BatchNorm,则在训练时需要先调用train(),使网络里的BatchNorm模块的training=True(默认是True),在网络评估时,需要先调用eval(),使网络里的BatchNorm模块的training=False。

BatchNorm1d参数

代码语言:javascript
复制
torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

num_features

输入维度是(N, C, L)时,num_features应该取C;这里N是batch size,C是数据的channel,L是数据长度。

输入维度是(N, L)时,num_features应该取L;这里N是batch size,L是数据长度,这时可以认为每条数据只有一个channel,省略了C

eps

对输入数据进行归一化时加在分母上,防止除零,详情见下文。

momentum

更新全局均值running_mean和方差running_var时使用该值进行平滑,详情见下文。

affine

设为True时,BatchNorm层才会学习参数

,否则不包含这两个变量,变量名是weightbias,详情见下文。

track_running_stats

设为True时,BatchNorm层会统计全局均值running_mean和方差running_var,详情见下文。

BatchNorm1d训练时前向传播

首先对输入batch求和,并用这两个结果把batch归一化,使其均值为0,方差为1。归一化公式用到了eps(),即。如下输入内容,shape是(3, 4),即batch_size=3,此时num_features需要传入4。

如果==True,则使用momentum更新模块内部的(初值是[0., 0., 0., 0.])和(初值是[1., 1., 1., 1.]),更新公式是,其中代表更新后的和,表示更新前的和,表示当前batch的均值和无偏样本方差。

如果==False,则BatchNorm中不含有和两个变量。

如果==True,则对归一化后的batch进行仿射变换,即乘以模块内部的(初值是[1., 1., 1., 1.])然后加上模块内部的(初值是[0., 0., 0., 0.]),这两个变量会在反向传播时得到更新。

如果==False,则BatchNorm中不含有和两个变量,什么都都不做。

BatchNorm1d评估时前向传播

  1. 如果track_running_stats==True,则对batch进行归一化,公式为

,注意这里的均值和方差是running_meanrunning_var,在网络训练时统计出来的全局均值和无偏样本方差。

  1. 如果track_running_stats==False,则对batch进行归一化,公式为

,注意这里的均值和方差是batch自己的mean和var,此时BatchNorm里不含有running_meanrunning_var。注意此时使用的是无偏样本方差(和训练时不同),因此如果batch_size=1,会使分母为0,就报错了。

  1. 如果affine==True,则对归一化后的batch进行放射变换,即乘以模块内部的weight然后加上模块内部的bias,这两个变量都是网络训练时学习到的。
  2. 如果affine==False,则BatchNorm中不含有weightbias两个变量,什么都不做。

总结

在使用batchNorm时,通常只需要指定num_features就可以了。网络训练前调用train(),训练时BatchNorm模块会统计全局running_meanrunning_var,学习weightbias,即文献中的

。网络评估前调用eval(),评估时,对传入的batch,使用统计的全局running_meanrunning_var对batch进行归一化,然后使用学习到的weightbias进行仿射变换。

发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/141980.html原文链接:https://javaforall.cn

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2022年5月2,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 说明
  • BatchNorm1d参数
    • num_features
      • eps
        • momentum
          • affine
            • track_running_stats
            • BatchNorm1d训练时前向传播
            • BatchNorm1d评估时前向传播
            • 总结
            相关产品与服务
            批量计算
            批量计算(BatchCompute,Batch)是为有大数据计算业务的企业、科研单位等提供高性价比且易用的计算服务。批量计算 Batch 可以根据用户提供的批处理规模,智能地管理作业和调动其所需的最佳资源。有了 Batch 的帮助,您可以将精力集中在如何分析和处理数据结果上。
            领券
            问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档