下面来简单回顾上一小节的嵌套非线性模型:
对 MNIST 手写数字识别进行分类大致分为四个步骤,这四个步骤也是训练大多数深度学习模型的基本步骤:
不过在这之前我们需要构建一个 utils.py
文件,其中包含着三个工具方法:
plot_curve(loss_list)
方法绘制损失函数曲线;plot_image(x, label, name)
方法显示 6 张手写数字图片以及对应的数字标签;one_hot(label, depth = 10)
方法将 0~9 的数字编码标签转换为 one-hot 编码的标签。比如将数字编码 5 转换为 one-hot 编码为 [0,0,0,0,1,0,0,0,0,0](由于此时假设为十个类别,因此 one-hot 编码后的向量维度为 10 维)。import torch
from matplotlib import pyplot as plt
def plot_curve(loss_list):
"""
根据存放loss值的列表绘制曲线
"""
plt.plot(range(len(loss_list)), loss_list, color = 'blue')
# 添加图例并放置在右上角
plt.legend(['train_loss'], loc = 'upper right')
plt.xlabel('step') # 设置横坐标轴名称
plt.ylabel('train_loss') # 设置纵坐标轴名称
plt.show()
def plot_image(x, label, name):
"""
显示6张手写数字图片以及对应的数字标签
"""
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.tight_layout()
plt.imshow(x[i][0] * 0.3081 + 0.1307, cmap='gray', interpolation='none')
plt.title("{}: {}".format(name, label[i].item()))
plt.xticks([])
plt.yticks([])
plt.show()
def one_hot(label, depth = 10):
'''
将数字编码标签label转换为one-hot编码y
'''
y = torch.zeros(label.size(0), depth)
idx = torch.LongTensor(label).view(-1, 1)
y.scatter_(dim = 1, index = idx, value = 1)
return y
MNIST 是比较重要和经典的数据集,目前常用的机器学习和深度学习框架都内置了 MNIST 数据集,通过几行代码就可以自动下载、管理以及加载 MNIST 数据集。基于 PyTorch 有很多工具集,比如:处理自然语言的 torchtext,处理音频的 torchaudio 和 处理图像视频的 torchvision,这些工具集可以独立于 PyTorch 的使用。MNIST 数据集属于图像,我们可以在 torchvision.datasets 包中加载 MNIST。「加载的 MNIST 数据集是 ndarray 数组类型,因此我们需要将其转换成 Tensor。实验证明输入数据在 0 附近均匀分布,神经网络模型会有所提升(在本小节的神经网络模型架构下,对数据进行标准化准确率能够提升 10%),因此我们还需要对 MNIST 数据集进行标准化的转换,torchvision.transforms 包提供了这些转换方法。」
import torchvision
train_data = torchvision.datasets.MNIST('mnist_data', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
]))
print(len(train_data)) # 60000
# 训练集中的第1张手写数字图片以及对应的标签
X_train_0, label_train_0 = train_data[0]
print(X_train_0.shape) # torch.Size([1, 28, 28])
print(label_train_0) # 5
在 torchvision.datasets 中有很多类似 MNIST 的数据集,下面来简单介绍 torchvision.datasets.MNIST 中的一些参数:
./mnist_data/MNIST/processed/training.pt
中加载训练集(使用 len(train_data) 可以看出共有 60000 张手写数字图片)。如果设置为 False,则从 ./mnist_data/MNIST/processed/test.pt
中加载测试集;加载完了 MNIST 数据集中的训练集,我们可以设置 train = False 来加载 10000 张测试集。
import torchvision
test_data = torchvision.datasets.MNIST('mnist_data', train = False, download = True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
]))
print(len(test_data)) # 10000
# 测试集中的第1张手写数字图片以及对应的标签
X_test_0, label_test_0 = test_data[0]
print(X_test_0.shape) # torch.Size([1, 28, 28])
print(label_test_0) # 7
至此 60000 张训练集以及 10000 张测试集都加载进来了,不过我们通常使用更为方便的数据集加载器 DataLoader,DataLoader 结合了数据集和取样器,提供了多个线程处理数据集,并且里面提供了很多方便处理数据集的功能。DataLoader 在 torch.utils.data 包下。
import torch
import utils # 加载我们自己写的工具类
batch_size = 512
train_loader = torch.utils.data.DataLoader(train_data,
batch_size = batch_size, # batch_size
shuffle = True) # 是否打乱数据集
test_loader = torch.utils.data.DataLoader(test_data,
batch_size = batch_size,
# 测试集只用于验证模型性能不需要打乱数据集
shuffle = False)
# 迭代器加载数据集,每次都加载batch_size个
# X: [batch_size, channel, width, hight]
# label: 数字编码
X, label = next(iter(train_loader))
print(X.shape, label.shape, X.min(), label.max())
utils.plot_image(X, label, 'image sample')
torch.Size([512, 1, 28, 28]) torch.Size([512]) tensor(-0.4242) tensor(9)
References: 1. 龙良曲深度学习与PyTorch入门实战:https://study.163.com/course/introduction/1208894818.htm
本文分享自 AI机器学习与深度学习算法 微信公众号,前往查看
如有侵权,请联系 cloudcommunity@tencent.com 删除。
本文参与 腾讯云自媒体同步曝光计划 ,欢迎热爱写作的你一起参与!