前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >pytorch之Dataset

pytorch之Dataset

作者头像
Tom2Code
发布2022-11-21 12:05:13
5530
发布2022-11-21 12:05:13
举报
文章被收录于专栏:Tom

一.前言

训练模型一般都是先处理 数据的输入问题 和 预处理问题 。Pytorch提供了几个有用的工具:torch.utils.data.Dataset 类和 torch.utils.data.DataLoader 类 。

而这也是我们在之前的文章里说过的三件套之一。

流程是先把原始数据转变成 torch.utils.data.Dataset 类,随后再把得到的 torch.utils.data.Dataset 类当作一个参数传递给 torch.utils.data.DataLoader 类,得到一个数据加载器,这个数据加载器每次可以返回一个 Batch 的数据供模型训练使用。

所以整体的流程是

数据=》Datasets=》DataLoader

在 pytorch 中,提供了一种十分方便的数据读取机制,即使用 torch.utils.data.Dataset 与 Dataloader 组合得到数据迭代器。在每次训练时,利用这个迭代器输出每一个 batch 数据,并能在输出时对数据进行相应的预处理或数据增广操作。

二.Datasets类

如果我们要自己定义一个读取数据的方法,就得继承torch.utils.data.Dataset这个父类,并且需要重写两个方法

我们可以看一下Dataset父类的源码:

代码语言:javascript
复制
class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

可以清楚的看到我们需要重写两个方法分别是getitem和len方法

下面我们自定义我们自己的数据读取类

代码语言:javascript
复制
import torch
from torch.utils import data
from PIL import Image
import numpy as np
from torchvision import transforms

#继承data.Dataset
# __init__方法
# __getitem__必须创建,作用:对数据切片
#__len__必须创建,作用:返回对象长度
class Tomdataset(data.Dataset):
    def __init__(self,root):
        self.imgs_path=root
    def __getitem__(self,index):
        img_path=self.imgs_path[index]
        return img_path
    def __len__(self):
        return len(Self.imgs_path)

这里的Tomdataset类就简单的实现了一下数据加载类的方法

因为可以重写方法的实现,所以我觉得可玩性还是很高的,比如在getitem方法内我们不仅可以返回一个单纯的元素,如果在构造方法中有其他的参数也可以一并返回。所以客制性很高,可玩性也很高。

在后面的文章中,我们会使用Tomdataset类对我们的数据进行加载和处理。

未完,待续

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-06-17,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Tom的小院 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档