前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【技术分享】pytorch的FINETUNING实践(resnet18 cifar10)

【技术分享】pytorch的FINETUNING实践(resnet18 cifar10)

原创
作者头像
ascehuang
修改2019-12-01 16:16:39
2.1K0
修改2019-12-01 16:16:39
举报
文章被收录于专栏:腾讯云TI平台腾讯云TI平台

本文主要是用pytorch训练resnet18模型,对cifar10进行分类,然后将cifar10的数据进行调整,加载已训练好的模型,在原有模型上FINETUNING 对调整的数据进行分类, 可参考pytorch官网教程

resnet18模型

pytorch的resnet18模型引用:https://github.com/kuangliu/pytorch-cifar

模型详情可参考github里面的models/resnet.py, 这里不做详细的说明,readme描述准确率可达到93.02%,但我本地测试迭代200次没有达到这个数字,本地200次迭代准确率为87.40%。

导入需要的包

代码语言:txt
复制
import os

import numpy as np
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

from models import *
from utils import progress_bar

设置随机种子,让结果可复现

这里尝试了比较久,在cpu上运行,只需要设置torch.manual_seed(SEED)即可稳定复现结果,但在GPU上始终不行,总存在randomness的问题,后来在友人的帮助下,查了官方的资料,终于解决了这个问题,感谢。其中tensorflow在GPU似乎做不到结果可稳定复现,如果有知道的同学,还请不吝指导~

代码语言:txt
复制
SEED = 0
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)

设置是运行在cpu上还是gpu上

根据是否有gpu可用选择运行的设备,注意驱动的安装,版本的兼容性,驱动也折磨了我很久。。由于我运行在docker中,下载的驱动版本不一致,导致一直检测不到gpu

代码语言:txt
复制
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0 
start_epoch = 0

数据加载及预处理

数据存放在py文件同级目录下的data文件夹下,如果数据不存在,download设置的为True,会自动从pytorch上进行下载,这里对数据进行不同的转换,增加数据多样性。

代码语言:txt
复制
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

对数据集进行调整

原来cifar数据集包含10个类别

代码语言:txt
复制
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

需要实践FINETUNING,所以对数据集进行了改造,由10类改为2类,分别为动物和运输工具。马算不算交通工具呢?^.^

代码语言:txt
复制
clz_idx = trainset.class_to_idx
clz_to_idx = {'animal': 0, 'transport': 1}
clz = ['animal', 'transport']
animal_name = ["bird", "cat", "deer", "dog", "frog", "horse"]
animal = [clz_idx[x] for x in animal_name]

trainset.targets = [0 if x in animal else 1 for x in trainset.targets]
trainset.class_to_idx = clz_to_idx
trainset.classes = clz
testset.targets = [0 if x in animal else 1 for x in testset.targets]
testset.class_to_idx = clz_to_idx
testset.classes = clz

加载预训练的模型

模型存放在checkpoint目录下,模型的训练是上述的Resnet18, 注意如果是gpu训练,尤其关注一下if中代码的顺序。

  • 将net装换为DataParallel,用以并行训练,因为原Resnet18在gpu上训练使用了DataParallel,所以这里也要进行封装,会包一层module
  • FINETUNING:将最后一层的10类输出,改为2类输出。注意gpu中的写法,net.module.linear
  • net = net.to(device) 修改了模型之后,要将模型推送到gpu上,这步不能提前,会出现参数不在GPU上的错误
代码语言:txt
复制
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/ckpt.pth')

net = ResNet18()
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    net.load_state_dict(checkpoint['net'])
    net.module.linear = nn.Linear(net.module.linear.in_features, 2)
else:
    net.load_state_dict(checkpoint['net'])
    net.linear = nn.Linear(net.linear.in_features, 2)

net = net.to(device)

指定不需要调整的层数

指定前40层的参数固定,不需要再学习

代码语言:txt
复制
for idx, (name, param) in enumerate(net.named_parameters()):
    if idx > 40:  # count of layers is 62
        param.requires_grad = False

    if param.requires_grad == True:
        print("\t", idx, name)

loss函数和优化算法

代码语言:txt
复制
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

训练函数 及 测试函数

参考Resnet18中的main.py, 在测试的时候,保存训练的结果,用以后续继续训练,区分文件夹保存, 同时只有在精度提高的基础上进行保存

代码语言:txt
复制
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        # print('%d/%d, [Loss: %.03f | Acc: %.3f%% (%d/%d)]'
        #       % (batch_idx+1, len(trainloader), train_loss/(batch_idx+1), 100.*correct/total, correct, total))
        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss / (batch_idx + 1), 100. * correct / total, correct, total))

best_acc = 0
def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))

    # Save checkpoint.
    acc = 100. * correct / total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint_ft'):
            os.mkdir('checkpoint_ft')
        torch.save(state, './checkpoint_ft/ckpt.pth')
        best_acc = acc

开始训练

由于在已经训练好的模型的基础上训练,这里的迭代次数不用太多即可以达到较高的准确率

代码语言:txt
复制
for epoch in range(start_epoch, start_epoch + 20):
    train(epoch)
    test(epoch)

结果展示

代码语言:txt
复制
Epoch: 0
 [================================================================>]  Step: 53ms | Tot: 3 391/391  Loss: 0.520 | Acc: 88.662% (44331/50000)
 [================================================================>]  Step: 21ms | Tot 100/100 | Loss: 0.449 | Acc: 95.090% (9509/10000)
Saving..

Epoch: 1
 [================================================================>]  Step: 53ms | Tot: 3 391/391  Loss: 0.430 | Acc: 95.342% (47671/50000)
 [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.411 | Acc: 95.590% (9559/10000)
Saving..

Epoch: 2
 [================================================================>]  Step: 53ms | Tot: 3 391/391  Loss: 0.394 | Acc: 95.816% (47908/50000)
 [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.373 | Acc: 96.110% (9611/10000)
Saving..

Epoch: 3
 [================================================================>]  Step: 54ms | Tot: 3 391/391  Loss: 0.376 | Acc: 96.002% (48001/50000)
 [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.386 | Acc: 94.560% (9456/10000)

Epoch: 4
 [================================================================>]  Step: 54ms | Tot: 3 391/391  Loss: 0.368 | Acc: 96.160% (48080/50000)
 [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.365 | Acc: 96.350% (9635/10000)
Saving..

Epoch: 5
 [================================================================>]  Step: 53ms | Tot: 3 391/391  Loss: 0.362 | Acc: 96.214% (48107/50000)
 [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.381 | Acc: 93.430% (9343/10000)

Epoch: 6
 [================================================================>]  Step: 54ms | Tot: 3 391/391  Loss: 0.360 | Acc: 96.070% (48035/50000)
 [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.362 | Acc: 95.400% (9540/10000)

Epoch: 7
 [================================================================>]  Step: 53ms | Tot: 3 391/391  Loss: 0.358 | Acc: 96.062% (48031/50000)
 [================================================================>]  Step: 21ms | Tot 100/100 | Loss: 0.400 | Acc: 90.730% (9073/10000)

Epoch: 8
 [================================================================>]  Step: 54ms | Tot: 3 391/391  Loss: 0.356 | Acc: 96.214% (48107/50000)
 [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.362 | Acc: 96.280% (9628/10000)

Epoch: 9
 [================================================================>]  Step: 53ms | Tot: 3 391/391  Loss: 0.353 | Acc: 96.242% (48121/50000)
 [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.376 | Acc: 94.590% (9459/10000)

Epoch: 10
 [================================================================>]  Step: 53ms | Tot: 3 391/391  Loss: 0.352 | Acc: 96.348% (48174/50000)
 [================================================================>]  Step: 21ms | Tot 100/100 | Loss: 0.384 | Acc: 93.080% (9308/10000)

Epoch: 11
 [================================================================>]  Step: 54ms | Tot: 3 391/391  Loss: 0.351 | Acc: 96.236% (48118/50000)
 [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.356 | Acc: 95.480% (9548/10000)

Epoch: 12
 [================================================================>]  Step: 53ms | Tot: 3 391/391  Loss: 0.350 | Acc: 96.348% (48174/50000)
 [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.383 | Acc: 93.170% (9317/10000)


Epoch: 13
 [================================================================>]  Step: 53ms | Tot: 3 391/391  Loss: 0.348 | Acc: 96.358% (48179/50000)
 [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.373 | Acc: 93.330% (9333/10000)

Epoch: 14
 [================================================================>]  Step: 53ms | Tot: 3 391/391  Loss: 0.347 | Acc: 96.446% (48223/50000)
 [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.391 | Acc: 91.670% (9167/10000)

Epoch: 15
 [================================================================>]  Step: 54ms | Tot: 3 391/391  Loss: 0.346 | Acc: 96.324% (48162/50000)
 [================================================================>]  Step: 21ms | Tot 100/100 | Loss: 0.347 | Acc: 95.880% (9588/10000)

Epoch: 16
 [================================================================>]  Step: 54ms | Tot: 3 391/391  Loss: 0.344 | Acc: 96.488% (48244/50000)
 [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.343 | Acc: 95.980% (9598/10000)

Epoch: 17
 [================================================================>]  Step: 53ms | Tot: 3 391/391  Loss: 0.344 | Acc: 96.416% (48208/50000)
 [================================================================>]  Step: 21ms | Tot 100/100 | Loss: 0.344 | Acc: 95.890% (9589/10000)

Epoch: 18
 [================================================================>]  Step: 54ms | Tot: 3 391/391  Loss: 0.344 | Acc: 96.370% (48185/50000)
 [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.354 | Acc: 95.060% (9506/10000)

Epoch: 19
 [================================================================>]  Step: 53ms | Tot: 3 391/391  Loss: 0.344 | Acc: 96.338% (48169/50000)
 [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.399 | Acc: 89.760% (8976/10000)

在已有准确率为87.4%的Resnet18模型上进行FINETUNING二分类,第一次迭代准确率就能达到95.09%,收敛速度还是很快的,分类效果也不错。

最终20次迭代测试集最高为96.11%。

最后

pytorch构建模型比较简单,代码看起来也很清晰,文档支持的比较全面。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • resnet18模型
  • 导入需要的包
  • 设置随机种子,让结果可复现
  • 设置是运行在cpu上还是gpu上
  • 数据加载及预处理
  • 对数据集进行调整
  • 加载预训练的模型
  • 指定不需要调整的层数
  • loss函数和优化算法
  • 训练函数 及 测试函数
  • 开始训练
  • 结果展示
  • 最后
相关产品与服务
容器服务
腾讯云容器服务(Tencent Kubernetes Engine, TKE)基于原生 kubernetes 提供以容器为核心的、高度可扩展的高性能容器管理服务,覆盖 Serverless、边缘计算、分布式云等多种业务部署场景,业内首创单个集群兼容多种计算节点的容器资源管理模式。同时产品作为云原生 Finops 领先布道者,主导开源项目Crane,全面助力客户实现资源优化、成本控制。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档