pytorch学习笔记(十一):fine-tune 预训练的模型

torchvision 中包含了很多预训练好的模型,这样就使得 fine-tune 非常容易。本文主要介绍如何 fine-tune torchvision 中预训练好的模型。

安装

pip install torchvision

如何 fine-tune

以 resnet18 为例:

from torchvision import models
from torch import nn
from torch import optim

resnet_model = models.resnet18(pretrained=True) 
# pretrained 设置为 True,会自动下载模型 所对应权重,并加载到模型中
# 也可以自己下载 权重,然后 load 到 模型中,源码中有 权重的地址。

# 假设 我们的 分类任务只需要 分 100 类,那么我们应该做的是
# 1. 查看 resnet 的源码
# 2. 看最后一层的 名字是啥 (在 resnet 里是 self.fc = nn.Linear(512 * block.expansion, num_classes))
# 3. 在外面替换掉这个层
resnet_model.fc= nn.Linear(in_features=..., out_features=100)

# 这样就 哦了,修改后的模型除了输出层的参数是 随机初始化的,其他层都是用预训练的参数初始化的。

# 如果只想训练 最后一层的话,应该做的是:
# 1. 将其它层的参数 requires_grad 设置为 False
# 2. 构建一个 optimizer, optimizer 管理的参数只有最后一层的参数
# 3. 然后 backward, step 就可以了

# 这一步可以节省大量的时间,因为多数的参数不需要计算梯度
for para in list(resnet_model.parameters())[:-2]:
    para.requires_grad=False 

optimizer = optim.SGD(params=[resnet_model.fc.weight, resnet_model.fc.bias], lr=1e-3)

...

为什么

这里介绍下 运行resnet_model.fc= nn.Linear(in_features=..., out_features=100)时 框架内发生了什么

这时应该看 nn.Module 源码的 __setattr__ 部分,因为 setattr 时都会调用这个方法:

def __setattr__(self, name, value):
    def remove_from(*dicts):
        for d in dicts:
            if name in d:
                del d[name]

首先映入眼帘就是 remove_from 这个函数,这个函数的目的就是,如果出现了 同名的属性,就将旧的属性移除。 用刚才举的例子就是:

  • 预训练的模型中 有个 名字叫fc 的 Module。
  • 在类定义外,我们 将另一个 Module 重新 赋值给了 fc
  • 类定义内的 fc 对应的 Module 就会从 模型中 删除。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏数据结构与算法

一种递推组合数前缀和的Trick

\(S(n, m)\)可以看做是杨辉三角上的一行,而\(S(n+1, m)\)是他的下一行

613
来自专栏用户2442861的专栏

用python简单处理图片(4):图像中的像素访问

前面的一些例子中,我们都是利用Image.open()来打开一幅图像,然后直接对这个PIL对象进行操作。如果只是简单的操作还可以,但是如果操作稍微复杂一些,就...

492
来自专栏区块链

15分钟破解网站验证码

概述   很多开发者都讨厌网站的验证码,特别是写网络爬虫的程序员,而网站之所以设置验证码,是为了防止机器人访问网站,造成不必要的损失。现在好了,随着机器学习技术...

1667
来自专栏Python小屋

三种Fibonacci数列第n项计算方法及其优劣分析

感谢国防科技大学刘万伟老师和中国传媒大学胡凤国两位老师提供的思路,文章作者不能超过8个字符,我的名字就写个姓吧,名字不写了^_^。另外,除了本文讨论的三种方法,...

3207
来自专栏ACM算法日常

流问题Flow Problem(网络最大流)- HDU 3549

网络最大流问题属于算法 里面较难的问题,因为牵涉的概念比较多,这一篇可能需要你花比较多的时间去理解,除了看这个,最好能多参考别的书籍或者文章进行...

1001
来自专栏FreeBuf

中文点选验证码之自动识别

某次测试中遇到了汉字点选的验证码,看着很简单,尝试了一下发现有两种简单的识别方法,终于有空给重新整理一下,分享出来。

1164
来自专栏人人都是极客

TensorFlow极简入门教程

随着 TensorFlow 在研究及产品中的应用日益广泛,很多开发者及研究者都希望能深入学习这一深度学习框架。本文介绍了TensorFlow 基础,包括静态计算...

1104
来自专栏海说

12、借助Jacob实现Java打印报表(Excel、Word)

12、使用Jacob来处理文档   Word或Excel程序是以一种COM组件形式存在的。如果能够在Java中调用相应组件,便能使用它的方法来获取文档中的文本信...

2080
来自专栏简书专栏

基于逻辑回归的鸢尾花分类

Iris(鸢尾花)数据集是多重变量分析的数据集。 数据集包含150行数据,分为3类,每类50行数据。 每行数据包括4个属性:Sepal Length(花萼长...

961
来自专栏琦小虾的Binary

OpenCV像素点邻域遍历效率比较,以及访问像素点的几种方法

OpenCV像素点邻域遍历效率比较,以及访问像素点的几种方法 前言: 以前笔者在项目中经常使用到OpenCV的算法,而大部分OpenCV的算法都需要进行遍历操作...

33210

扫码关注云+社区