首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Pytorch打怪路pytorch进行CIFAR-10分类(1)CIFAR-10数据加载和处理

Pytorch打怪路pytorch进行CIFAR-10分类(1)CIFAR-10数据加载和处理

作者头像
TeeyoHuang
发布2019-05-25 22:32:32
9530
发布2019-05-25 22:32:32
举报

1.大致流程

一般来说,使用深度学习框架我们会经过下面几个流程:

模型定义(包括损失函数的选择) --->数据处理和加载 ---> 训练(可能包含训练过程可视化) ---> 测试

所以我们在自己写代码的时候也基本上就按照这四个大模块四步走就ok了

官方给的这个例子呢,是先进行的第二步数据处理和加载,然后定义网络,这其实没什么关系。

所以本篇博文讲解的是 数据处理和加载 这一步的内容,当然会接着在后续博文写其他步骤。

此例的步骤: A、Load and normalizing the CIFAR10 training and test datasets using torchvision B、Define a Convolution Neural Network C、Define a loss function D、Train the network on the training data E、Test the network on the test data

下面我就直接上程序,并且添加我自己的一些注解,觉得有问题的欢迎提出,希望和大家多交流。

2、代码分析

首先使用torchvision加载和归一化我们的训练数据和测试数据。 a、torchvision这个东西,实现了常用的一些深度学习的相关的图像数据的加载功能,比如cifar10、Imagenet、Mnist等等的,保存在torchvision.datasets模块中。 b、同时,也封装了一些处理数据的方法。保存在torchvision.transforms模块中 c、还封装了一些模型和工具封装在相应模型中。可以从下图一窥大貌:

#  首先当然肯定要导入torch和torchvision,至于第三个是用于进行数据预处理的模块
import torch
import torchvision
import torchvision.transforms as transforms

#  **由于torchvision的datasets的输出是[0,1]的PILImage,所以我们先先归一化为[-1,1]的Tensor**
    #  首先定义了一个变换transform,利用的是上面提到的transforms模块中的Compose( )
    #  把多个变换组合在一起,可以看到这里面组合了ToTensor和Normalize这两个变换
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 

    # 定义了我们的训练集,名字就叫trainset,至于后面这一堆,其实就是一个类:
    # torchvision.datasets.CIFAR10( )也是封装好了的,就在我前面提到的torchvision.datasets
    # 模块中,不必深究,如果想深究就看我这段代码后面贴的图1,其实就是在下载数据
    #(不访问外国网站可能会慢一点吧)然后进行变换,可以看到transform就是我们上面定义的transform
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
    # trainloader其实是一个比较重要的东西,我们后面就是通过trainloader把数据传入网
    # 络,当然这里的trainloader其实是个变量名,可以随便取,重点是他是由后面的
    # torch.utils.data.DataLoader()定义的,这个东西来源于torch.utils.data模块,
    #  网页链接http://pytorch.org/docs/0.3.0/data.html,这个类可见我后面图2
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)
    # 对于测试集的操作和训练集一样,我就不赘述了
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)
    # 类别信息也是需要我们给定的
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

3、图片

图1

root表示存放dataset的位置,本例就是' ./data'

train,如果为True,就创建的是trainning set,可以看到我们的trainset调用它时用的是True

而testset调用它时,参数里填的是False

transform,这个transform是形参名,由于我们定义的变换也叫transform,所以就有transform = transform,

看起来可能有点怪,其实我们的之前的变换可以随便命名

download,如果为True,就从网上下载,如果已经有下载好的数据就不会重复下载了

------------------------------------------------------------------------------------------------------------------------------------------

图2

数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。

dataset:就是数据的来源,比如训练集就添入我们定义的trainset

batch_size:每批次进入多少数据,本例中填的是4

shuffle:如果为真,就打乱数据的顺序,本例为True

num_workers:用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)

本例中为2。这个值是什么意思呢,就是数据读入的速度到底有多快,你选的用来加载数据的

子进程越多,那么显然数据读的就越快,这样的话消耗CPU的资源也就越多,所以这个值在自己

跑实验的时候,可以自己试一试,既不要让花在加载数据上的时间太多,也不要占用太多电脑资源

所以这第一步----数据加载和处理,要注意的就是这些内容,如果程序运行完毕,会显示:

.

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2018年01月30日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1.大致流程
  • 2、代码分析
  • 3、图片
相关产品与服务
腾讯云代码分析
腾讯云代码分析(内部代号CodeDog)是集众多代码分析工具的云原生、分布式、高性能的代码综合分析跟踪管理平台,其主要功能是持续跟踪分析代码,观测项目代码质量,支撑团队传承代码文化。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档