前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >使用Dataset加载数据

使用Dataset加载数据

作者头像
Tom2Code
发布2022-11-21 12:06:29
7320
发布2022-11-21 12:06:29
举报
文章被收录于专栏:Tom

上一篇文章我们对torch.utils.data.Dataset类进行了学习,并且也封装了一个Tomdataset类

pytorch之Dataset

代码语言:javascript
复制
#继承data.Dataset
# __init__方法
# __getitem__必须创建,作用:对数据切片
#__len__必须创建,作用:返回对象长度
class Tomdataset(data.Dataset):
    def __init__(self,root):
        self.imgs_path=root
    def __getitem__(self,index):
        img_path=self.imgs_path[index]
        return img_path
    def __len__(self):
        return len(self.imgs_path)

那么今天我们直接使用一个新的类来处理我们这次训练的数据集。

这次训练的数据集是1100张天气的照片,并且图片已经打好标签,也就是每一张图片的文件名则是该图片的分类

一共有四种天气的图片,分别是多云,下雨,晴天和日出。

首先我们使用python的glob库读取这些文件

代码语言:javascript
复制
all_imgs_path=glob.glob(r'D:\codingSpace\DeepLearning\weather\*.jpg')
all_imgs_path[:10]

我们打印图片列表的前10张:

代码语言:javascript
复制
['D:\\codingSpace\\DeepLearning\\weather\\cloudy1.jpg',
 'D:\\codingSpace\\DeepLearning\\weather\\cloudy10.jpg',
 'D:\\codingSpace\\DeepLearning\\weather\\cloudy100.jpg',
 'D:\\codingSpace\\DeepLearning\\weather\\cloudy101.jpg',
 'D:\\codingSpace\\DeepLearning\\weather\\cloudy102.jpg',
 'D:\\codingSpace\\DeepLearning\\weather\\cloudy103.jpg',
 'D:\\codingSpace\\DeepLearning\\weather\\cloudy104.jpg',
 'D:\\codingSpace\\DeepLearning\\weather\\cloudy105.jpg',
 'D:\\codingSpace\\DeepLearning\\weather\\cloudy106.jpg',
 'D:\\codingSpace\\DeepLearning\\weather\\cloudy107.jpg']

然后定义分类的一个数组

代码语言:javascript
复制
species=['cloudy','rain','shine','sunrise']
species_to_idx=dict((c,i) for i,c in enumerate(species))

species_to_idx是这样的一个字典:

我们再转化一下

代码语言:javascript
复制
idx_to_species=dict((v,k) for k,v in species_to_idx.items())

然后就是我们需要的格式的

然后就是生成一个标签数组和我们的图片对应起来

代码语言:javascript
复制
all_labels=[]
for img in all_imgs_path:
    for i,c in enumerate(species):
        if c in img:
            all_labels.append(i)

然后我们可以验证一下标签是否可以和图片对应

可以看到经过检验,label的最后五个输出和图片本身的标签一致

然后就是进入我们今天的主题,自定义一个数据加载类

但是在此之前我们需要定义一个transform

代码语言:javascript
复制
transform=transforms.Compose([
    transforms.Resize((96,96)),
    transforms.ToTensor(),  
])

class MyDataset(data.Dataset):
    def __init__(self,img_paths,labels,transform):
        self.imgs=img_paths
        self.labels=labels
        self.transforms=transform
        
    def __getitem__(self,index):
        img=self.imgs[index]
        label=self.labels[index]
        pil_img=Image.open(img)
        pil_img=pil_img.convert("RGB") #可选
        data=self.transforms(pil_img)
        return data,label
    
    def __len__(self):
        return len(self.imgs)

这里的MyDataset类就是我们自定义的数据加载类,我们说过自写类然后继承torch.utils.data.Dataset父类 ,这样的做法可玩性很高。

然后我们实例化一个MyDataset类,并且传入我们上面写的参数。

代码语言:javascript
复制
wheather_dataset=MyDataset(all_imgs_path,all_labels,transform)

这样我们的dataset的数据预处理就完成了,接下来的任务就可以交给dataloader了

代码语言:javascript
复制
wheather_dl=data.DataLoader(wheather_dataset,batch_size=16,shuffle=True)
imgs_batch,labels_batch=next(iter(wheather_dl))

然后最后就是绘图了 直接上图片

至此我们对Dataset的讲解就告一段落,感谢大家的观看。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-06-18,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Tom的小院 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档