前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Pytorch制作数据集

Pytorch制作数据集

作者头像
月梦@剑心
发布2022-09-14 13:57:07
3020
发布2022-09-14 13:57:07
举报
文章被收录于专栏:月梦·剑心的技术专栏

pytorch中制作数据集是要基于Dataset类来进行

首先查看一下Dataset的官方教程

如图,Dataset是一个抽象类,只能被继承,不能被实例化,我们要构建自己的数据集类时需要继承Dataset类,并且所有的子类需要重写Dataset中的__getitem__和__len__函数,前者是我们构建数据集的重点,而后者只是返回数据集的长度。

需要读取的数据存放在名为dataset的文件夹下,文件结构如图:

数据就是.jpg的图片,标签是文件夹名ants,ants下的所有图片都是关于蚂蚁的图片,另有文件夹bees,与ants类似。

代码语言:javascript
复制
from torch.utils.data import Dataset#import Dataset类
from PIL import Image#图像处理
import os#操作系统相关库,用来根据路径读取数据

class Mydata(Dataset):#我们读取数据的类要继承Dataset类
    def __init__(self,root_dir,label_dir):
        self.root_dir=root_dir
        self.label_dir=label_dir
        self.path=os.path.join(root_dir,label_dir)#将两个路径合并为一个
        self.img_path=os.listdir(self.path)#获取所有图片的文件名列表

    def __getitem__(self, idx):#返回数据标签对的函数
        img_name=self.img_path[idx]#通过idx(索引)访问数据,最终实例化Mydata后可直接##通过索引访问数据-标签对
        img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)#合并路##径
        img=Image.open(img_item_path)#获得图片对象
        label=self.label_dir
        return img,label#最终返回数据-标签对即可
    def __len__(self):
        return len(self.img_path)

root_dir="dataset/train"
ants_label_dir="ants"
bees_label_dir="bees"
ants_dataset=Mydata(root_dir,ants_label_dir)#蚂蚁数据集
bees_dataset=Mydata(root_dir,bees_label_dir)#蜜蜂数据集
train_dataset=ants_dataset+bees_dataset#合并两个数据集
img,label=ants_dataset[0]#通过索引读取数据对
img.show()#打印图片
print(label)#打印label
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2022-01-11,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档