pytorch学习笔记(六):自定义Datasets

什么是Datasets:

输入流水线中,我们看到准备数据的代码是这么写的data = datasets.CIFAR10("./data/", transform=transform, train=True, download=True)datasets.CIFAR10就是一个Datasets子类,data是这个类的一个实例。

为什么要定义Datasets:

PyTorch提供了一个工具函数torch.utils.data.DataLoader。通过这个类,我们在准备mini-batch的时候可以多线程并行处理,这样可以加快准备数据的速度。Datasets就是构建这个类的实例的参数之一。

如何自定义Datasets

下面是一个自定义Datasets的框架

class CustomDataset(data.Dataset):#需要继承data.Dataset
    def __init__(self):
        # TODO
        # 1. Initialize file path or list of file names.
        pass
    def __getitem__(self, index):
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        #这里需要注意的是,第一步:read one data,是一个data
        pass
    def __len__(self):
        # You should change 0 to the total size of your dataset.
        return 0

下面看一下官方MNIST的例子(代码被缩减,只留下了重要的部分):

class MNIST(data.Dataset):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')

        if self.train:
            self.train_data, self.train_labels = torch.load(
                os.path.join(root, self.processed_folder, self.training_file))
        else:
            self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file))

    def __getitem__(self, index):
        if self.train:
            img, target = self.train_data[index], self.train_labels[index]
        else:
            img, target = self.test_data[index], self.test_labels[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode='L')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        if self.train:
            return 60000
        else:
            return 10000

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏AI科技评论

开发 | TensorFlow全新的数据读取方式:Dataset API入门教程

AI科技评论按:本文作者何之源,该文首发于知乎专栏AI Insight (https://zhuanlan.zhihu.com/ai-insight),AI科技...

33650
来自专栏聊聊技术

原 初学图论-DAG单源最短路径算法

391120
来自专栏xingoo, 一个梦想做发明家的程序员

Spark源码分析之分区器的作用

最近因为手抖,在Spark中给自己挖了一个数据倾斜的坑。为了解决这个问题,顺便研究了下Spark分区器的原理,趁着周末加班总结一下~ 先说说数据倾斜 数据...

249100
来自专栏人工智能头条

TensorFlow全新的数据读取方式:Dataset API入门教程

25030
来自专栏LhWorld哥陪你聊算法

【推荐系统篇】--推荐系统之训练模型

经过之前的训练数据的构建可以得到所有特征值为1的模型文件,本文将继续构建训练数据特征并构建模型。

19210
来自专栏HTML5学堂

Javascript中的Label语句

HTML5学堂:在JavaScript中,我们可能很少会去用到 Label 语句,但是熟练的应用 Label 语句,尤其是在嵌套循环中熟练应用 break, c...

42070
来自专栏一英里广度一英寸深度的学习

Intellij idea配置Spark开发环境,统计哈姆雷特词频(2)

中间层Spark,即核心模块Spark Core,必须在maven中引用。 编译Spark还要声明java8编译工具。

23520
来自专栏章鱼的慢慢技术路

Direct3D 11 Tutorial 7:Texture Mapping and Constant Buffers_Direct3D 11 教程7:纹理映射和常量缓冲区

在上一个教程中,我们为项目引入了照明。 现在我们将通过向我们的立方体添加纹理来构建它。 此外,我们将介绍常量缓冲区的概念,并解释如何使用缓冲区通过最小化带宽使用...

10640
来自专栏xingoo, 一个梦想做发明家的程序员

在Java Web中使用Spark MLlib训练的模型

模型下载到本地,重新命名为xml。 可以看到默认四个特征分别叫做feild_0,field_1...目标为target

20520
来自专栏人工智能LeadAI

TensorFlow应用实战 | TensorFlow基础知识

hw = tf.constant("Hello World! Mtianyan love TensorFlow!")

18140

扫码关注云+社区

领取腾讯云代金券