上一篇文章我们对torch.utils.data.Dataset类进行了学习,并且也封装了一个Tomdataset类
#继承data.Dataset
# __init__方法
# __getitem__必须创建,作用:对数据切片
#__len__必须创建,作用:返回对象长度
class Tomdataset(data.Dataset):
def __init__(self,root):
self.imgs_path=root
def __getitem__(self,index):
img_path=self.imgs_path[index]
return img_path
def __len__(self):
return len(self.imgs_path)
那么今天我们直接使用一个新的类来处理我们这次训练的数据集。
这次训练的数据集是1100张天气的照片,并且图片已经打好标签,也就是每一张图片的文件名则是该图片的分类
一共有四种天气的图片,分别是多云,下雨,晴天和日出。
首先我们使用python的glob库读取这些文件
all_imgs_path=glob.glob(r'D:\codingSpace\DeepLearning\weather\*.jpg')
all_imgs_path[:10]
我们打印图片列表的前10张:
['D:\\codingSpace\\DeepLearning\\weather\\cloudy1.jpg',
'D:\\codingSpace\\DeepLearning\\weather\\cloudy10.jpg',
'D:\\codingSpace\\DeepLearning\\weather\\cloudy100.jpg',
'D:\\codingSpace\\DeepLearning\\weather\\cloudy101.jpg',
'D:\\codingSpace\\DeepLearning\\weather\\cloudy102.jpg',
'D:\\codingSpace\\DeepLearning\\weather\\cloudy103.jpg',
'D:\\codingSpace\\DeepLearning\\weather\\cloudy104.jpg',
'D:\\codingSpace\\DeepLearning\\weather\\cloudy105.jpg',
'D:\\codingSpace\\DeepLearning\\weather\\cloudy106.jpg',
'D:\\codingSpace\\DeepLearning\\weather\\cloudy107.jpg']
然后定义分类的一个数组
species=['cloudy','rain','shine','sunrise']
species_to_idx=dict((c,i) for i,c in enumerate(species))
species_to_idx是这样的一个字典:
我们再转化一下
idx_to_species=dict((v,k) for k,v in species_to_idx.items())
然后就是我们需要的格式的
然后就是生成一个标签数组和我们的图片对应起来
all_labels=[]
for img in all_imgs_path:
for i,c in enumerate(species):
if c in img:
all_labels.append(i)
然后我们可以验证一下标签是否可以和图片对应
可以看到经过检验,label的最后五个输出和图片本身的标签一致
然后就是进入我们今天的主题,自定义一个数据加载类
但是在此之前我们需要定义一个transform
transform=transforms.Compose([
transforms.Resize((96,96)),
transforms.ToTensor(),
])
class MyDataset(data.Dataset):
def __init__(self,img_paths,labels,transform):
self.imgs=img_paths
self.labels=labels
self.transforms=transform
def __getitem__(self,index):
img=self.imgs[index]
label=self.labels[index]
pil_img=Image.open(img)
pil_img=pil_img.convert("RGB") #可选
data=self.transforms(pil_img)
return data,label
def __len__(self):
return len(self.imgs)
这里的MyDataset类就是我们自定义的数据加载类,我们说过自写类然后继承torch.utils.data.Dataset父类 ,这样的做法可玩性很高。
然后我们实例化一个MyDataset类,并且传入我们上面写的参数。
wheather_dataset=MyDataset(all_imgs_path,all_labels,transform)
这样我们的dataset的数据预处理就完成了,接下来的任务就可以交给dataloader了
wheather_dl=data.DataLoader(wheather_dataset,batch_size=16,shuffle=True)
imgs_batch,labels_batch=next(iter(wheather_dl))
然后最后就是绘图了 直接上图片
至此我们对Dataset的讲解就告一段落,感谢大家的观看。