Loading [MathJax]/jax/input/TeX/config.js
前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >专栏 >PyTorch 中自定义数据集的读取方法

PyTorch 中自定义数据集的读取方法

原创
作者头像
陶陶name
发布于 2022-05-12 00:51:02
发布于 2022-05-12 00:51:02
9520
举报
文章被收录于专栏:陶陶计算机陶陶计算机

显然我们在学习深度学习时,不能只局限于通过使用官方提供的MNSIT、CIFAR-10、CIFAR-100这样的数据集,很多时候我们还是需要根据自己遇到的实际问题自己去搜集数据,然后制作数据集(收集数据集的方法有很多,这里就不过多的展开了)。这里只介绍数据集的读取。 1. 自定义数据集的方法: 首先创建一个Dataset类

在这里插入图片描述
在这里插入图片描述
在代码中: def init() 一些初始化的过程写在这个函数下 def len() 返回所有数据的数量,比如我们这里将数据划分好之后,这里仅仅返回的是被处理后的关系 def getitem() 回数据和标签补充代码 上述已经将框架打出来了,接下来就是将框架填充完整就行了,下面是完整的代码,代码的解释说明我也已经写在其中了# -*- coding: utf-8 -*- # @Author : 胡子旋 # @Email :1017190168@qq.com import torch import os,glob import visdom import time import torchvision import random,csv from torch.utils.data import Dataset,DataLoader from torchvision import transforms from PIL import Image class pokemom(Dataset): def __init__(self,root,resize,mode,): super(pokemom,self).__init__() # 保存参数 self.root=root self.resize=resize # 给每一个类做映射 self.name2label={} # "squirtle":0 ,"pikachu":1…… for name in sorted(os.listdir(os.path.join(root))): # 过滤掉文件夹 if not os.path.isdir(os.path.join(root,name)): continue # 保存在表中;将最长的映射作为最新的元素的label的值 self.name2label[name]=len(self.name2label.keys()) print(self.name2label) # 加载文件 self.images,self.labels=self.load_csv('images.csv') # 裁剪数据 if mode=='train': self.images=self.images[:int(0.6*len(self.images))] # 将数据集的60%设置为训练数据集合 self.labels=self.labels[:int(0.6*len(self.labels))] # label的60%分配给训练数据集合 elif mode=='val': self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))] # 从60%-80%的地方 self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))] else: self.images = self.images[int(0.8 * len(self.images)):] # 从80%的地方到最末尾 self.labels = self.labels[int(0.8 * len(self.labels)):] # image+label 的路径 def load_csv(self,filename): # 将所有的图片加载进来 # 如果不存在的话才进行创建 if not os.path.exists(os.path.join(self.root,filename)): images=[] for name in self.name2label.keys(): images+=glob.glob(os.path.join(self.root,name,'*.png')) images+=glob.glob(os.path.join(self.root, name, '*.jpg')) images += glob.glob(os.path.join(self.root, name, '*.jpeg')) print(len(images),images) # 1167 'pokeman\\bulbasaur\\00000000.png' # 将文件以上述的格式保存在csv文件内 random.shuffle(images) with open(os.path.join(self.root,filename),mode='w',newline='') as f: writer=csv.writer(f) for img in images: # 'pokeman\\bulbasaur\\00000000.png' name=img.split(os.sep)[-2] label=self.name2label[name] writer.writerow([img,label]) print("write into csv into :",filename) # 如果存在的话就直接的跳到这个地方 images,labels=[],[] with open(os.path.join(self.root, filename)) as f: reader=csv.reader(f) for row in reader: # 接下来就会得到 'pokeman\\bulbasaur\\00000000.png' 0 的对象 img,label=row # 将label转码为int类型 label=int(label) images.append(img) labels.append(label) # 保证images和labels的长度是一致的 assert len(images)==len(labels) return images,labels # 返回数据的数量 def __len__(self): return len(self.images) # 返回的是被裁剪之后的关系 def denormalize(self, x_hat): mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1) std = torch.tensor(std).unsqueeze(1).unsqueeze(1) # print(mean.shape, std.shape) x = x_hat * std + mean return x # 返回idx的数据和当前图片的label def __getitem__(self,idx): # idex-[0-总长度] # retrun images,labels # 将图片,label的路径取出来 # 得到的img是这样的一个类型:'pokeman\\bulbasaur\\00000000.png' # 然而label得到的则是 0,1,2 这样的整形的格式 img,label=self.images[idx],self.labels[idx] tf=transforms.Compose([ lambda x:Image.open(x).convert('RGB'), # 将t图片的路径转换可以处理图片数据 # 进行数据加强 transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))), # 随机旋转 transforms.RandomRotation(15), # 设置旋转的度数小一些,否则的话会增加网络的学习难度 # 中心裁剪 transforms.CenterCrop(self.resize), # 此时:既旋转了又不至于导致图片变得比较的复杂 transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]) ]) img=tf(img) label=torch.tensor(label) return img,label def main(): # 验证工作 viz=visdom.Visdom() db=pokemom('pokeman',64,'train') # 这里可以改变大小 224->64,可以通过visdom进行查看 # 可视化样本 x,y=next(iter(db)) print('sample:',x.shape,y.shape,y) viz.image(db.denormalize(x),win='sample_x',opts=dict(title='sample_x')) # 加载batch_size的数据 loader=DataLoader(db,batch_size=32,shuffle=True,num_workers=8) for x,y in loader: viz.images(db.denormalize(x),nrow=8,win='batch',opts=dict(title='batch')) viz.text(str(y.numpy()),win='label',opts=dict(title='batch-y')) # 每一次加载后,休息10s time.sleep(10) if __name__ == '__main__': main()

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
Transfer Learning
通过网络上收集宝可梦的图片,制作图像分类数据集。我收集了5种宝可梦,分别是皮卡丘,超梦,杰尼龟,小火龙,妙蛙种子
mathor
2020/02/17
4500
深度学习实战之手写签名识别(100%准确率、语音播报)
在完成了上述的环境搭建后,即可进入到准备阶段了。这里准备的有数据集的准备、以及相关代码的主备。
陶陶name
2022/05/13
1.6K0
深度学习实战之垃圾分类
垃圾分类,指按一定规定或标准将垃圾分类储存、分类投放和分类搬运,从而转变成公共资源的一系列活动的总称。分类的目的是提高垃圾的资源价值和经济价值,力争物尽其用;然而我们在日常生活中认为对垃圾分类还是有些不知所措的,对干垃圾、湿垃圾……分的不是很清楚,由此我们就想到了使用深度学习的方法进行分类。简介 本篇博文主要会带领大家进行数据的预处理、网络搭建、模型训练、模型测试 1. 获取数据集 这里笔者已经为大家提供了一个比较完整的数据集,所以大家不必再自己去收集数据了 数据集链接:https://pan.baidu
陶陶name
2022/05/13
6270
PyTorch 自定义数据集
准备 COCO128[1] 数据集,其是 COCO[2] train2017 前 128 个数据。按 YOLOv5 组织的目录:
GoCoding
2021/05/06
8270
PyTorch 自定义数据集
Pytorch打怪路(三)Pytorch创建自己的数据集2
版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/Teeyohuang/article/details/82108203
TeeyoHuang
2019/05/25
9840
我用 PyTorch 复现了 LeNet-5 神经网络(自定义数据集篇)!
我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)!
红色石头
2022/01/10
1.6K0
我用 PyTorch 复现了 LeNet-5 神经网络(自定义数据集篇)!
PyTorch-24h 04_自定义数据集
如果我们正在构建像 Nutrify 这样的食物图像分类应用程序,我们的自定义数据集可能是食物图像。 如果我们试图建立一个模型来分类网站上基于文本的评论是正面的还是负面的,我们的自定义数据集可能是现有客户评论及其评级的示例。 如果我们试图构建一个声音分类应用程序,我们的自定义数据集可能是声音样本及其样本标签。
一只大鸽子
2023/02/24
8620
PyTorch-24h  04_自定义数据集
05-PyTorch自定义数据集Datasets、Loader和tranform
对于机器学习中的许多不同问题,我们采取的步骤都是相似的。PyTorch 有许多内置数据集,用于大量机器学习基准测试。除此之外也可以自定义数据集,本问将使用我们自己的披萨、牛排和寿司图像数据集,而不是使用内置的 PyTorch 数据集。具体来说,我们将使用 torchvision.datasets 以及我们自己的自定义 Dataset 类来加载食物图像,然后我们将构建一个 PyTorch 计算机视觉模型,希望对三种物体进行分类。
renhai
2023/11/24
1.1K0
05-PyTorch自定义数据集Datasets、Loader和tranform
十分钟搞懂Pytorch如何读取MNIST数据集
本文用于记录使用pytorch读取minist数据集的过程,以及一些思考和疑惑吧…
全栈程序员站长
2022/07/23
5.4K0
十分钟搞懂Pytorch如何读取MNIST数据集
轻松学pytorch – 使用多标签损失函数训练卷积网络
大家好,我还在坚持继续写,如果我没有记错的话,这个是系列文章的第十五篇,pytorch中有很多非常方便使用的损失函数,本文就演示了如何通过多标签损失函数训练验证码识别网络,实现验证码识别。
OpenCV学堂
2020/07/16
1.2K0
轻松学Pytorch-迁移学习实现表面缺陷检查
大家好,我又又好久没有更新这个系列了,主要原因归根结底只有一个懒,所谓 一勤天下无难事,百思心中有良谋。以后还争取每周更新,这次隔了一周没有更新,对不起大家了。今天给大家更新的是如何基于torchvision自带的模型完成图像分类任务的迁移学习,前面我们已经完成了对对象检测任务的迁移学习,这里补上针对图像分类任务的迁移学习,官方的文档比较啰嗦,看了之后其实可操作性很低,特别是对于初学者,估计看了之后就发懵的那种。本人重新改写了一波,代码简洁易懂,然后把训练结果导出ONNX,使用OpenCV DNN调用部署,非常实用!废话不多说了,少吹水!
OpenCV学堂
2020/09/22
1.5K0
轻松学Pytorch-迁移学习实现表面缺陷检查
【pytorch-ssd目标检测】训练自己创建的数据集
制作类似pascal voc格式的目标检测数据集:https://www.cnblogs.com/xiximayou/p/12546061.html
西西嘛呦
2020/08/26
1.4K0
【pytorch-ssd目标检测】训练自己创建的数据集
【小白学习PyTorch教程】十七、 PyTorch 中 数据集torchvision和torchtext
对于PyTorch加载和处理不同类型数据,官方提供了torchvision和torchtext。
润森
2022/08/18
1.2K0
轻松学Pytorch-实现自定义对象检测器
大家好,今天来继续更新轻松学Pytorch专栏,这个是系列文章我会一直坚持写下去的,希望大家转发、点赞、留言支持!上一篇文章使用了torchvision中提供的预训练对象检测网络Faster-RCNN实现了常见的对象检测,基于COCO数据集,支持90个类型对象检测,非常的实用。本文将介绍如何使用自定义数据集,使用Faster-RCNN预训练模型实现迁移学习,完成自定义对象检测。
OpenCV学堂
2020/07/30
8630
轻松学Pytorch-实现自定义对象检测器
PyTorch数据Pipeline标准化代码模板
PyTorch作为一款流行深度学习框架其热度大有超越TensorFlow的感觉。根据此前的统计,目前TensorFlow虽然仍然占据着工业界,但PyTorch在视觉和NLP领域的顶级会议上已呈一统之势。
机器视觉CV
2019/12/17
1.7K0
PyTorch数据Pipeline标准化代码模板
pytorch Dataset, DataLoader产生自定义的训练数据「建议收藏」
datasets这是一个pytorch定义的dataset的源码集合。下面是一个自定义Datasets的基本框架,初始化放在__init__()中,其中__getitem__()和__len__()两个方法是必须重写的。__getitem__()返回训练数据,如图片和label,而__len__()返回数据长度。
全栈程序员站长
2022/09/06
2K0
vggnet pytorch_Javaweb项目
VGG网络是在2014年由牛津大学著名研究组VGG (Visual Geometry Group) 提出。
全栈程序员站长
2022/11/17
3980
vggnet pytorch_Javaweb项目
VGG16 训练猫狗数据集
准备数据应该是一件比较麻烦的过程,所以一般都去找那种公开的数据集。在网上找到的可以用于猫狗分类的数据集有 Kaggle 的 “Dogs vs. Cats”数据集,还有牛津大学提供的 Oxford-IIIT Pet 数据集,包含猫和狗的图片,都是非常适合做猫狗分类任务的公开数据集。
繁依Fanyi
2025/03/24
1240
Github复现之TransUnet更新[通俗易懂]
上一篇关于TransUnet的GitHub复现,大家反映效果不好,调参也不好调,我把模型单独拿出来,放到另外一个框架,供大家参考学习(上一篇链接:https://blog.csdn.net/qq_20373723/article/details/115548900) 我这里训练了20个epoch,下面先给出效果正常的情况: 原图
全栈程序员站长
2022/10/04
1.7K0
Github复现之TransUnet更新[通俗易懂]
使用关键点进行小目标检测
【GiantPandaCV导语】本文是笔者出于兴趣搞了一个小的库,主要是用于定位红外小目标。由于其具有尺度很小的特点,所以可以尝试用点的方式代表其位置。本文主要采用了回归和heatmap两种方式来回归关键点,是一个很简单基础的项目,代码量很小,可供新手学习。
BBuf
2020/09/10
9390
推荐阅读
相关推荐
Transfer Learning
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档