首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >PyTorch 中Datasets And DataLoaders的使用 | PyTorch系列(十二)

PyTorch 中Datasets And DataLoaders的使用 | PyTorch系列(十二)

作者头像
AI算法与图像处理
发布2020-05-25 17:45:17
1.3K0
发布2020-05-25 17:45:17
举报

文 |AI_study

在这篇文章中,我们将看到如何使用Dataset和DataLoader 的PyTorch类。

在这篇文章中,我们的目标是熟悉如何使用dataset 和 data loader 对象,并对我们的训练集有一个初步的了解。

从高层次的角度来看,我们的深度学习项目仍处于数据准备阶段。

  • 准备数据
  • 构建模型
  • 训练模型
  • 分析模型的结果

在这篇文章中,我们将看到如何使用我们在前一篇文章中创建的dataset 和 data loader对象。请记住,在前一篇文章中,我们有两个PyTorch对象、Dataset和 DataLoader。

  • train_set
  • train_loader

PyTorch Dataset:使用训练集

让我们先来看看我们可以执行哪些操作来更好地理解我们的数据。

探索数据

要查看我们的训练集中有多少图像,我们可以使用Python len()函数检查数据集的长度:

> len(train_set)
60000

这个60000的数字是有意义的,基于我们在 [Fashion-MNIST dataset](https://deeplizard.com/learn/video/EqpzfvxBx30)一文中所学到的。假设我们想查看每个图像的标签。可以这样做:

注意,torchvision API从版本0.2.1开始进行了更改。参见GitHub上的发布说明。

> https://github.com/pytorch/vision/releases/tag/v0.2.2

# Before torchvision 0.2.2
> train_set.train_labels
tensor([9, 0, 0, ..., 3, 0, 5])
# Starting with torchvision 0.2.2
> train_set.targets
tensor([9, 0, 0, ..., 3, 0, 5])

第一个图像是 9,接下来的两个是0。请记住,在以前的文章中,这些值编码实际的类名或标签。例如,9是短靴,而0是t恤。

如果我们想要查看数据集中每个标签的数量,我们可以像这样使用PyTorch bincount()函数:

注意,torchvision API从版本0.2.1开始进行了更改。参见GitHub上的发布说明。

https://github.com/pytorch/vision/releases/tag/v0.2.2

# Before torchvision 0.2.2
> train_set.train_labels.bincount()
tensor([6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000])

# starting torchvision 0.2.2
> train_set.targets.bincount()
tensor([6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000])
```

Class Imbalance: Balanced And Unbalanced Datasets

这向我们表明,Fashion-MNIST数据集在每个类的样本数量方面是一致的。这意味着我们每个类有6000个样本。因此,这个数据集被认为是平衡的。如果类具有不同数量的样本,我们将该集合称为不平衡数据集。

类别不平衡是一个常见的问题,但在我们的例子中,我们刚刚看到Fashion-MNIST数据集确实是平衡的,所以我们的项目不需要担心这个问题。

要了解更多关于在深度学习中减轻不平衡数据集的方法,请看这篇论文:卷积神经网络中的类不平衡问题的系统研究。

https://arxiv.org/abs/1710.05381

访问训练集中的数据

要访问训练集中的单个元素,我们首先将train_set对象传递给Python的iter()内置函数,该函数返回一个表示数据流的对象。

对于数据流,我们可以使用Python内置的next()函数来获取数据流中的下一个数据元素。从这我们期待得到一个单一的样本,所以我们将命名相应的结果:

> sample = next(iter(train_set))
> len(sample)
2

将样例传递给len()函数后,我们可以看到样例包含两个样本,这是因为数据集包含图像-标签对。我们从训练集中检索的每个样本都包含一个张量的图像数据和相应的张量标签。

由于样本是一个序列类型([sequence type](https://docs.python.org/3/library/stdtypes.html#typesseq)),我们可以使用序列解压( *sequence unpacking* )来分配图像和标签。现在我们将检查图像的类型和标签,看看他们都是 torch.Tensor 对象:

> type(image)
torch.Tensor

# Before torchvision 0.2.2
> type(label)
torch.Tensor

# Starting at torchvision 0.2.2
> type(label)
int

我们将检查形状,图像是一个1 x 28 x 28的张量,而标签是一个标量值的张量:

> image.shape
torch.Size([1, 28, 28])

> torch.tensor(label).shape
torch.Size([])

我们还将调用图像上的squeeze() 函数,以查看如何删除size 1的维度。

> image.squeeze().shape
torch.Size([28, 28])

同样,基于我们之前对Fashion-MNIST数据集的讨论,我们希望看到图像的28 x 28的形状。我们在张量的第一维看到1的原因是因为需要表示通道的数量。与有3个颜色通道的RGB图像相反,灰度图像只有一个颜色通道。这就是为什么我们有一个1×28×28张量。我们有一个颜色通道,大小是28x28。

现在我们来画出图像,我们会看到为什么一开始我们压缩了这个张量。我们首先压缩这个张量,然后把它传递给imshow() 函数。

> plt.imshow(image.squeeze(), cmap="gray")
> torch.tensor(label)
tensor(9)

我们拿回了一个ankle-boot 和一个9号的标签。我们知道标签9代表了ankle-boot,因为它是我们在前一篇文章中看到的那篇论文中指定的。

好吧。现在让我们看看如何使用数据加载器。

PyTorch DataLoader:处理批量数据

我们将开始创建一个新的数据加载器与较小的批处理大小为10,以便很容易演示发生了什么:

> display_loader = torch.utils.data.DataLoader(
    train_set, batch_size=10
)

我们以与训练集相同的方式从loader中 获得一个batch。我们使用iter() 和next() 函数。

使用数据加载器时要注意一件事。如果shuffle = True,则每次调用next时批次将不同。如果shuffle = True,则在第一次调用next时将返回训练集中的第一个样本。shuffle 函数默认情况下处于关闭状态。

# note that each batch will be different when shuffle=True
> batch = next(iter(display_loader))
> print('len:', len(batch))
len: 2

检查返回batch的长度,就像训练集一样,我们得到2。让我们拆开每个batch,看看两个张量及其形状:

> images, labels = batch

> print('types:', type(images), type(labels))
> print('shapes:', images.shape, labels.shape)
types: <class 'torch.Tensor'> <class 'torch.Tensor'>
shapes: torch.Size([10, 1, 28, 28]) torch.Size([10])

由于batch_size = 10,我们知道我们正在处理一批10张图像和10个相应的标签。这就是为什么我们对变量名使用复数形式的原因。

类型是我们期望的张量。但是,形状与我们在单个样品中看到的形状不同。我们没有一个标量值作为标签,而是有一个带有10个值的一阶张量。张量中包含图像数据的每个维度的大小由以下每个值定义:

> (batch size, number of color channels, image height, image width)

批量大小为10,这就是为什么现在张量的第一个尺寸为10,每个图像是一个索引。以下是我们之前看到的第一个ankle-boot :

> images[0].shape
torch.Size([1, 28, 28])

> labels[0]
tensor(9)

要绘制一批图像,我们可以使用torchvision.utils.make_grid() 函数创建一个可以如下绘制的网格:

> grid = torchvision.utils.make_grid(images, nrow=10)

> plt.figure(figsize=(15,15))
> plt.imshow(np.transpose(grid, (1,2,0)))

> print('labels:', labels)
labels: tensor([9, 0, 0, 3, 0, 2, 7, 2, 5, 5])

感谢Amit Chaudhary指出,可以使用 PyTorch张量方法 permute()代替np.transpose()。就像这样:

> grid = torchvision.utils.make_grid(images, nrow=10)

> plt.figure(figsize=(15,15))
> plt.imshow(grid.permute(1,2,0))

> print('labels:', labels)
labels: tensor([9, 0, 0, 3, 0, 2, 7, 2, 5, 5])

回想一下,我们有下面这个表,它显示了标签映射到下面的类名:

Index

Label

0

T-shirt/top

1

Trouser

2

Pullover

3

Dress

4

Coat

5

Sandal

6

Shirt

7

Sneaker

8

Bag

9

Ankle boot

如何使用PyTorch DataLoader绘制图像

这里是另一个是使用PyTorch DataLoader来绘制图像。这个方法的灵感来自Barry Mitchell。享受吧!

how_many_to_plot = 20

train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=1, shuffle=True
)

mapping = {
    0:'Top', 1:'Trousers', 2:'Pullover', 3:'Dress', 4:'Coat'
    ,5:'Sandal', 6:'Shirt', 7:'Sneaker', 8:'Bag', 9:'Ankle Boot'
}

plt.figure(figsize=(50,50))
for i, batch in enumerate(train_loader, start=1):
    image, label = batch
    plt.subplot(10,10,i)
    fig = plt.imshow(image.reshape(28,28), cmap='gray')
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)
    plt.title(mapping[label.item()], fontsize=28)
    if (i >= how_many_to_plot): break
plt.show()

下一步是构建模型

我们现在应该很好地理解了如何探索和与`Dataset`s and `DataLoader`交互。当我们开始构建卷积神经网络和训练回路时,这两种方法都将被证明是重要的。事实上,数据加载器将直接在我们的训练循环中使用。

让我们继续前进,因为我们已经准备好在下一篇文章中构建我们的模型。到时候见!

文章中内容都是经过仔细研究的,本人水平有限,翻译无法做到完美,但是真的是费了很大功夫,希望小伙伴能动动你性感的小手,分享朋友圈或点个“在看”,支持一下我 ^_^

英文原文链接是:

https://deeplizard.com/learn/video/mUueSPmcOBc

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-05-20,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI算法与图像处理 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • PyTorch DataLoader:处理批量数据
  • 如何使用PyTorch DataLoader绘制图像
相关产品与服务
批量计算
批量计算(BatchCompute,Batch)是为有大数据计算业务的企业、科研单位等提供高性价比且易用的计算服务。批量计算 Batch 可以根据用户提供的批处理规模,智能地管理作业和调动其所需的最佳资源。有了 Batch 的帮助,您可以将精力集中在如何分析和处理数据结果上。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档