前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >[PyTorch小试牛刀]实战六·准备自己的数据集用于训练(基于猫狗大战数据集)

[PyTorch小试牛刀]实战六·准备自己的数据集用于训练(基于猫狗大战数据集)

作者头像
小宋是呢
发布2019-06-27 12:56:26
1.6K1
发布2019-06-27 12:56:26
举报
文章被收录于专栏:深度应用深度应用

[PyTorch小试牛刀]实战六·准备自己的数据集用于训练(基于猫狗大战数据集)

在上面几个实战中,我们使用的是Pytorch官方准备好的FashionMNIST数据集进行的训练与测试。本篇博文介绍我们如何自己去准备数据集,以应对更多的场景。

我们此次使用的是猫狗大战数据集,开始之前我们要先把数据处理一下,形式如下

datas │ └───train │ │ │ └───cats │ │ │ cat1000.jpg │ │ │ cat1001.jpg │ │ │ … │ └───dogs │ │ │ dog1000.jpg │ │ │ dog1001.jpg │ │ │ … └───valid │ │ │ └───cats │ │ │ cat0.jpg │ │ │ cat1.jpg │ │ │ … │ └───dogs │ │ │ dog0.jpg │ │ │ dog1.jpg │ │ │ …

train数据集中有23000张数据,valid数据集中有2000数据用于验证网络性能

代码部分 1.采用隐形字典形式,代码简练,不易理解

代码语言:javascript
复制
import torch as t
import torchvision as tv
import os

data_dir = "./datas"

BATCH_SIZE = 100

EPOCH = 10

transform = {
    x:tv.transforms.Compose(
        [tv.transforms.Resize([64,64]),tv.transforms.ToTensor()]#tv.transforms.Resize 用于重设图片大小
    ) 
    for x in ["train","valid"]
}

datasets = {
    x:tv.datasets.ImageFolder(root = os.path.join(data_dir,x),transform=transform[x])
    for x in ["train","valid"]
}

dataloader = {
    x:t.utils.data.DataLoader(dataset= datasets[x],
        batch_size=BATCH_SIZE,
        shuffle=True
    ) 
    for x in ["train","valid"]
}

b_x,b_y = next(iter(dataloader["train"]))

print(b_x.shape,b_y.shape)

index_classes = datasets["train"].class_to_idx

print(index_classes)

2.采用显性字典形式,代码稍多,易于理解

代码语言:javascript
复制
import torch as t
import torchvision as tv


data_dir = "./datas"

BATCH_SIZE = 100

EPOCH = 10

transform = {
    "train":tv.transforms.Compose(
        [tv.transforms.Resize([64,64]),tv.transforms.ToTensor()]
    ),
    "valid":tv.transforms.Compose(
        [tv.transforms.Resize([64,64]),tv.transforms.ToTensor()]
    ),
}

datasets = {
    "train":tv.datasets.ImageFolder(root = os.path.join(data_dir,"train"),transform=transform["train"]),
    "vaild":tv.datasets.ImageFolder(root = os.path.join(data_dir,"vaild"),transform=transform["vaild"]),
}
dataloader = {
    "train":t.utils.data.DataLoader(dataset= datasets["train"],
        batch_size=BATCH_SIZE,
        shuffle=True
    ),
    "valid":t.utils.data.DataLoader(dataset= datasets["valid"],
        batch_size=100,
        shuffle=True
    )
}

b_x,b_y = next(iter(dataloader["train"]))

print(b_x.shape,b_y.shape)

index_classes = datasets["train"].class_to_idx

print(index_classes)

输出结果

代码语言:javascript
复制
torch.Size([100, 3, 64, 64]) torch.Size([100])
{'cats': 0, 'dogs': 1}
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2018年12月23日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • [PyTorch小试牛刀]实战六·准备自己的数据集用于训练(基于猫狗大战数据集)
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档