前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >PyTorch入门笔记-手写数字实战01

PyTorch入门笔记-手写数字实战01

作者头像
触摸壹缕阳光
修改2021-03-28 22:37:50
9520
修改2021-03-28 22:37:50
举报

下面来简单回顾上一小节的嵌套非线性模型:

  • H_1 = relu(XW_1 + b_1)
  • H_2 = relu(H1W_2 + b_2)
  • H_3 = f(H_2W_3 + b_3), 模型最后一层的激活函数不会是 relu 激活函数,需要根据你的具体任务来选择合适的激活函数。比如使用二分类的 Sigmoid 或多分类的 SoftMax(当然多个二分类也可以用于处理多分类)。由于这里只是简单的演示整个训练流程,所以为了简单本小节最后一层不添加任何激活函数。

对 MNIST 手写数字识别进行分类大致分为四个步骤,这四个步骤也是训练大多数深度学习模型的基本步骤:

  • 加载数据集(Load data)
  • 构建模型(Build Model)
  • 训练(Train)
  • 测试(Test)

不过在这之前我们需要构建一个 utils.py 文件,其中包含着三个工具方法:

  • plot_curve(loss_list) 方法绘制损失函数曲线;
  • plot_image(x, label, name)方法显示 6 张手写数字图片以及对应的数字标签;
  • one_hot(label, depth = 10)方法将 0~9 的数字编码标签转换为 one-hot 编码的标签。比如将数字编码 5 转换为 one-hot 编码为 [0,0,0,0,1,0,0,0,0,0](由于此时假设为十个类别,因此 one-hot 编码后的向量维度为 10 维)。
代码语言:txt
复制
import torch
from matplotlib import pyplot as plt

def plot_curve(loss_list):
    """
    根据存放loss值的列表绘制曲线
    """
    plt.plot(range(len(loss_list)), loss_list, color = 'blue')
    # 添加图例并放置在右上角
    plt.legend(['train_loss'], loc = 'upper right')
    plt.xlabel('step') # 设置横坐标轴名称
    plt.ylabel('train_loss') # 设置纵坐标轴名称
    plt.show()

def plot_image(x, label, name):
    """
    显示6张手写数字图片以及对应的数字标签
    """
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(x[i][0] * 0.3081 + 0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()

def one_hot(label, depth = 10):
    '''
    将数字编码标签label转换为one-hot编码y
    '''
    y = torch.zeros(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1, 1)
    y.scatter_(dim = 1, index = idx, value = 1)
    return y

加载数据集

MNIST 是比较重要和经典的数据集,目前常用的机器学习和深度学习框架都内置了 MNIST 数据集,通过几行代码就可以自动下载、管理以及加载 MNIST 数据集。基于 PyTorch 有很多工具集,比如:处理自然语言的 torchtext,处理音频的 torchaudio 和 处理图像视频的 torchvision,这些工具集可以独立于 PyTorch 的使用。MNIST 数据集属于图像,我们可以在 torchvision.datasets 包中加载 MNIST。「加载的 MNIST 数据集是 ndarray 数组类型,因此我们需要将其转换成 Tensor。实验证明输入数据在 0 附近均匀分布,神经网络模型会有所提升(在本小节的神经网络模型架构下,对数据进行标准化准确率能够提升 10%),因此我们还需要对 MNIST 数据集进行标准化的转换,torchvision.transforms 包提供了这些转换方法。」

代码语言:txt
复制
import torchvision

train_data = torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))
                               ]))

print(len(train_data)) # 60000
# 训练集中的第1张手写数字图片以及对应的标签
X_train_0, label_train_0 = train_data[0]
print(X_train_0.shape) # torch.Size([1, 28, 28])
print(label_train_0) # 5

在 torchvision.datasets 中有很多类似 MNIST 的数据集,下面来简单介绍 torchvision.datasets.MNIST 中的一些参数:

  • 'mnist_data':MNIST 数据集所在的文件夹,我直接设置在当前路径。如果你也传入 'mnist_data',你会在当前路径下发现一个 mnist_data 的文件夹;
  • train = True:可选参数。如果设置为 True,则从 ./mnist_data/MNIST/processed/training.pt 中加载训练集(使用 len(train_data) 可以看出共有 60000 张手写数字图片)。如果设置为 False,则从 ./mnist_data/MNIST/processed/test.pt 中加载测试集;
  • download = True:可选参数。如果设置为 True,且路径下没有 MNIST 数据集,则会从网络上下载 MNIST 数据集,如果路径下已经存在 MNIST 数据集,则不会再次下载;
  • transform = torchvision.transforms.Compose:transform 进行数据的预处理操作:
    • ToTensor:将 ndarray 数组转换为 Tensor 数据类型;
    • Normalize:进行数据的标准化,即减去均值除以方差,此时均值 0.1307 和方差 0.3081 是 MNIST 数据集计算好的数据,直接使用即可;

加载完了 MNIST 数据集中的训练集,我们可以设置 train = False 来加载 10000 张测试集。

代码语言:txt
复制
import torchvision

test_data = torchvision.datasets.MNIST('mnist_data', train = False, download = True,
                                       transform=torchvision.transforms.Compose([
                                           torchvision.transforms.ToTensor(),
                                           torchvision.transforms.Normalize(
                                               (0.1307,), (0.3081,))
                               ]))

print(len(test_data)) # 10000
# 测试集中的第1张手写数字图片以及对应的标签
X_test_0, label_test_0 = test_data[0]
print(X_test_0.shape) # torch.Size([1, 28, 28])
print(label_test_0) # 7

至此 60000 张训练集以及 10000 张测试集都加载进来了,不过我们通常使用更为方便的数据集加载器 DataLoader,DataLoader 结合了数据集和取样器,提供了多个线程处理数据集,并且里面提供了很多方便处理数据集的功能。DataLoader 在 torch.utils.data 包下。

代码语言:txt
复制
import torch
import utils # 加载我们自己写的工具类

batch_size = 512

train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size = batch_size, # batch_size
                                           shuffle = True) # 是否打乱数据集
test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size = batch_size,
                                          # 测试集只用于验证模型性能不需要打乱数据集
                                          shuffle = False) 
# 迭代器加载数据集,每次都加载batch_size个
# X: [batch_size, channel, width, hight]
# label: 数字编码
X, label = next(iter(train_loader))
print(X.shape, label.shape, X.min(), label.max())
utils.plot_image(X, label, 'image sample')
代码语言:txt
复制
torch.Size([512, 1, 28, 28]) torch.Size([512]) tensor(-0.4242) tensor(9)

References: 1. 龙良曲深度学习与PyTorch入门实战:https://study.163.com/course/introduction/1208894818.htm

原文地址:https://mp.weixin.qq.com/s/JTMcPCUL-F8kd3CRnUvOUg

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-10-22,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI机器学习与深度学习算法 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 加载数据集
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档