前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >CIFAR10数据集实战-数据读取部分(上)

CIFAR10数据集实战-数据读取部分(上)

作者头像
用户6719124
发布2019-12-19 10:17:10
2.2K0
发布2019-12-19 10:17:10
举报
文章被收录于专栏:python pytorch AI机器学习实践

本节课主要介绍CIFAR10数据集

登录http://www.cs.toronto.edu/~kriz/cifar.html网站,可以自行下载数据集。

打开页面后

前讲的MNIST数据集为0~9的数字识别,而这里的为10类物品识别。由上可见物品包含有飞机、汽车、鸟、猫等。照片大小为32*32的彩色图片。每一个类别大概有6000张照片,其中随机筛选出5000作为学习,余下的1000用于测试。

首先在pycharm软件中新建文件夹,并创建main.py文件。

首先引入一部分相关的工具包

代码语言:javascript
复制
import torch
from torchvision import datasets
# 引入pytorch、datasets工具包

定义main函数

代码语言:javascript
复制
def main():
    
    
    
if __name__ == '__main__':
    main()

下面开始在里面写入代码

首先开始加载数据集

代码语言:javascript
复制
def main():

    cifar_train = datasets.CIFAR10('cifar', train=True, transform=transforms.Compose([
        transforms.Resize((32, 32)),

继续

在前面引入工具包处加入代码

代码语言:javascript
复制
from torchvision import transforms
# 引入数据变换工具包

继续定义数据集代码

代码语言:javascript
复制
def main():

    cifar_train = datasets.CIFAR10('cifar', train=True, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        # .Compose相当于一个数据转换的集合
        # 进行数据转换,首先将图片统一为32*32
        transforms.ToTensor()
        # 将数据转化到Tensor中

    ]))
    # 直接在datasets中导入CIFAR10数据集,放在"cifar"文件夹中

这里暂时不写Normalize函数

写到这里别忘了让pytorch自己下载数据集

在代码后面加入download=True即可实现

代码语言:javascript
复制
]), download=True)

Cifar_train 的代码部分已经写好

写到这里要注意这里只是建立了一次加载一张的代码

若想一次性加载一批,则要利用其多线程的特性

继续在引入工具包部分加入相关工具包

代码语言:javascript
复制
from torch.utils.data import DataLoader
# 多线程数据读取

继续书写数据读取部分代码

按照其提示,写入相关参数

cifar_train = DataLoader(cifar_train, batch_size=batchsz, ) # 按照其要求,这里的参数需要有batch_size, # 在该部分代码前面定义batch_size

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-12-06,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 python pytorch AI机器学习实践 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档