前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【炼丹技巧】指数移动平均(EMA)的原理及PyTorch实现

【炼丹技巧】指数移动平均(EMA)的原理及PyTorch实现

作者头像
朴素人工智能
发布2020-06-15 18:27:37
5.6K0
发布2020-06-15 18:27:37
举报
文章被收录于专栏:朴素人工智能

来自 | 知乎 地址 | https://zhuanlan.zhihu.com/p/68748778

作者 | Nicolas

编辑 | 朴素人工智能

在深度学习中,经常会使用EMA(指数移动平均)这个方法对模型的参数做平均,以求提高测试指标并增加模型鲁棒。

今天瓦砾准备介绍一下EMA以及它的Pytorch实现代码。

EMA的定义

在深度学习的优化中的EMA

EMA的偏差修正

EMA为什么有效

PyTorch实现

瓦砾看了网上的一些实现,使用起来都不是特别方便,所以自己写了一个。

代码语言:javascript
复制
class EMA():
    def __init__(self, model, decay):
        self.model = model
        self.decay = decay
        self.shadow = {}
        self.backup = {}

    def register(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
                self.shadow[name] = new_average.clone()

    def apply_shadow(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                self.backup[name] = param.data
                param.data = self.shadow[name]

    def restore(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.backup
                param.data = self.backup[name]
        self.backup = {}

# 初始化
ema = EMA(model, 0.999)
ema.register()

# 训练过程中,更新完参数后,同步update shadow weights
def train():
    optimizer.step()
    ema.update()

# eval前,apply shadow weights;eval之后,恢复原来模型的参数
def evaluate():
    ema.apply_shadow()
    # evaluate
    ema.restore()

References

  1. 机器学习模型性能提升技巧: 指数加权平均(EMA), https://blog.csdn.net/mikelkl/article/details/85227053
  2. Exponential Weighted Average for Deep Neutal Networks, https://www.ashukumar27.io/exponentially-weighted-average/
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-06-09,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 朴素人工智能 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • EMA的定义
  • 在深度学习的优化中的EMA
  • EMA的偏差修正
  • EMA为什么有效
  • PyTorch实现
  • References
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档