pytorch中制作数据集是要基于Dataset类来进行
首先查看一下Dataset的官方教程
如图,Dataset是一个抽象类,只能被继承,不能被实例化,我们要构建自己的数据集类时需要继承Dataset类,并且所有的子类需要重写Dataset中的__getitem__和__len__函数,前者是我们构建数据集的重点,而后者只是返回数据集的长度。
需要读取的数据存放在名为dataset的文件夹下,文件结构如图:
数据就是.jpg的图片,标签是文件夹名ants,ants下的所有图片都是关于蚂蚁的图片,另有文件夹bees,与ants类似。
from torch.utils.data import Dataset#import Dataset类
from PIL import Image#图像处理
import os#操作系统相关库,用来根据路径读取数据
class Mydata(Dataset):#我们读取数据的类要继承Dataset类
def __init__(self,root_dir,label_dir):
self.root_dir=root_dir
self.label_dir=label_dir
self.path=os.path.join(root_dir,label_dir)#将两个路径合并为一个
self.img_path=os.listdir(self.path)#获取所有图片的文件名列表
def __getitem__(self, idx):#返回数据标签对的函数
img_name=self.img_path[idx]#通过idx(索引)访问数据,最终实例化Mydata后可直接##通过索引访问数据-标签对
img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)#合并路##径
img=Image.open(img_item_path)#获得图片对象
label=self.label_dir
return img,label#最终返回数据-标签对即可
def __len__(self):
return len(self.img_path)
root_dir="dataset/train"
ants_label_dir="ants"
bees_label_dir="bees"
ants_dataset=Mydata(root_dir,ants_label_dir)#蚂蚁数据集
bees_dataset=Mydata(root_dir,bees_label_dir)#蜜蜂数据集
train_dataset=ants_dataset+bees_dataset#合并两个数据集
img,label=ants_dataset[0]#通过索引读取数据对
img.show()#打印图片
print(label)#打印label