首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【小白学习PyTorch教程】十二、迁移学习:微调VGG19实现图像分类

【小白学习PyTorch教程】十二、迁移学习:微调VGG19实现图像分类

作者头像
润森
发布2022-08-18 09:29:50
8850
发布2022-08-18 09:29:50
举报

「@Author:Runsen」

前言:迁移学习就是利用数据、任务或模型之间的相似性,将在旧的领域学习过或训练好的模型,应用于新的领域这样的一个过程。从这段定义里面,我们可以窥见迁移学习的关键点所在,即新的任务与旧的任务在数据、任务和模型之间的相似性。

假设有两个任务系统A和B,任务A拥有海量的数据资源且已训练好,但并不是我们的目标任务,任务B是我们的目标任务,但数据量少且极为珍贵,这种场景便是典型的迁移学习的应用场景

接下来在博客中,我们将学习如何将迁移学习与 PyTorch 结合使用。

在这个迁移学习 PyTorch 图像二分类Vgg19 示例中,数据来源:https://www.kaggle.com/pmigdal/alien-vs-predator-images/home

这是我在kaggle找到的关于迁移学习的入门案例

1) 加载数据

第一步是加载数据并对图像进行一些转换,使其符合网络要求。

使用 torchvision.dataset ,在文件夹中加载数据。该模块将在文件夹中迭代以拆分数据以进行训练和验证。

转换过程进行基本的图片处理操作。

将从中心裁剪图像,执行水平翻转,归一化,最后使用将其转换为张量。

import os
import time
import torch
import torchvision
from torchvision import datasets, models, transforms
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

data_dir = "alien_pred"
input_shape = 224
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]

#data transformation
data_transforms = {
   'train': transforms.Compose([
       transforms.CenterCrop(input_shape),
       transforms.ToTensor(),
       transforms.Normalize(mean, std)
   ]),
   'validation': transforms.Compose([
       transforms.CenterCrop(input_shape),
       transforms.ToTensor(),
       transforms.Normalize(mean, std)
   ]),
}

image_datasets = {
   x: datasets.ImageFolder(
       os.path.join(data_dir, x),
       transform=data_transforms[x]
   )
   for x in ['train', 'validation']
}

dataloaders = {
   x: torch.utils.data.DataLoader(
       image_datasets[x], batch_size=32,
       shuffle=True, num_workers=4
   )
   for x in ['train', 'validation']
}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'validation']}

print(dataset_sizes)
# {'train': 694, 'validation': 200}

class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

可视化 PyTorch 迁移学习数据集。可视化过程将从训练数据加载器和标签中获取下一批图像,并用 matplot 显示它。

images, labels = next(iter(dataloaders['train']))



rows = 4
columns = 4
fig=plt.figure(figsize=(15,15))


for i in range(16):
   fig.add_subplot(rows, columns, i+1)
   plt.title(class_names[labels[i]])
   img = images[i].numpy().transpose((1, 2, 0))
   img = std * img + mean
   plt.imshow(img)
plt.show()

2) 定义模型

VGG19有两个部分,分别是VGG19.features和VGG19.classifier。

  • vgg19.features有卷积层和池化层
  • vgg19.features有三个线性层,最后是softmax分类器

下面将使用 torchvision.models 加载 VGG19,并将预训练权重设置为 True之后,将冻结层,使这些层不可训练。

对 Linear 层修改最后一层,以满足我们 2 个类的需求。

也可以将 CrossEntropyLoss 用于多类损失函数,对于优化器,使用学习率为 0.0001 和动量为 0.9 的 SGD,如下面的 PyTorch 迁移学习示例所示。

##加载基于VGG19的模型
vgg_based = torchvision.models.vgg19(pretrained=True) 

for param in vgg_based.parameters():
   param.requires_grad = False

#修改最后一层
number_features = vgg_based.classifier[6].in_features 
features = list(vgg_based.classifier.children())[:-1] # 移除最后一层
features.extend([torch.nn.Linear(number_features, len(class_names))]) 
vgg_based.classifier = torch.nn.Sequential(*features) 

vgg_based = vgg_based.to(device) 

print(vgg_based)

criterion = torch.nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(vgg_based.parameters(), lr=0.001, momentum=0.9)

vgg_based输出如下

VGG(
  (features): Sequential(
 (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): ReLU(inplace)
 (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (3): ReLU(inplace)
 (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
 (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (6): ReLU(inplace)
 (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (8): ReLU(inplace)
 (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
 (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (11): ReLU(inplace)
 (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (13): ReLU(inplace)
 (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (15): ReLU(inplace)
 (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (17): ReLU(inplace)
 (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
 (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (20): ReLU(inplace)
 (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (22): ReLU(inplace)
 (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (24): ReLU(inplace)
 (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (26): ReLU(inplace)
 (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
 (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (29): ReLU(inplace)
 (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (31): ReLU(inplace)
 (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (33): ReLU(inplace)
 (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (35): ReLU(inplace)
 (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
 (0): Linear(in_features=25088, out_features=4096, bias=True)
 (1): ReLU(inplace)
 (2): Dropout(p=0.5)
 (3): Linear(in_features=4096, out_features=4096, bias=True)
 (4): ReLU(inplace)
 (5): Dropout(p=0.5)
 (6): Linear(in_features=4096, out_features=2, bias=True)
  )
)

3) 训练模型

下面使用 PyTorch 中的一些功能来帮助我们训练和评估我们的模型。

def train_model(model, criterion, optimizer, num_epochs=25):
   since = time.time()

   for epoch in range(num_epochs):
       print('Epoch {}/{}'.format(epoch, num_epochs - 1))
       print('-' * 10)

        # 迭代数据
       train_loss = 0

       # Iterate over data.
       for i, data in enumerate(dataloaders['train']):
           inputs , labels = data
           inputs = inputs.to(device)
           labels = labels.to(device)

           optimizer.zero_grad()
          
           with torch.set_grad_enabled(True):
               outputs  = model(inputs)
               loss = criterion(outputs, labels)

           loss.backward()
           optimizer.step()

           train_loss += loss.item() * inputs.size(0)

           print('{} Loss: {:.4f}'.format(
               'train', train_loss / dataset_sizes['train']))
          
   time_elapsed = time.time() - since
   print('Training complete in {:.0f}m {:.0f}s'.format(
       time_elapsed // 60, time_elapsed % 60))

   return model


最后, epoch 数设置为 25 开始我们的模型训练过程,并在训练过程结束后进行评估。

在每个训练步骤中,模型接受输入并预测输出。之后预测输出将传递给计算损失。然后损失将执行反向传播来计算得到梯度,最后计算权重并使用 autograd 不断的优化参数。

vgg_based = train_model(vgg_based, criterion, optimizer_ft, num_epochs=25)

4) 测试模型

在可视化模型中,将训练好的模型,使用一批图像进行测试和预测标签

def visualize_model(model, num_images=6):
   was_training = model.training
   model.eval()
   images_so_far = 0
   fig=plt.figure(figsize=(15,15))


   with torch.no_grad():
       for i, (inputs, labels) in enumerate(dataloaders['validation']):
           inputs = inputs.to(device)
           labels = labels.to(device)

           outputs = model(inputs)
           _, preds = torch.max(outputs, 1)

           for j in range(inputs.size()[0]):
               images_so_far += 1
               ax = plt.subplot(num_images//2, 2, images_so_far)
               ax.axis('off')
               ax.set_title('predicted: {} truth: {}'.format(class_names[preds[j]], class_names[labels[j]]))
               img = inputs.cpu().data[j].numpy().transpose((1, 2, 0))
               img = std * img + mean
               ax.imshow(img)

               if images_so_far == num_images:
                   model.train(mode=was_training)
                   return
       model.train(mode=was_training)
visualize_model(vgg_based)
plt.show()
本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2021-07-28,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 小刘IT教程 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1) 加载数据
  • 2) 定义模型
  • 3) 训练模型
  • 4) 测试模型
相关产品与服务
图片处理
图片处理(Image Processing,IP)是由腾讯云数据万象提供的丰富的图片处理服务,广泛应用于腾讯内部各产品。支持对腾讯云对象存储 COS 或第三方源的图片进行处理,提供基础处理能力(图片裁剪、转格式、缩放、打水印等)、图片瘦身能力(Guetzli 压缩、AVIF 转码压缩)、盲水印版权保护能力,同时支持先进的图像 AI 功能(图像增强、图像标签、图像评分、图像修复、商品抠图等),满足多种业务场景下的图片处理需求。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档