pytorch入门教程 | 第四章:准备图片数据集

在训练神经网络之前,我们必须有数据,作为资深伸手党,必须知道以下几个数据提供源:

1 CIFAR-10

CIFAR-10图片样本截图

CIFAR-10是多伦多大学提供的图片数据库,图片分辨率压缩至32x32,一共有10种图片分类,均进行了标注。适合监督式学习。CIFAR-10数据下载页面:http://www.cs.toronto.edu/~kriz/cifar.html

2 ImageNet

imagenet首页

ImageNet首页:http://image-net.org/

3

ImageFolder

imagefolder首页

ImageFolder首页:https://www.tugraz.at/institute/icg/home/

4

LSUN Classification

LSUN Classification

LSUN 图片下载地址:

http://lsun.cs.princeton.edu/2016/

5

COCO (Captioning and Detection)

coco首页

COCO首页地址:

http://mscoco.org/

6

我们进入正题

为了方便加载以上五种数据库的数据,pytorch团队帮我们写了一个torchvision包。使用torchvision就可以轻松实现数据的加载和预处理。

我们以使用CIFAR10为例:

导入torchvision的库:

import torchvision import torchvision.transforms as transforms # transforms用于数据预处理

使用datasets.CIFAR10()函数加载数据库。CIFAR10有60000张图片,其中50000张是训练集,10000张是测试集。

#训练集,将相对目录./data下的cifar-10-batches-py文件夹中的全部数据(50000张图片作为训练数据)加载到内存中,若download为True时,会自动从网上下载数据并解压trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=None)

下面简单讲解root、train、download、transform这四个参数

1.root,表示cifar10数据的加载的相对目录

2.train,表示是否加载数据库的训练集,false的时候加载测试集

3.download,表示是否自动下载cifar数据集

4.transform,表示是否需要对数据进行预处理,none为不进行预处理

由于美帝路途遥远,靠命令台进程下载100多M的数据速度很慢,所以我们可以自己去到cifar10的官网上把CIFAR-10 python version下载下来,然后解压为cifar-10-batches-py文件夹,并复制到相对目录./data下。(若设置download=True,则程序会自动从网上下载cifar10数据到相对目录./data下,但这样小伙伴们可能要等一个世纪了),并对训练集进行加载(train=True)。

如图所示,在脚本文件下建一个data文件夹,然后把数据集文件夹丢到里面去就好了,注意cifar-10-batches-py文件夹名字不能自己任意改。

我们在写完上面三行代码后,在写一行print一下trainset的大小看看:

print len(trainset) #结果:50000

我们在训练神经网络时,使用的是mini-batch(一次输入多张图片),所以我们在使用一个叫DataLoader的工具为我们将50000张图分成每四张图一分,一共12500份的数据包。

#将训练集的50000张图片划分成12500份,每份4张图,用于mini-batch输入。shffule=True在表示不同批次的数据遍历时,打乱顺序(这个需要在训练神经网络时再来讲)。num_workers=2表示使用两个子进程来加载数据 import torch trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=False, num_workers=2)

那么我们就写下了这几行代码:

print的结果为50000和12500

下面我们需要对数据进行预处理,什么是预处理?为什么要预处理?如果不知道的小盆友可以看看下面几个链接,或许对你有帮助。神经网络为什么要归一(http://blog.csdn.net/liuheng0111/article/details/52841838),深度学习-----数据预处理(http://blog.csdn.net/dcxhun3/article/details/47999281)。还无法理解也没关系,只要记住,预处理会帮助我们加快神经网络的训练。

在pytorch中我们预处理用到了transforms函数:

transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])

compose函数会将多个transforms包在一起。

我们的transforms有好几种,例如transforms.ToTensor(), transforms.Scale()等,完整列表在这。好好学习吧!

我只讲现在用到了两种:

1.ToTensor是指把PIL.Image(RGB) 或者numpy.ndarray(H x W x C) 从0到255的值映射到0到1的范围内,并转化成Tensor格式。

2.Normalize(mean,std)是通过下面公式实现数据归一化

channel=(channel-mean)/std

那么经过上面两个转换一折腾,我们的数据中的每个值就变成了[-1,1]的数了。

1到22行,我们从硬盘中读取数据,并将数据预处理(第13行,transform=transform),然后转换成4张图为一批的数据结构。26行到47行,为我们显示出一个图片例子,可有可无,不再作代码解释。

原文发布于微信公众号 - 人工智能LeadAI(atleadai)

原文发表时间:2017-09-04

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏GAN&CV

使用tensorflow 的slim模块fine-tune resnet/densenet/inception网络,解决batchnorm问题

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_25737169/article/d...

2112
来自专栏CNN

从Tensorflow模型文件中解析并显示网络结构图(CKPT模型篇)

上一篇文章《从Tensorflow模型文件中解析并显示网络结构图(pb模型篇)》中介绍了如何从pb模型文件中提取网络结构图并实现可视化,本文介绍如何从CKPT模...

1673
来自专栏ATYUN订阅号

使用Tensorflow对象检测在安卓手机上“寻找”皮卡丘

在TensorFlow的许多功能和工具中,隐藏着一个名为TensorFlow对象探测API(TensorFlow Object Detection API)的组...

4325
来自专栏AI研习社

Github 项目推荐 | 微软开源 MMdnn,模型可在多框架间转换

近期,微软开源了 MMdnn,这是一套能让用户在不同深度学习框架间做相互操作的工具。比如,模型的转换和可视化,并且可以让模型在 Caffe、Keras、MXNe...

3548
来自专栏HansBug's Lab

算法模板——sap网络最大流 1(非递归+邻接矩阵)

实现功能:首行输入N,M,S,T,代表这张图N个点,M条边,源点为S,汇点为T;接下来T行输入个边的出发点、终点和权值;输出最大流 原理:sap网络流算法(详见...

2757
来自专栏ATYUN订阅号

浣熊检测器实例, 如何用TensorFlow的Object Detector API来训练你的物体检测器

这篇文章是“用Tensorflow和OpenCV构建实时对象识别应用”的后续文章。具体来说,我在自己收集和标记的数据集上训练了我的浣熊检测器。完整的数据集可以在...

4367
来自专栏ATYUN订阅号

使用Java部署训练好的Keras深度学习模型

Keras库为深度学习提供了一个相对简单的接口,使神经网络可以被大众使用。然而,我们面临的挑战之一是将Keras的探索模型转化为产品模型。Keras是用Pyth...

974
来自专栏机器学习实践二三事

使用Faster-Rcnn进行目标检测(实践篇)

原理 上一篇文章,已经说过了,大家可以参考一下,Faster-Rcnn进行目标检测(原理篇) 实验 我使用的代码是python版本的Faster Rcnn,官方...

6676
来自专栏北京马哥教育

搭建python机器学习环境以及一个机器学习例子

作者 | hzyido 来源 | 简书 糖豆贴心提醒,本文阅读时间6分钟,文末有秘密! 这篇文章介绍了Python机器学习环境的搭建,我用的机器学习开...

50112
来自专栏AI研习社

YOLO 升级到 v3 版,速度相比 RetinaNet 快 3.8 倍

雷锋网 AI 研习社按,YOLO 是一种非常流行的目标检测算法,速度快且结构简单。日前,YOLO 作者推出 YOLOv3 版,在 Titan X 上训练时,在 ...

1053

扫码关注云+社区