前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >专栏 >PyTorch从入门到放弃之数据模块

PyTorch从入门到放弃之数据模块

作者头像
愷龍
发布于 2024-09-07 04:42:18
发布于 2024-09-07 04:42:18
10000
代码可运行
举报
文章被收录于专栏:愷龍的Blog愷龍的Blog
运行总次数:0
代码可运行

Dataset 和 DataLoader 都 是 用 来 帮 助 我 们 加 载 数 据 集 的 两 个 重 要 工 具类。 Dataset 用来构造支持索引的数据集。 在训练时需要在全部样本中拿出小批量数据参与每次的训练,因此我们需要使用 DataLoader ,即 DataLoader 是用来在 Dataset 里取出一组数据 (mini-batch)供训练时快速使用的。

Dataset简介及用法

Dataset 本质上就是一个抽象类,可以把数据封装成 Python 可以识别的数据结构。Dataset 类不能实例化,所以在使用 Dataset 的时候,我们需要定义自己的数据集类,也是 Dataset 的子类,来继承 Dataset 类的属性和方法。Dataset 可作为 DataLoader 的参数传入 DataLoader ,实现基于张量的数据预处理。Dataset 主要有两种类型,分别为 Map-style datasets 和 Iterable-style datasets 。

Map-style datasets类型

该类型实现了 getitem() 和 len() 方法,它代表数据的索引到真正数据样本的映射。也就是说,使用这种方式读取的数据并非直接直接把所有数据读取出来,而是读取数据的索引或者键值。其中,列表或者数组类型的数据读取的就是索引,而字典类型的数据读取的就是键值。在访问时,用dataset[idx]访问idx对应的真实数据。这种类型的数据也是使用最多的类型。

Iterable-style datasets类型

该类型实现了 iter() 方法,与上述类型不同之处在于,他会将真实的数据全部载入,然后在整个数据集上进行迭代。如果随机读取的情况不能实现或者代价太大就用这种读取方式。这种读取数据的方式比较适合处理流数据

Dataset 作为一个抽象类,需要定义其子类来实例化。所以需要自己定义其子类或者使用已经定义好的子类。

(1)自定义子类

  • 必须要继承已经内置的抽象类 dataset
  • 必须要重写其中的 init() 方法、 getitem() 方法和 len() 方法
  • 其中 getitem() 方法实现通过给定的索引遍历数据样本, len() 方法实现返回数据的条数

定义一个MyDataset类继承Dataset抽象类,其中pass为占位符,并且改写其中的三个方法

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    
    def __init__(self):
        pass
    
    def __getitem__(self, index):
        pass
    
    def __len__(self):
        pass

这里定义了一个MyDataset类继承Dataset抽象类,并且改写其中的三个方法。在创建的dataset类中可根据用户本身的需求对数据进行处理。可独立编写的数据处理函数,在__getitem__()函数中进行调用;或者直接将数据处理方法写在__getitem__()函数中或者__init__()函数中,但__getitem__()函数必须根据index返回响应的值,该值会通过index传到DataLoader中进行厚涂的Batch批量处理。

在创建的dataset类中可根据自己的需求对数据进行处理,以时间序列使用为示例,输入3个时间步,输出1个时间步,batch_size=5

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import torch 
from torch.utils.data import Dataset

class GetTrainTestData(Dataset):
    def __init__(self, input_len, output_len, train_rate, is_train=True):
        super().__init__()
        # 使用sin函数返回10000个时间序列,如果不自己构造数据,就使用numpy,pandas等读取自己的数据为x即可。
        # 以下数据组织这块既可以放在init方法里,也可以放在getitem方法里
        self.x = torch.sin(torch.arange(0, 1000, 0.1))
        self.sample_num = len(self.x)
        self.input_len = input_len
        self.output_len = output_len
        self.train_rate = train_rate
        self.src, self.trg = [], []
        if is_train:
            for i in range(int(self.sample_num*train_rate)-self.input_len-self.output_len):
                self.src.append(self.x[i:(i+input_len)])
                self.trg.append(self.x[(i+input_len):(i+input_len+output_len)])
        else:
            for i in range(int(self.sample_num*train_rate), self.sample_num-self.input_len-self.output_len):
                self.src.append(self.x[i:(i+input_len)])
                self.trg.append(self.x[(i+input_len):(i+input_len+output_len)])
        print(len(self.src), len(self.trg))

    def __getitem__(self, index):
        return self.src[index], self.trg[index]

    def __len__(self):
        return len(self.src)  # 或者return len(self.trg), src和trg长度一样

实例化定义好的Dataset子类GetTrainTestData

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
data_train = GetTrainTestData(input_len=3, output_len=1, train_rate=0.8, is_train=True)
data_test = GetTrainTestData(input_len=3, output_len=1, train_rate=0.8, is_train=False)

(2)已经定义好的内置子类

除了自己定义子类继承Dataset外,还可以使用PyTorch提供的已经被定义好的子类,如TensorDataset和IterableDataset。

对 于 给 定 的 tensor 数 据 , TensorDataset 是 一 个 包 装 了 Tensor 的Dataset 子类,传入的参数就是张量,每个样本都可以通过 Tensor 第一个维度的索引获取,所以传入张量的第一个维度必须一致。

PyTorch官方给出的TensorDataset类的定义:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
class TensorDataset(Dataset[Tuple[Tensor, ...]]):
    r"""Dataset wrapping tensors.

    Each sample will be retrieved by indexing tensors along the first dimension.

    Args:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """
    tensors: Tuple[Tensor, ...]

    def __init__(self, *tensors: Tensor) -> None:
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors"
        self.tensors = tensors

    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)

    def __len__(self):
        return self.tensors[0].size(0)

所以这个类的实例化有两个参数,分别为data_tensor(Tensor)样本数据和target_tensor(Tensor)样本标签。

使用TensorDataset:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import torch
from torch.utils.data import TensorDataset

src = torch.sin(torch.arange(1, 1000, 0.1))
trg = torch.cos(torch.arange(1, 1000, 0.1))

于是可以直接实例化已定义好的Dataset子类TensorDataset

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
data = TensorDataset(src, trg)

DataLoader简介及用法

Dataset 和 DataLoader 是一起使用的,在模型训练的过程中不断为模型提供数据,同时,使用 Dataset 加载出来的数据集也是 DataLoader 的第一个参数。所以, DataLoader 本质上就是用来将已经加载好的数据以模型能够接收的方式输入到即将训练的模型中去。

几个深度学习模型训练时涉及的参数:

(1)Data_size:所有数据的样本数量。

(2)Batch_size:每个Batch加载多少个样本。

(3)Batch:每一批放进module训练的样本叫一个Batch。

(4)Epoch:模型把所有样本训练完毕一次叫做一个Epoch。

(5)Iteration:所有数据共分成了几个Batch,即训练几次才能够便利所有样本/数据。

(6)Shuffle:在抽取Batch之前是否将样本全部打乱顺序。

数据的输入过程如下图所示。

Data_size=10 , Batch_size=3 ,一次 Epoch 需要四次 Iteration ,第一列为所有样本,第二列为打乱之后的所有样本,由于 Batch_size=3 ,所以通过 DataLoader输入了 4 个 batch ,包括最后一个数量已经不够 3 个的 Batch4 ,里边只包含sample3

官方给出的DataLoader定义:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None,*, prefetch_factor=2,
           persistent_workers=False)

参数说明:

dataset: 通过Dataset加载进来的数据集。

batch_size:每个Batch加载多少个样本。

shuffle: 是否打乱输入数据的顺序,设置为True时,调用RandomSample进行随机索引。

sampler: 定义从数据集中提取样本的策略,若指定,就不能用shuffle函数随机索引,其取值必须为False。

batch_sampler: 批量采样,每次返回一个Batch大小的索引,默认设置为None,和batch_size、shuffle等参数是互斥的。

num_workers: 用多少子进程加载数据。0表示数据将在主进程中加载,根据自己的计算资源配置选定。

collate_fn: 将一小段数据合并成数据列表以形成一个Batch。

pin_memory:是否在将张量返回之前将其复制到Cuda固定的内存中。

drop_last: 设置了batch_size的数目后,最后一批数据未必是设置的数目,有可能会小一些,这时需要丢弃这些数据。

timeout:设置数据表读取的超时时间,但超过这个时间还没读取到数据就会报错,不能为负。

worker_init_fn:是否在数据导入前和步长结束后根据工作子进程的ID逐个按照顺序导入数据,默认为None。

prefetch_factor:每个worker提前加载的Sample数量。

persistent_workers: 如果为True,DataLoader将不会终值worker进程,直到dataset迭代完成。

将Dataset读取的数据输入到DataLoader中。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

class GetTrainTestData(Dataset):
    def __init__(self, input_len, output_len, train_rate, is_train=True):
        super().__init__()
        # 使用sin函数返回10000个时间序列,如果不自己构造数据,就使用numpy,pandas等读取自己的数据为x即可。
        # 以下数据组织这块既可以放在init方法里,也可以放在getitem方法里
        self.x = torch.sin(torch.arange(1, 1000, 0.1))
        self.sample_num = len(self.x)
        self.input_len = input_len
        self.output_len = output_len
        self.train_rate = train_rate
        self.src,  self.trg = [], []
        if is_train:
            for i in range(int(self.sample_num*train_rate)-self.input_len-self.output_len):
                self.src.append(self.x[i:(i+input_len)])
                self.trg.append(self.x[(i+input_len):(i+input_len+output_len)])
        else:
            for i in range(int(self.sample_num*train_rate), self.sample_num-self.input_len-self.output_len):
                self.src.append(self.x[i:(i+input_len)])
                self.trg.append(self.x[(i+input_len):(i+input_len+output_len)])
        print(len(self.src), len(self.trg))

    def __getitem__(self, index):
        return self.src[index], self.trg[index]

    def __len__(self):
        return len(self.src)  # 或者return len(self.trg), src和trg长度一样


data_train = GetTrainTestData(input_len=3, output_len=1, train_rate=0.8, is_train=True)
data_test = GetTrainTestData(input_len=3, output_len=1, train_rate=0.8, is_train=False)
data_loader_train = DataLoader(data_train, batch_size=5, shuffle=False)
data_loader_test = DataLoader(data_test, batch_size=5, shuffle=False)
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
for idx, train in enumerate(data_loader_train):
    print(idx, train)
    break
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2024-09-06,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
Dataset和DataLoader
Pytorch通常使用Dataset和DataLoader这两个工具类来构建数据管道。
lyhue1991
2020/07/20
2.4K0
Dataset和DataLoader
PyTorch中 Datasets & DataLoader 的介绍
用于处理数据样本的代码可能很快就会变得混乱且难以维护。理想情况下,为了获得更好的可读性和模块化,我们希望处理数据集的代码与模型训练代码分离。
JOYCE_Leo16
2024/04/16
2430
PyTorch中 Datasets & DataLoader 的介绍
Pytorch实现基于卷积神经网络的面部表情识别(详细步骤)「建议收藏」
另外,我整理了整个项目的精简版本,完整代码,开箱即用,教程详细,方便快捷!下载:Pytorch实现基于卷积神经网络的面部表情识别项目源码
全栈程序员站长
2022/08/29
1.3K0
Pytorch实现基于卷积神经网络的面部表情识别(详细步骤)「建议收藏」
【转载】PyTorch系列 (二): pytorch数据读取
本文首先介绍了有关预处理包的源码,接着介绍了在数据处理中的具体应用; 其主要目录如下:
marsggbo
2019/03/08
2.1K0
PyTorch 源码解读之 torch.utils.data:解析数据处理全流程
来源丨https://zhuanlan.zhihu.com/p/337850513
BBuf
2021/07/01
1.6K0
pytorch-DataLoader(数据迭代器)
本博客讲解了pytorch框架下DataLoader的多种用法,每一种方法都展示了实例,虽然有一点复杂,但是小伙伴静下心看一定能看懂哦 :)
全栈程序员站长
2022/07/01
1.3K0
PyTorch 小课堂开课啦!带你解析数据处理全流程(一)
最近被迫开始了居家办公,这不,每天认真工(mo)作(yu)之余,也有了更多时间重新学习分析起了 PyTorch 源码分享,属于是直接站在巨人的肩膀上了。在简单捋一捋思路之后,就从 torch.utils.data 数据处理模块开始,一步步重新学习 PyTorch 的一些源码模块解析,希望也能让大家重新认识已经不陌生的 PyTorch 这个小伙伴。
OpenMMLab 官方账号
2022/05/25
1.1K0
PyTorch 小课堂开课啦!带你解析数据处理全流程(一)
【深度学习】Pytorch 教程(十四):PyTorch数据结构:6、数据集(Dataset)与数据加载器(DataLoader):自定义鸢尾花数据类
  Tensor(张量)是PyTorch中用于表示多维数据的主要数据结构,类似于多维数组,可以存储和操作数字数据。
Qomolangma
2024/07/30
2010
【深度学习】Pytorch 教程(十四):PyTorch数据结构:6、数据集(Dataset)与数据加载器(DataLoader):自定义鸢尾花数据类
AI:使用pytorch通过BERT模型进行文本分类
BERT 是一个强大的语言模型,至少有两个原因:它使用从 BooksCorpus (有 8 亿字)和 Wikipedia(有 25 亿字)中提取的未标记数据进行预训练。它是通过利用编码器堆栈的双向特性进行预训练的。这意味着 BERT 不仅从左到右,而且从右到左从单词序列中学习信息。
Freedom123
2024/03/29
1.3K1
AI:使用pytorch通过BERT模型进行文本分类
小白学PyTorch | 8 实战之MNIST小试牛刀
在这个文章中,主要是来做一下MNIST手写数字集的分类任务。这是一个基础的、经典的分类任务。建议大家一定要跟着代码做一做,源码和数据已经上传到公众号。回复【pytorch】获取数据和源码哦~
机器学习炼丹术
2020/09/14
7910
小白学PyTorch | 8 实战之MNIST小试牛刀
小白学PyTorch | 3 浅谈Dataset和Dataloader
PyTorch 读取其他的数据,主要是通过 Dataset 类,所以先简单了解一下 Dataset 类。在看很多PyTorch的代码的时候,也会经常看到dataset这个东西的存在。Dataset类作为所有的 datasets 的基类存在,所有的 datasets 都需要继承它。
机器学习炼丹术
2020/09/03
2.4K0
小白学PyTorch | 3 浅谈Dataset和Dataloader
Pytorch划分数据集的方法
之前用过sklearn提供的划分数据集的函数,觉得超级方便。但是在使用TensorFlow和Pytorch的时候一直找不到类似的功能,之前搜索的关键字都是“pytorch split dataset”之类的,但是搜出来还是没有我想要的。结果今天见鬼了突然看见了这么一个函数torch.utils.data.Subset。我的天,为什么超级开心hhhh。终于不用每次都手动划分数据集了。 torch.utils.data Pytorch提供的对数据集进行操作的函数详见:https://pytorch.or
marsggbo
2019/05/26
4.5K0
【深度学习】Pytorch 教程(十五):PyTorch数据结构:7、模块(Module)详解(自定义神经网络模型并训练、评估)
  Tensor(张量)是PyTorch中用于表示多维数据的主要数据结构,类似于多维数组,可以存储和操作数字数据。
Qomolangma
2024/07/30
3990
【深度学习】Pytorch 教程(十五):PyTorch数据结构:7、模块(Module)详解(自定义神经网络模型并训练、评估)
pytorch笔记
pytorch刚上手确实不太容易适应。特别是Andrew给出的1.x的tensorflow代码,和当前torch的差异还是很大的。这里的用法挺琐碎的,用作备忘性质。
Sarlren
2022/10/28
2820
pytorch demo 实践
相关环境 python opencv pytorch ubuntu 14.04 pytorch 基本内容 60分钟快速入门,参考:https://blog.csdn.net/u014630987/article/details/78669051 需要学习的内容包括 1、基本概念Tensors、Variable、Numpy等 2、如何搭建神经网络模型(包括卷积神经网络) 3、如何定义损失函数和优化器(包括不同分类器和优化器的含义) 4、如何训练(包括如何读取数据、如何在GPU上
张俊怡
2018/04/24
2.1K0
pytorch demo 实践
时间序列数据建模流程范例
最开始在学习神经网络,PyTorch 的时候,懂的都还不多,虽然也知道 RNN, CNN 这些网络的原理,但真正自己实现起来又是另一回事,代码往往也都是从网上 copy 过来然后再自己魔改的,这也就导致了一系列的问题,代码格式不统一,没弄懂具体实现细节等等。当然,凭这些 copy 过来的代码让模型运行起来还是不难的,你只需要知晓一定的原理。显而易见,这些时间往往最后都是要“还”的。
EmoryHuang
2022/10/31
1.2K0
时间序列数据建模流程范例
【深度学习实验】前馈神经网络(七):批量加载数据(直接加载数据→定义类封装数据)
在本系列先前的代码中,借助深度学习框架的帮助,已经完成了前馈神经网络的大部分功能。本文将基于鸢尾花数据集构建一个数据迭代器,以便在每次迭代时从全部数据集中获取指定数量的数据。(借助深度学习框架中的Dataset类和DataLoader类来实现此功能)
Qomolangma
2024/07/30
1580
【深度学习实验】前馈神经网络(七):批量加载数据(直接加载数据→定义类封装数据)
基于PyTorch深度学习框架的序列图像数据装载器
如今,深度学习和机器学习算法正在统治世界。PyTorch是最常用的深度学习框架之一,用于实现各种深度学习算法。另一方面,基于学习的方法本质上需要一些带注释的训练数据集,这些数据集可以被模型用来提取输入数据和标签之间的关系。为了给神经网络提供数据,我们定义了一个数据加载器。
磐创AI
2021/09/03
6140
pytorch源码分析之torch.utils.data.Dataset类和torch.utils.data.DataLoader类
Pytorch深度学习框架优势之一是python优先,源代码由python代码层和C语言代码层组成,一般只需要理解python代码层就可以深入理解pytorch框架的计算原理。所以学习pytorch源码需要熟练掌握python语言的各种使用技巧。
全栈程序员站长
2022/08/24
1K0
pytorch源码分析之torch.utils.data.Dataset类和torch.utils.data.DataLoader类
Pytorch(五)入门:DataLoader 和 Dataset
构建模型的基本方法,我们了解了。 接下来,我们就要弄明白怎么对数据进行预处理,然后加载数据,我们以前手动加载数据的方式,在数据量小的时候,并没有太大问题,但是到了大数据量,我们需要使用 shuffle, 分割成mini-batch 等操作的时候,我们可以使用PyTorch的API快速地完成这些操作。
全栈程序员站长
2022/07/01
4440
Pytorch(五)入门:DataLoader 和 Dataset
推荐阅读
相关推荐
Dataset和DataLoader
更多 >
领券
社区富文本编辑器全新改版!诚邀体验~
全新交互,全新视觉,新增快捷键、悬浮工具栏、高亮块等功能并同时优化现有功能,全面提升创作效率和体验
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
查看详情【社区公告】 技术创作特训营有奖征文