前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【小白学习PyTorch教程】八、使用图像数据增强手段,提升CIFAR-10 数据集精确度

【小白学习PyTorch教程】八、使用图像数据增强手段,提升CIFAR-10 数据集精确度

作者头像
润森
发布2022-08-18 09:26:35
1.1K0
发布2022-08-18 09:26:35
举报
文章被收录于专栏:毛利学Python毛利学Python

「@Author:Runsen」

上次基于CIFAR-10 数据集,使用PyTorch构建图像分类模型的精确度是60%,对于如何提升精确度,方法就是常见的transforms图像数据增强手段。

代码语言:javascript
复制
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as vutils

import numpy as np
import os
import warnings
from matplotlib import pyplot as plt
warnings.filterwarnings('ignore')`
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

加载数据集

代码语言:javascript
复制
# number of images in one forward and backward pass
batch_size = 128

# number of subprocesses used for data loading
# Normally do not use it if your os is windows
num_workers = 2

train_dataset = datasets.CIFAR10('./data/CIFAR10/', 
                                 train = True, 
                                 download = True, 
                                 transform = transform_train)

train_loader = DataLoader(train_dataset, 
                          batch_size = batch_size, 
                          shuffle = True, 
                          num_workers = num_workers)

val_dataset = datasets.CIFAR10('./data/CIFAR10', 
                                train = True, 
                                transform = transform_test)

val_loader = DataLoader(val_dataset, 
                        batch_size = batch_size, 
                        shuffle = False, 
                        num_workers = num_workers)

test_dataset = datasets.CIFAR10('./data/CIFAR10', 
                                train = False, 
                                transform = transform_test)

test_loader = DataLoader(test_dataset, 
                         batch_size = batch_size, 
                         shuffle = False, 
                         num_workers = num_workers)

# declare classes in CIFAR10
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

之前的transform ’只是进行了缩放和归一,在这里添加RandomCrop和RandomHorizontalFlip

代码语言:javascript
复制
# define a transform to normalize the data

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), # converting images to tensor
    transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5)) 
    # if the image dataset is black and white image, there can be just one number. 
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
])

可视化具体的图像

代码语言:javascript
复制
# function that will be used for visualizing the data

def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    plt.imshow(np.transpose(img, (1, 2, 0)))  # convert from Tensor image

# obtain one batch of imges from train dataset
dataiter = iter(train_loader)
images, labels = dataiter.next()
images = images.numpy() # convert images to numpy for display

# plot the images in one batch with the corresponding labels
fig = plt.figure(figsize = (25, 4))

# display images
for idx in np.arange(10):
    ax = fig.add_subplot(1, 10, idx+1, xticks=[], yticks=[])
    imshow(images[idx])
    ax.set_title(classes[labels[idx]])

建立常见的CNN模型

代码语言:javascript
复制
# define the CNN architecture

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        
        self.main = nn.Sequential(
            # 3x32x32
            nn.Conv2d(in_channels = 3, out_channels = 32, kernel_size = 3, padding = 1), # 3x32x32 (O = (N+2P-F/S)+1)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size = 2, stride = 2), # 32x16x16
            nn.BatchNorm2d(32),
            
            nn.Conv2d(32, 64, kernel_size = 3, padding = 1), # 32x16x16
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2), # 64x8x8
            nn.BatchNorm2d(64),
            
            nn.Conv2d(64, 128, 3, padding = 1), # 64x8x8
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2), # 128x4x4
            nn.BatchNorm2d(128),
        )
        
        self.fc = nn.Sequential(
            nn.Linear(128*4*4, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            
            nn.Linear(1024, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            
            nn.Linear(256, 10)
        )
        
    def forward(self, x):
        # Conv and Poolilng layers
        x = self.main(x)
        
        # Flatten before Fully Connected layers
        x = x.view(-1, 128*4*4) 
        
        # Fully Connected Layer
        x = self.fc(x)
        return x

cnn = CNN().to(device)
cnn

torch.nn.CrossEntropyLoss对输出概率介于0和1之间的分类模型进行分类。

训练模型

代码语言:javascript
复制
# 超参数:Hyper Parameters
learning_rate = 0.001
train_losses = []
val_losses = []

# Loss function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn.parameters(), lr = learning_rate)

# define train function that trains the model using a CIFAR10 dataset

def train(model, epoch, num_epochs):
    model.train()
    
    total_batch = len(train_dataset) // batch_size
    
    for i, (images, labels) in enumerate(train_loader):
            
        X = images.to(device)
        Y = labels.to(device)

        ### forward pass and loss calculation
        # forward pass
        pred = model(X)
        #c alculation  of loss value
        cost = criterion(pred, Y)

        ### backward pass and optimization
        # gradient initialization
        optimizer.zero_grad()
        # backward pass
        cost.backward()
        # parameter update
        optimizer.step()
            
        # training stats
        if (i+1) % 100 == 0:
            print('Train, Epoch [%d/%d], lter [%d/%d], Loss: %.4f' 
                  % (epoch+1, num_epochs, i+1, total_batch, np.average(train_losses)))
            
            train_losses.append(cost.item())n

# def the validation function that validates the model using CIFAR10 dataset

def validation(model, epoch, num_epochs):
    model.eval()
    
    total_batch = len(val_dataset) // batch_size
    
    for i, (images, labels) in enumerate(val_loader):
        
        X = images.to(device)
        Y = labels.to(device)
        
        with torch.no_grad():
            pred = model(X)
            cost = criterion(pred, Y)
        
        if (i+1) % 100 == 0:
            print("Validation, Epoch [%d/%d], lter [%d/%d], Loss: %.4f"
                  % (epoch+1, num_epochs, i+1, total_batch, np.average(val_losses)))
            
            val_losses.append(cost.item())
            
def plot_losses(train_losses, val_losses):
    plt.figure(figsize=(5, 5))
    plt.plot(train_losses, label='Train', alpha=0.5)
    plt.plot(val_losses, label='Validation', alpha=0.5)
    plt.xlabel('Epochs')
    plt.ylabel('Losses')
    plt.legend()
    plt.grid(b=True)
    plt.title('CIFAR 10 Train/Val Losses Over Epoch')
    plt.show()

num_epochs = 20
for epoch in range(num_epochs):
    train(cnn, epoch, num_epochs)
    validation(cnn, epoch, num_epochs)
    torch.save(cnn.state_dict(), './data/Tutorial_3_CNN_Epoch_{}.pkl'.format(epoch+1))

plot_losses(train_losses, val_losses)

测试模型

代码语言:javascript
复制
def test(model):
    
    # declare that the model is about to evaluate
    model.eval()

    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in test_dataset:
            images = images.unsqueeze(0).to(device)
            
            # forward pass
            outputs = model(images)
            
            _, predicted = torch.max(outputs.data, 1)
            total += 1
            correct += (predicted == labels).sum().item()

    print("Accuracy of Test Images: %f %%" % (100 * float(correct) / total))

经过图像数据增强。模型从60提升到了84。

测试模型在哪些类上表现良好,

代码语言:javascript
复制
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))

with torch.no_grad():
    
    for data in test_loader:
        images, labels = data
        images = images.to(device)
        labels = labels.to(device)
        outputs = cnn(images)
        
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))
本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2021-07-15,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 小刘IT教程 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档