前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >PyTorch-数据处理流程

PyTorch-数据处理流程

作者头像
GIS指北
发布2022-11-14 11:00:07
6760
发布2022-11-14 11:00:07
举报
文章被收录于专栏:GIS指北

在此记录对PyTorch框架学习过程中的思考。

数据加载处理是深度学习模型训练的前奏,是很重要的一部分。这一过程需要把原始数据,影像或者文本等进行封装、转换,并以合适的格式传递给模型。这个过程依赖torch.utils.data模块,常用以上三个类:

torch.utils.data.Dataset

torch.utils.data.Sampler

torch.utils.data.DataLoader

01

三者关系

三者的关系可以表示如下图:

三个类形成对数据的层层封装。

Dataset对原始数据进行封装,暴露数据提取的接口。

Sampler决定了采样策略,根据不同索引方式来从Dataset中提取部分数据。

DataLoader通过封装Dataset和Sampler,设定batch_size等参数,构造了方便快速遍历的mini batch数据集。

02

Dataset

Dataset是一个抽象类,迭代器。负责对原始数据进行封装,形成模型可以识别的数据结构,其暴露了获取单个数据的接口。

Dataset有两种:Map-style datasets 和 Iterable-style datasets

torch.utils.data.Dataset通过实现__len__()和__getitem__()来获取数据。

torch.utils.data.IterableDataset 通过实现__iter()__来获取数据。

可以通过集成Dataset类来自定义自己的数据集,如下示例,通过改写__getitem__()方法自定义提取数据,可以在其中加入数据增强的方法。

代码语言:javascript
复制
import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

03

Sampler

Sampler定义了对数据集的采样策略。通过Sampler类中的__iter__()方法来获取数据集的索引,其基类如下。

代码语言:javascript
复制
class Sampler(Generic[T_co]):
    r"""Base class for all Samplers.

    Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
    way to iterate over indices of dataset elements, and a :meth:`__len__` method
    that returns the length of the returned iterators.

    .. note:: The :meth:`__len__` method isn't strictly required by
              :class:`~torch.utils.data.DataLoader`, but is expected in any
              calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
    """

    def __init__(self, data_source: Optional[Sized]) -> None:
        pass

    def __iter__(self) -> Iterator[T_co]:
        raise NotImplementedError

所有子类都继承自Sampler,通过改写__iter__()方法来实现。

torch.utils.data.SequentialSampler 指定顺序采样样本。

torch.utils.data.RandomSampler 随机采样,可指定是否放回样本

torch.utils.data.DistributeSampler 数据加载限制为数据集子集,每个进程都可以把一个DistributeSampler实例作为DataLoader采样器传递

torch.utils.data.BatchSampler 在一个小batch中封装一个Sampler,返回小batch的索引

04

DataLoader

DataLoader是数据加载的核心,它对Dataset和Sampler进行封装,以mini batch的形式加载数据。支持单进程和多进程.

代码语言:javascript
复制
torch.utils.data.DataLoader(dataset, batch_size=1, 
          shuffle=False, sampler=None, batch_sampler=None, 
          num_workers=0, collate_fn=None, pin_memory=False, 
          drop_last=False, timeout=0, worker_init_fn=None, 
          multiprocessing_context=None, generator=None, *, 
          prefetch_factor=2, persistent_workers=False)

DataLoader是数据加载的核心,它对Dataset和Sampler进行封装,以mini batch的形式加载数据。支持单进程和多进程.

Dataset,加载的数据集,Dataset实例

batch_size,每个batch的样本数

shuffle:设置为True,在每个epoch开始前,都会随机抽取数据,调用了RandomSampler

sampler:定义从数据集的抽取策略,指定了sampler,shuffle必须为False

batch_sampler:和sampler功能一样,传入BatchSampler,和batch_size

shuffle互斥。

num_workers:指定进程数。默认0,只在主进程加载数据

drop_last:True的话,会删除最后一个不完整的batch数据。

总结来讲,DataLoader通过Sampler定义的索引策略,从Dataset中遍历提取数据。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-05-24,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 GIS指北 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
批量计算
批量计算(BatchCompute,Batch)是为有大数据计算业务的企业、科研单位等提供高性价比且易用的计算服务。批量计算 Batch 可以根据用户提供的批处理规模,智能地管理作业和调动其所需的最佳资源。有了 Batch 的帮助,您可以将精力集中在如何分析和处理数据结果上。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档