前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Pytorch加载自己的数据集(使用DataLoader读取Dataset)[通俗易懂]

Pytorch加载自己的数据集(使用DataLoader读取Dataset)[通俗易懂]

作者头像
全栈程序员站长
发布2022-08-18 13:01:25
2K0
发布2022-08-18 13:01:25
举报

大家好,又见面了,我是你们的朋友全栈君。

1. 我们经常可以看到Pytorch加载数据集会用到官方整理好的数据集。很多时候我们需要加载自己的数据集,这时候我们需要使用Dataset和DataLoader

  • Dataset:是被封装进DataLoader里,实现该方法封装自己的数据和标签。
  • DataLoader:被封装入DataLoaderIter里,实现该方法达到数据的划分。

2.Dataset

阅读源码后,我们可以指导,继承该方法必须实现两个方法:

  • _getitem_()
  • _len_() 因此,在实现过程中我们测试如下:
代码语言:javascript
复制
import torch
import numpy as np


# 定义GetLoader类,继承Dataset方法,并重写__getitem__()和__len__()方法
class GetLoader(torch.utils.data.Dataset):
	# 初始化函数,得到数据
    def __init__(self, data_root, data_label):
        self.data = data_root
        self.label = data_label
    # index是根据batchsize划分数据后得到的索引,最后将data和对应的labels进行一起返回
    def __getitem__(self, index):
        data = self.data[index]
        labels = self.label[index]
        return data, labels
    # 该函数返回数据大小长度,目的是DataLoader方便划分,如果不知道大小,DataLoader会一脸懵逼
    def __len__(self):
        return len(self.data)

# 随机生成数据,大小为10 * 20列
source_data = np.random.rand(10, 20)
# 随机生成标签,大小为10 * 1列
source_label = np.random.randint(0,2,(10, 1))
# 通过GetLoader将数据进行加载,返回Dataset对象,包含data和labels
torch_data = GetLoader(source_data, source_label)

3.DataLoader

提供对Dataset的操作,操作如下:

代码语言:javascript
复制
torch.utils.data.DataLoader(dataset,batch_size,shuffle,drop_last,num_workers)

参数含义如下:

  • d a t a s e t \color{HotPink}{dataset} dataset: 加载torch.utils.data.Dataset对象数据
  • b a t c h _ s i z e \color{HotPink}{batch\_size} batch_size: 每个batch的大小
  • s h u f f l e \color{HotPink}{shuffle} shuffle:是否对数据进行打乱
  • d r o p _ l a s t \color{HotPink}{drop\_last} drop_last:是否对无法整除的最后一个datasize进行丢弃
  • n u m _ w o r k e r s \color{HotPink}{num\_workers} num_workers:表示加载的时候子进程数

因此,在实现过程中我们测试如下(紧跟上述用例):

代码语言:javascript
复制
from torch.utils.data import DataLoader

# 读取数据
datas = DataLoader(torch_data, batch_size=6, shuffle=True, drop_last=False, num_workers=2)

此时,我们的数据已经加载完毕了,只需要在训练过程中使用即可。

4.查看数据

我们可以通过迭代器(enumerate)进行输出数据,测试如下:

代码语言:javascript
复制
for i, data in enumerate(datas):
	# i表示第几个batch, data表示该batch对应的数据,包含data和对应的labels
    print("第 {} 个Batch \n{}".format(i, data))

输出结果如下图:

Pytorch加载自己的数据集(使用DataLoader读取Dataset)[通俗易懂]
Pytorch加载自己的数据集(使用DataLoader读取Dataset)[通俗易懂]

结果说明:由于数据的是10个,batchsize大小为6,且drop_last=False,因此第一个大小为6,第二个为4。每一个batch中包含data和对应的labels。 当我们想取出data和对应的labels时候,只需要用下表就可以啦,测试如下:

代码语言:javascript
复制
# 表示输出数据
print(data[0])
# 表示输出标签
print(data[1])

结果如图:

在这里插入图片描述
在这里插入图片描述

发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/134589.html原文链接:https://javaforall.cn

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2022年5月2,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. 我们经常可以看到Pytorch加载数据集会用到官方整理好的数据集。很多时候我们需要加载自己的数据集,这时候我们需要使用Dataset和DataLoader
  • 2.Dataset
  • 3.DataLoader
  • 4.查看数据
相关产品与服务
批量计算
批量计算(BatchCompute,Batch)是为有大数据计算业务的企业、科研单位等提供高性价比且易用的计算服务。批量计算 Batch 可以根据用户提供的批处理规模,智能地管理作业和调动其所需的最佳资源。有了 Batch 的帮助,您可以将精力集中在如何分析和处理数据结果上。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档