pytorch学习笔记(六):自定义Datasets

什么是Datasets:

输入流水线中,我们看到准备数据的代码是这么写的data = datasets.CIFAR10("./data/", transform=transform, train=True, download=True)datasets.CIFAR10就是一个Datasets子类,data是这个类的一个实例。

为什么要定义Datasets:

PyTorch提供了一个工具函数torch.utils.data.DataLoader。通过这个类,我们在准备mini-batch的时候可以多线程并行处理,这样可以加快准备数据的速度。Datasets就是构建这个类的实例的参数之一。

如何自定义Datasets

下面是一个自定义Datasets的框架

class CustomDataset(data.Dataset):#需要继承data.Dataset
    def __init__(self):
        # TODO
        # 1. Initialize file path or list of file names.
        pass
    def __getitem__(self, index):
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        #这里需要注意的是,第一步:read one data,是一个data
        pass
    def __len__(self):
        # You should change 0 to the total size of your dataset.
        return 0

下面看一下官方MNIST的例子(代码被缩减,只留下了重要的部分):

class MNIST(data.Dataset):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')

        if self.train:
            self.train_data, self.train_labels = torch.load(
                os.path.join(root, self.processed_folder, self.training_file))
        else:
            self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file))

    def __getitem__(self, index):
        if self.train:
            img, target = self.train_data[index], self.train_labels[index]
        else:
            img, target = self.test_data[index], self.test_labels[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode='L')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        if self.train:
            return 60000
        else:
            return 10000

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏素质云笔记

极简主义︱利用apple机器学习平台Turicreate实现图像相似性检索(二)

apple开源机器学习框架turicreate内容非常广阔,本篇介绍冰山一角的图像相似,极简主义的代表!!!! github:https://apple.g...

2678
来自专栏Hongten

让你一看就明白什么是单列模式(和静态静态工厂模式)--java版本_源码下载

=================================================

361
来自专栏BestSDK

Pytorch 0.3.0 发布:新增张量函数,支持模型移植

根据官方文档的介绍,此次增加了多个函数和功能,多方面的性能均实现提升。 重大变更 0.3 版本中删掉了 Variable.reinforce() 等随机函数,因...

3418
来自专栏企鹅号快讯

深度学习系列教程(六)tf.data API 使用方法介绍

"玩转TensorFlow与深度学习模型”系列文字教程,本周带来tf.data 使用方法介绍! 大家在学习和实操过程中,有任何疑问都可以通过学院微信交流群进行提...

2467
来自专栏UAI人工智能

实战派 | PaddlePaddle 你其实也可以真正地上手

1574
来自专栏TensorFlow从0到N

TensorFlow从1到2 - 5 - 非专家莫入!TensorFlow实现CNN

当看到本篇时,根据TensorFlow官方标准《Deep MNIST for Experts》,你已经达到Expert Level,要恭喜了。 且不说是否夸大...

6229
来自专栏瓜大三哥

HLS Lesson8-基本操作

1.算术操作 ? 如果是定点数处理时候,需要遵循的原则是:大数据不溢出,小数据不损失 2.算数赋值 ? ? #include<iostream> #includ...

2187
来自专栏WD学习记录

数据结构与算法2016-06-01

1.数据结构包括数据的逻辑结构和物理结构。数据的逻辑结构就是从具体问题抽象出来的数学模型,是为了讨论问题的方便,与数据在计算机中的具体存储没有关系。讨论数据结构...

782
来自专栏算法channel

动态规划中篇:爬楼梯

主要推送关于对算法的思考以及应用的消息。培养思维能力,注重过程,挖掘背后的原理,刨根问底。本着严谨和准确的态度,目标是撰写实用和启发性的文章,欢迎您的关注。 0...

3579
来自专栏专知

【干货】使用TensorFlow官方Java API调用TensorFlow模型(附代码)

3264

扫码关注云+社区