optimizer.zero_grad()

传统的训练函数,一个batch是这么训练的:

for i,(images,target) in enumerate(train_loader):
    # 1. input output
    images = images.cuda(non_blocking=True)
    target = torch.from_numpy(np.array(target)).float().cuda(non_blocking=True)
    outputs = model(images)
    loss = criterion(outputs,target)

    # 2. backward
    optimizer.zero_grad()   # reset gradient
    loss.backward()
    optimizer.step()            
  1. 获取loss:输入图像和标签,通过infer计算得到预测值,计算损失函数;
  2. optimizer.zero_grad() 清空过往梯度;
  3. loss.backward() 反向传播,计算当前梯度;
  4. optimizer.step() 根据梯度更新网络参数

简单的说就是进来一个batch的数据,计算一次梯度,更新一次网络,使用梯度累加是这么写的:

for i,(images,target) in enumerate(train_loader):
    # 1. input output
    images = images.cuda(non_blocking=True)
    target = torch.from_numpy(np.array(target)).float().cuda(non_blocking=True)
    outputs = model(images)
    loss = criterion(outputs,target)

    # 2.1 loss regularization
    loss = loss/accumulation_steps   
    # 2.2 back propagation
    loss.backward()
    # 3. update parameters of net
    if((i+1)%accumulation_steps)==0:
        # optimizer the net
        optimizer.step()        # update parameters of net
        optimizer.zero_grad()   # reset gradient
  1. 获取loss:输入图像和标签,通过infer计算得到预测值,计算损失函数;
  2. loss.backward() 反向传播,计算当前梯度;
  3. 多次循环步骤1-2,不清空梯度,使梯度累加在已有梯度上;
  4. 梯度累加了一定次数后,先optimizer.step() 根据累计的梯度更新网络参数,然后optimizer.zero_grad() 清空过往梯度,为下一波梯度累加做准备;

总结来说:梯度累加就是,每次获取1个batch的数据,计算1次梯度,梯度不清空,不断累加,累加一定次数后,根据累加的梯度更新网络参数,然后清空梯度,进行下一次循环。

一定条件下,batchsize越大训练效果越好,梯度累加则实现了batchsize的变相扩大,如果accumulation_steps为8,则batchsize '变相' 扩大了8倍,是我们这种乞丐实验室解决显存受限的一个不错的trick,使用时需要注意,学习率也要适当放大。

更新1:关于BN是否有影响,之前有人是这么说的:

As far as I know, batch norm statistics get updated on each forward pass, so no problem if you don't do .backward() every time.

BN的估算是在forward阶段就已经完成的,并不冲突,只是accumulation_steps=8和真实的batchsize放大八倍相比,效果自然是差一些,毕竟八倍Batchsize的BN估算出来的均值和方差肯定更精准一些。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • YOLO v2

    相对于YOLOv1,改进后的v2版使用一种新的、多尺度的训练方法,相同的YOLOv2模型可以在不同的尺寸运行,在速度和准确性之间达到简单的折中。这个模型由于可以...

    于小勇
  • 反向传播和其他微分算法

    时,信息通过网络前向流动。输入x并提供初始信息,然后传播到每一层的隐藏单元,最终产生输出

    于小勇
  • tf.stop_gradient

    停止梯度计算。当在一个图中执行时,这个op按原样输出它的输入张量。当构建ops来计算梯度时,该op会阻止将其输入的贡献考虑在内。通常情况下,梯度发生器通过递归找...

    于小勇
  • 大数据为飞机制造业带来新的机遇与挑战

    导读:民机主制造商是天然的数据生产者和集成者,相关的数据类型和领域众多。在新一轮技术大变革背景下,这些数据无疑也为民机制造业带来了新的机遇与挑战。理解、认识、用...

    钱塘数据
  • hashMap

    https://www.cnblogs.com/skywang12345/category/455711.html

    大学里的混子
  • JavaScript的工作原理:解析、抽象语法树(AST)+ 提升编译速度5个技巧

    我们都知道运行一大段 JavaScript 代码性能会变得很糟糕。这段代码不仅需要通过网络传输,而且还需要解析、编译成字节码,最后执行。在之前的文章中,我们讨论...

    Fundebug
  • Win7下修改Hosts文件

    WIN7或者VISTAWIN7或者VISTA系统的需要提升用户对Hosts文件的操作权限,否则无效。 具体方法如下: 方法一:按着Shift键,然后Hosts文...

    跟着阿笨一起玩NET
  • log4net.SignalR - 日志即时发送客户端页面

    在log4net的配置中,appender是最重要的部分,一般来说,每一种appender都表示一种日志的输出介质,如日志文件、EvengLog、数据库、控制台...

    张善友
  • 大数据24小时 | 星环科技发布大数据一体机产品TxData,猎豹移动欲投5000万美元建机器人公司

    “南方大数据创新联盟”在粤宣布正式成立 ? 在“粤治—治理现代化”的经验交流会上,“南方大数据创新联盟”宣布正式成立,发起方为南方报业舆情数据研究院。据悉,该联...

    数据猿
  • JavaScript “袁华”飞雪特效

    马上就要到了传统节日“春节”?,网站添加了飞雪❄特效,从网上找了源代码,先要感谢张戈博客分享?,现计划将网站在今天上线至过年回来下线,看看可以么,你可以复制到自...

    Debug客栈

扫码关注云+社区

领取腾讯云代金券