前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >PyTorch4:模块总览&torch.utils.data

PyTorch4:模块总览&torch.utils.data

作者头像
小胡胡说
修改2020-08-07 09:53:40
1K0
修改2020-08-07 09:53:40
举报
文章被收录于专栏:小胡学编程小胡学编程

1. Pytorch模块总览


相比TensorFlow,PyTorch 是非常轻量级的:相比 TensorFlow 追求兼容并包,PyTorch 把外围功能放在了扩展包中,比如torchtext,以保持主体的轻便。

根据PyTorch 的 API,可知其核心大概如下:

  • torch.nn & torch.nn.functional:构建神经网络
  • torch.nn.init:初始化权重
  • torch.optim:优化器
  • torch.utils.data:载入数据

可以说,掌握了上面四个模块和前文中提到的底层 API,至少 80% 的 PyTorch 任务都可以完成。剩下的外围事物则有如下的模块支持:

  • torch.cuda:管理 GPU 资源
  • torch.distributed:分布式训练
  • torch.jit:构建静态图提升性能
  • torch.tensorboard:神经网络的可视化

如果额外掌握了上面的四个的模块,PyTorch 就只剩下一些边边角角的特殊需求了。

2.torch.utils.data


这个功能包的作用是收集、打包数据,给数据索引,然后按照 batch 将数据分批喂给神经网络。

数据读取的核心是 torch.utils.data.DataLoader 类。它是一个数据迭代读取器,支持

  • 映射方式和迭代方式读取数据;
  • 自定义数据读取顺序;
  • 自动批;
  • 单线程或多线程数据读取;
  • 自动内存定位。

所有上述功能都可以在 torch.utils.data.DataLoader 的变量中定义:

代码语言:javascript
复制
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

最重要的变量为 dataset,它指明了数据的来源。

DataLoader 支持两种数据类型:

  • 映射风格的数据封装(map-style datasets):这种数据结构拥有自定义的 __getitem__()__len__() 属性,可以以“索引/值”的方式读取数据,对应 torch.utils.data.Dataset 类;
  • 迭代风格的数据封装(iterable-style datasets):这种数据结构拥有自定义的 __iter__() 属性,通常适用于不方便随机获取数据或不定长数据集的读取上,对应 torch.utils.data.IterableDataset 类。

下面我们从顶层的 torch.utils.data.DataLoader 开始,然后一步一步深入到自定义的细节上。为了方便讨论,我们先人工构建一个数据集:

代码语言:javascript
复制
>>> samples = torch.arange(100)
>>> labels = torch.cat([torch.zeros(50), torch.ones(50)], dim=0)

2.1 torch.utils.data.DataLoader 数据加载器


首先看一下常用的变量:

  • dataset:数据源;
  • batch_size:一个整数,定义每一批读取的元素个数;
  • shuffle:一个布尔值,定义是否随机读取;
  • sampler:定义获取数据的策略,必须与 shuffle 互斥;
  • num_workers:一个整数,读取数据使用的线程数;
  • collate_fn:一个将读取的数据处理、聚合成一个一个 batch 的自定义函数;
  • drop_last:一个布尔值,如果最后一批数据的个数不足 batch 的大小,是否保留这个 batch。

datasetsamplercollate_fn 是自定义的类或功能,我们从后往前看。

2.2  数据集的分割


在介绍这三个变量以前,我们先看看如何将数据集分割,比如分成训练集和测试集。

  • torch.utils.data.Subset(dataset, indices)

这个函数可以根据索引indices将数据集dataset分割。

代码语言:javascript
复制
>>> even = [i for i in range(100) if i % 2 == 0]
>>> new1 = torch.utils.data.Subset(samples, even)
>>> print(new1[:5])
tensor([0, 2, 4, 6, 8])
  • torch.utils.data.random_split(dataset, lengths)

先将数据随机排列,然后按照指定的长度进行选择。长度的和必须等于数据集中的数据数量。

代码语言:javascript
复制
>>> train, test = torch.utils.data.random_split(samples, [90, 10])
>>> print(torch.tensor(test))
tensor([79, 60, 98, 74, 31, 43, 21, 69, 55, 76])

2.3. collate_fn 核对函数


这个变量的功能是在数据被读取后,送进模型前对所有数据进行处理打包

比如我们有一个不定长度的视频数据集或文本数据集,我们可以自定义一个函数将它们的长度归一化。比如:

代码语言:javascript
复制
>>> a = [[1,2,3],[4,5],[6,7,8,9]]
>>> def collate_fn(data):
...     '''
...     padding data, so they have same length.
...     '''
...     max_len = max([len(feature) for feature in data])
...     new = torch.zeros(len(data), max_len)
    
...     for i in range(len(data)):
...         tmp = torch.as_tensor(data[i])
...         j = len(tmp)
...         new[i][:j] = tmp
        
...     return new

>>> collate_fn(a)
tensor([[1., 2., 3., 0.],
        [4., 5., 0., 0.],
        [6., 7., 8., 9.]])

将这个函数赋值给 collate_fn,在读取数据的时候就可以自动对数据进行 padding 并打包成一个 batch。

2.4 sampler 采样器


这个变量决定了数据读取的顺序。

注意,sampler 只对 iterable-style datasets 有效。

除了可以自定义采样器,Python 内置了几种不同的采样器:

  • torch.utils.data.SequentialSampler(data_source)默认的采样器。
  • torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None)

随机选择数据。可以指定一次读取 num_samples 个数据。replacementTrue 的话可以指定 num_samples

代码语言:javascript
复制
>>> batch = torch.utils.data.RandomSampler(samples, replacement=True, num_samples=5) # 生成一个迭代器
>>> print(list(batch))
[85, 70, 5, 63, 79]

还有三个采样器无法独立使用,必须先实例化,然后放进 DataLoader

  • torch.utils.data.SubsetRandomSampler(indices):先按照索引选取数据,然后随机排列。
  • torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True):字面意思是按照概率选择不同类别的元素。
  • torch.utils.data.BatchSampler(sampler, batch_size, drop_last):在一个 batch 中应用另外一个采样器。

2.5 dataset 数据集生成器


  • torch.utils.data.Dataset

这个类需要覆写 __getitem____len__ 属性。

代码语言:javascript
复制
>>> class MyData(torch.utils.data.Dataset):
...     def __init__(self, data):
...         super(MyData, self).__init__()
...         self.data = data
        
...     def __len__(self, data):
...         return len(self.data)
    
...     def __getitem__(self, index):
...         return self.data[index]
        
>>> mydata = MyData(samples)
>>> mydata[0]
tensor(0)
>>> mydata[10:15]
tensor([10, 11, 12, 13, 14])

除此以外,还有若干个 wrapper:

  • torch.utils.data.IterableDataset
  • torch.utils.data.TensorDataset(*tensors)
  • torch.utils.data.ConcatDataset(datasets)
  • torch.utils.data.ChainDataset(datasets)

2.6 总结


选择让我们把所有知识应用一下。假设我们想以 10 为一个 batch,随机选择数据:

代码语言:javascript
复制
>>> train = data.TensorDataset(torch.as_tensor(samples), torch.as_tensor(labels))
>>> ds = torch.utils.data.DataLoader(train, batch_size=10, shuffle=True)

>>> for _ in range(5):
...     print(iter(ds).next())
[tensor([35, 19, 99, 58, 59, 10, 26, 86, 24, 25]), tensor([0., 0., 1., 1., 1., 0., 0., 1., 0., 0.])]
[tensor([ 6, 37, 24, 98, 96, 18, 88, 90, 19, 87]), tensor([0., 0., 0., 1., 1., 0., 1., 1., 0., 1.])]
[tensor([80, 75, 48, 34, 90, 67,  8, 63, 47, 32]), tensor([1., 1., 0., 0., 1., 1., 0., 1., 0., 0.])]
[tensor([48, 68, 64, 54, 87, 76, 18, 53, 65, 17]), tensor([0., 1., 1., 1., 1., 1., 0., 1., 1., 0.])]
[tensor([65, 26, 67,  5,  4,  8, 35, 47, 40, 96]), tensor([1., 0., 1., 0., 0., 0., 0., 0., 0., 1.])]

本文系转载,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文系转载前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. Pytorch模块总览
  • 2.torch.utils.data
    • 2.1 torch.utils.data.DataLoader 数据加载器
      • 2.2  数据集的分割
        • 2.3. collate_fn 核对函数
          • 2.4 sampler 采样器
            • 2.5 dataset 数据集生成器
              • 2.6 总结
              相关产品与服务
              批量计算
              批量计算(BatchCompute,Batch)是为有大数据计算业务的企业、科研单位等提供高性价比且易用的计算服务。批量计算 Batch 可以根据用户提供的批处理规模,智能地管理作业和调动其所需的最佳资源。有了 Batch 的帮助,您可以将精力集中在如何分析和处理数据结果上。
              领券
              问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档