前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >pytorch的余弦退火学习率

pytorch的余弦退火学习率

作者头像
zenRRan
发布2020-11-19 14:19:33
3.4K0
发布2020-11-19 14:19:33
举报

作者:limzero

地址:https://www.zhihu.com/people/lim0-34

编辑:人工智能前沿讲习

最近深入了解了下pytorch下面余弦退火学习率的使用.网络上大部分教程都是翻译的pytorch官方文档,并未给出一个很详细的介绍,由于官方文档也只是给了一个数学公式,对参数虽然有解释,但是解释得不够明了,这样一来导致我们在调参过程中不能合理的根据自己的数据设置合适的参数.这里作一个笔记,并且给出一些定性和定量的解释和结论.说到pytorch自带的余弦学习率调整方法,通常指下面这两个

CosineAnnealingLR

CosineAnnealingWarmRestarts

CosineAnnealingLR

这个比较简单,只对其中的最关键的Tmax参数作一个说明,这个可以理解为余弦函数的半周期.如果max_epoch=50次,那么设置T_max=5则会让学习率余弦周期性变化5次.

max_opoch=50, T_max=5

CosineAnnealingWarmRestarts

这个最主要的参数有两个:

  • T_0:学习率第一次回到初始值的epoch位置
  • T_mult:这个控制了学习率变化的速度
    • 如果T_mult=1,则学习率在T_0,2T_0,3T_0,....,i*T_0,....处回到最大值(初始学习率)
      • 5,10,15,20,25,.......处回到最大值
    • 如果T_mult>1,则学习率在T_0,(1+T_mult)T_0,(1+T_mult+T_mult**2)T_0,.....,(1+T_mult+T_mult2+...+T_0i)*T0,处回到最大值
      • 5,15,35,75,155,.......处回到最大值

T_0=5, T_mult=1

T_0=5, T_mult=2

所以可以看到,在调节参数的时候,一定要根据自己总的epoch合理的设置参数,不然很可能达不到预期的效果,经过我自己的试验发现,如果是用那种等间隔的退火策略(CosineAnnealingLR和Tmult=1的CosineAnnealingWarmRestarts),验证准确率总是会在学习率的最低点达到一个很好的效果,而随着学习率回升,验证精度会有所下降.所以为了能最终得到一个更好的收敛点,设置T_mult>1是很有必要的,这样到了训练后期,学习率不会再有一个回升的过程,而且一直下降直到训练结束。

下面是使用示例和画图的代码:

代码语言:javascript
复制
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR,CosineAnnealingWarmRestarts,StepLR
import torch.nn as nn
from torchvision.models import resnet18
import matplotlib.pyplot as plt
#
model=resnet18(pretrained=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
mode='cosineAnnWarm'
if mode=='cosineAnn':
    scheduler = CosineAnnealingLR(optimizer, T_max=5, eta_min=0)
elif mode=='cosineAnnWarm':
    scheduler = CosineAnnealingWarmRestarts(optimizer,T_0=5,T_mult=1)
    '''
    以T_0=5, T_mult=1为例:
    T_0:学习率第一次回到初始值的epoch位置.
    T_mult:这个控制了学习率回升的速度
        - 如果T_mult=1,则学习率在T_0,2*T_0,3*T_0,....,i*T_0,....处回到最大值(初始学习率)
            - 5,10,15,20,25,.......处回到最大值
        - 如果T_mult>1,则学习率在T_0,(1+T_mult)*T_0,(1+T_mult+T_mult**2)*T_0,.....,(1+T_mult+T_mult**2+...+T_0**i)*T0,处回到最大值
            - 5,15,35,75,155,.......处回到最大值
    example:
        T_0=5, T_mult=1
    '''
plt.figure()
max_epoch=50
iters=200
cur_lr_list = []
for epoch in range(max_epoch):
    for batch in range(iters):
        '''
        这里scheduler.step(epoch + batch / iters)的理解如下,如果是一个epoch结束后再.step
        那么一个epoch内所有batch使用的都是同一个学习率,为了使得不同batch也使用不同的学习率
        则可以在这里进行.step
        '''
        #scheduler.step(epoch + batch / iters)
        optimizer.step()
    scheduler.step()
    cur_lr=optimizer.param_groups[-1]['lr']
    cur_lr_list.append(cur_lr)
    print('cur_lr:',cur_lr)
x_list = list(range(len(cur_lr_list)))
plt.plot(x_list, cur_lr_list)
plt.show()

最后,对 scheduler.step(epoch + batch / iters)的一个说明,这里的个人理解:一个epoch结束后再.step, 那么一个epoch内所有batch使用的都是同一个学习率,为了使得不同batch也使用不同的学习率 ,则可以在这里进行.step(将离散连续化,或者说使得采样得更加的密集),下图是以20个epoch,每个epoch5个batch,T0=2,Tmul=2画的学习率变化图

代码:

代码语言:javascript
复制
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR,CosineAnnealingWarmRestarts,StepLR
import torch.nn as nn
from torchvision.models import resnet18
import matplotlib.pyplot as plt
#
model=resnet18(pretrained=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
mode='cosineAnnWarm'
if mode=='cosineAnn':
    scheduler = CosineAnnealingLR(optimizer, T_max=5, eta_min=0)
elif mode=='cosineAnnWarm':
    scheduler = CosineAnnealingWarmRestarts(optimizer,T_0=2,T_mult=2)
    '''
    以T_0=5, T_mult=1为例:
    T_0:学习率第一次回到初始值的epoch位置.
    T_mult:这个控制了学习率回升的速度
        - 如果T_mult=1,则学习率在T_0,2*T_0,3*T_0,....,i*T_0,....处回到最大值(初始学习率)
            - 5,10,15,20,25,.......处回到最大值
        - 如果T_mult>1,则学习率在T_0,(1+T_mult)*T_0,(1+T_mult+T_mult**2)*T_0,.....,(1+T_mult+T_mult**2+...+T_0**i)*T0,处回到最大值
            - 5,15,35,75,155,.......处回到最大值
    example:
        T_0=5, T_mult=1
    '''
plt.figure()
max_epoch=20
iters=5
cur_lr_list = []
for epoch in range(max_epoch):
    print('epoch_{}'.format(epoch))
    for batch in range(iters):
        scheduler.step(epoch + batch / iters)
        optimizer.step()
        #scheduler.step()
        cur_lr=optimizer.param_groups[-1]['lr']
        cur_lr_list.append(cur_lr)
        print('cur_lr:',cur_lr)
    print('epoch_{}_end'.format(epoch))
x_list = list(range(len(cur_lr_list)))
plt.plot(x_list, cur_lr_list)
plt.show()

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-11-10,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 深度学习自然语言处理 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
批量计算
批量计算(BatchCompute,Batch)是为有大数据计算业务的企业、科研单位等提供高性价比且易用的计算服务。批量计算 Batch 可以根据用户提供的批处理规模,智能地管理作业和调动其所需的最佳资源。有了 Batch 的帮助,您可以将精力集中在如何分析和处理数据结果上。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档