前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【Pytorch基础】加载数据集

【Pytorch基础】加载数据集

作者头像
yhlin
发布2023-02-27 16:59:25
7800
发布2023-02-27 16:59:25
举报
文章被收录于专栏:yhlin's blogyhlin's blog

回顾

  上一篇训练神经网络是用的是批梯度下降,容易陷入鞍点中。Pytorch 提供了一个数据集加载工具,使得我们可以方便地用小批量随机梯度下降来训练网络。其包含两个部分:

  • Dataset: 用于构造数据集(支持索引)
  • DataLoader: 每次拿出一个 Mini-Batch 用于训练更新

Epoch,Batch-Size,Iterations 概念释义

  • Epoch: 表示一个训练周期,所有样本都进行一次前馈、反馈计算
  • Batch-Size: 表示一个 Mini-Batch 包含的样本数量,即每次训练(一次更新)时用到的样本数量
  • Iterations: 全部样本被划分的 Mini-Batch 的数量,如 1000 个样本,Batch-Size=100,那么 Iteration=10
代码语言:javascript
复制
# 训练循环
for epoch in range(trainning_epochs):
    # 用所有 Mini-Batch 训练
    for i in range(total_batch): # 执行 Iteration 次
        pass

DataLoader

代码语言:javascript
复制
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)

参数释义:

  • dataset (Dataset) – dataset from which to load the data.
  • batch_size (int, optional) – how many samples per batch to load (default: 1).
  • shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).
  • sampler (Sampler, optional) – defines the strategy to draw samples from the dataset. If specified, shuffle must be False.
  • batch_sampler (Sampler, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.
  • num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
  • collate_fn (callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
  • pin_memory (bool, optional) – If True, the data loader will copy Tensors into CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below.
  • drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)
  • timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)
  • worker_init_fn (callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)

假设 batch_dize=2,shuffle=True, 经 DataLoader 过程如下:

【Pytorch 基础】加载数据集
【Pytorch 基础】加载数据集

从左至右先打乱样本顺序最终得到一个可迭代的 Loader,每次迭代将(yield)产生一个 Mini-Batch 用于训练网络。

Dataset

 Dataset 是一个抽象类,无法被实例化。只能被其子类继承,再实例化。因此,若要实例化 Dataset 我们必须自己写一个类来继承自它。其结构大致为:

代码语言:javascript
复制
from torch.utils.data import Dataset # Dataset 是一个抽象类, 不能实例化
from torch.utils.data import DataLoader

class MyDataset(Dataset):
    def __init__(self, filepath):
        # 加载数据集,如 csv 文件等,有两种方式:# 1. All in: 将所有数据加载到内存 (适用于数据集不大的情况)
        # 2. 如果数据集很大,可以分割成内存允许大小的文件,用一个列表放文件名,然后训练时用 getitem 函数时在将其读取到内存中
        pass

    def __getitem__(self,index): # 使对象支持下标操作 dataset[index]
        pass

    def __len__(self): # 返回数据集中的样本数
        pass

实例化数据集对象:

代码语言:javascript
复制
dataset = DiabetesDataset(filepath)
train_loader = DataLoader(dataset=dataset,batch_size=32,shuffle=True,num_workers=2)

糖尿病数据集

加载数据集:

代码语言:javascript
复制
import torch
import numpy as np
from torch.utils.data import Dataset # Dataset 是一个抽象类, 不能实例化
from torch.utils.data import DataLoader

class DiabetesDataset(Dataset):
    def __init__(self, filepath):
        # 数据集很小,直接加载进内存
        xy = np.loadtxt(filepath,delimiter=',',dtype=np.float32)
        self.len = xy.shape[0] # 取 行 数(样本数)
        self.x_data = torch.from_numpy(xy[:,:-1]) # 所有行,最后一列不要
        self.y_data = torch.from_numpy(xy[:,[-1]]) # 所有行,只要最后一列,- 1 加 [] 表示拿出来一个矩阵,而不是向量

    def __getitem__(self,index):
        return self.x_data[index], self.y_data[index] # 返回一个元组(x,y)def __len__(self):
        return self.len

dataset = DiabetesDataset('https://project-preview-1257022783.cos.ap-chengdu.myqcloud.com/diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset,batch_size=32,shuffle=True,num_workers=2)

训练:

代码语言:javascript
复制
for epoch in range(100):
    for i, data in enumerate(train_loader,0): # for i, (data,labels) in enumerate(train_loader,0):
        # 准备数据
        inputs, labels = data # 自动转换成 Tensor
        # 前馈计算
        y_pred = model(inputs)
        loss = criterion(y_pred,labels)
        #print(epoch, i, loss.item())
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        # 更新参数
        optimizer.step()

Pytorchvision 库内置的数据集

  • MNIST
  • Fashion-MNIST
  • EMNIST
  • COCO
  • LSUN
  • ImageFolder
  • DatasetFolder
  • Imagenet-12
  • CIFAR
  • STL10
  • PhotoTour

这些数据集都继承与 torch.utils.data.Dataset, 都具有 getitem 和 len 函数的实现,可以直接用 torch.utils.data.DataLoader 进行加载。

引入方式

代码语言:javascript
复制
import torch
from torch.utils.data import DataLoader
from torchvision import transforms 
from torchvision import datasets

# Mnist
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms 
from torchvision import datasets

# FashionMnist
train_dataset = torchvision.datasets.FashionMNIST(root='./dataset/fmnist/',train=True,transform=transforms.ToTensor(),download=True)
test_dataset =  torchvision.datasets.FashionMNIST(root='./dataset/fmnist/',train=False,transform=transforms.ToTensor(),download=True)

train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=4,shuffle=True,num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=4,shuffle=False) # 测试不需要 shuffle 打乱顺序,保证结果的顺序
# 训练
for epoch in range(epoch_trainning):
    for batch_idx, (inputs, target) in enumerate(train_loader,0):
        pass # 前馈、反馈计算
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2023-01-19,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 回顾
  • Epoch,Batch-Size,Iterations 概念释义
  • DataLoader
  • Dataset
  • 糖尿病数据集
    • 加载数据集:
      • 训练:
      • Pytorchvision 库内置的数据集
        • 引入方式
        相关产品与服务
        批量计算
        批量计算(BatchCompute,Batch)是为有大数据计算业务的企业、科研单位等提供高性价比且易用的计算服务。批量计算 Batch 可以根据用户提供的批处理规模,智能地管理作业和调动其所需的最佳资源。有了 Batch 的帮助,您可以将精力集中在如何分析和处理数据结果上。
        领券
        问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档