pytorch学习笔记(六):自定义Datasets

什么是Datasets:

输入流水线中,我们看到准备数据的代码是这么写的data = datasets.CIFAR10("./data/", transform=transform, train=True, download=True)datasets.CIFAR10就是一个Datasets子类,data是这个类的一个实例。

为什么要定义Datasets:

PyTorch提供了一个工具函数torch.utils.data.DataLoader。通过这个类,我们在准备mini-batch的时候可以多线程并行处理,这样可以加快准备数据的速度。Datasets就是构建这个类的实例的参数之一。

如何自定义Datasets

下面是一个自定义Datasets的框架

class CustomDataset(data.Dataset):#需要继承data.Dataset
    def __init__(self):
        # TODO
        # 1. Initialize file path or list of file names.
        pass
    def __getitem__(self, index):
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        #这里需要注意的是,第一步:read one data,是一个data
        pass
    def __len__(self):
        # You should change 0 to the total size of your dataset.
        return 0

下面看一下官方MNIST的例子(代码被缩减,只留下了重要的部分):

class MNIST(data.Dataset):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')

        if self.train:
            self.train_data, self.train_labels = torch.load(
                os.path.join(root, self.processed_folder, self.training_file))
        else:
            self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file))

    def __getitem__(self, index):
        if self.train:
            img, target = self.train_data[index], self.train_labels[index]
        else:
            img, target = self.test_data[index], self.test_labels[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode='L')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        if self.train:
            return 60000
        else:
            return 10000

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏落影的专栏

GPUImage详细解析(九)图像的输入输出和滤镜通道

前言 GPUImage系列解析已经接近尾声,这次介绍的是: 纹理输入输出GPUImageTextureOutput 和 GPUImageTextureOutpu...

4588
来自专栏wOw的Android小站

[Tensorflow] 在Android运行TensorFlow模型

以下代码来自于TensorFlowObjectDetectionAPIModel.java

2231
来自专栏应兆康的专栏

100个Numpy练习【5】

翻译:YingJoy 网址: https://www.yingjoy.cn/ 来源: https://github.com/rougier/numpy-100...

57412
来自专栏Flutter入门

Android OpenGL ES(六) - 将输入源换成视频

上文中,我们是将相机提供的预览画面输入到Surface当中,然后进行滤镜处理和录制。 那我们可以切换输入源为视频吗?

1554
来自专栏大数据文摘

你的数据科学python编程能力过关吗?看看这40道题你能得几分

1303
来自专栏生信技能树

比对NR库看看物种分布【直播】我的基因组88

前面我提前了我的基因组测序数据里面的未成功比对到人类基因组上面的那些fastq序列,也用了软件把它们组装成fasta序列,这些序列的功能是未知的,可以通过比对到...

4228
来自专栏CDA数据分析师

excel隔行选取新方法,99.99%的人没用过!

文 | 赵志东 有一个excel表格,需要隔N行选取。可能是隔一行,也可能是隔2行、3行。 ? 常见的有两种方法: 一、按ctrl键不松用鼠标逐行选取。如果需要...

1766
来自专栏应兆康的专栏

100个Numpy练习【5】

Numpy是Python做数据分析必须掌握的基础库之一,非常适合刚学习完Numpy基础的同学,完成以下习题可以帮助你更好的掌握这个基础库。

53810
来自专栏月色的自留地

Metal并行计算以及Metal程序的命令行编译

2664
来自专栏Golang语言社区

抽奖问题分析

普通抽奖问题问题描述 用户随机抽奖,数据如下: // map中,key代表用户名,value代表成用户下单数 var users map[string]int...

3313

扫码关注云+社区