前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >从原始图片数据开始构建卷积神经网络(Pytorch)

从原始图片数据开始构建卷积神经网络(Pytorch)

原创
作者头像
机器视觉CV
修改2019-07-15 10:30:45
7130
修改2019-07-15 10:30:45
举报
文章被收录于专栏:机器视觉CV机器视觉CV
说在前面

入门机器学习的时候,我们往往使用的是框架自带的数据集来进行学习的,这样其实跳过了机器学习最重要的步骤,数据预处理,本文通过从原始数据(图片格式)到卷积神经网络的设计,逐步实现 MNIST 的分类

本文使用的是 Facebook 的深度学习框架 Pytorch

MNIST 数据集是机器学习界的 HelloWorld ,主要是手写字符(0-9)

数据下载:后台回复 MNIST 获取下载链接

代码语言:javascript
复制
# 导入所需要的包
import torch # 1.1.0 版本
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import shutil
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
原始数据处理

10 个文件夹下面各有 10000 张图片,我们对原始数据集进行分配

原始数据集
原始数据集

原始数据集

将其分成训练集、测试集、验证集,各自按照类别文件夹放置。智能一点,让程序自己根据设定的比例拆分

代码语言:javascript
复制
def path_init(src_path, dst_path, rate=(0.6, 0.2, 0.2)):
    """
    将原始数据按比较分配成 train validation test
    :param src_path: 原始数据路径,要求格式如下
    - src_path
        - class_1
        - class_2
        ...
    :param dst_path: 目标路径
    :param rate: 分配比例,加起来一定要等于 1
    :return:
    """
    # 以下几行是创建如下格式的文件夹
    """
    - img_data
        - train
            - class_1
            - class_2
            ...
        - validation
            - class_1
            - class_2
            ...
        - test
            - class_1
            - class_2
            ...
    """
    try:
        class_names = os.listdir(src_path)  # 获取原始数据所有类别的纯文件名
        dst_path = dst_path + '/' + 'MNIST100000_init'
        os.mkdir(dst_path)  # 创建目标文件夹
        three_paths = [dst_path + '/' +
                       i for i in ['train', 'validation', 'test']]  # 三个文件夹的路径
        for three_path in three_paths:
            os.mkdir(three_path)
            for class_name in class_names:
                os.mkdir(three_path+'/'+class_name)
        # -----------------------------

        dst_train = dst_path + '/' + 'train'
        dst_validation = dst_path + '/' + 'validation'
        dst_test = dst_path + '/' + 'test'

        class_names_list = [src_path + '/' +
                            class_name for class_name in class_names]  # 获取原始数据所有类别的路径

        for class_li in class_names_list:
            imgs = os.listdir(class_li)  # 当前类别所有图片的文件名,不包括路径
            # 得到当前类别的所有图片的路径,指定后缀
            imgs_list = [class_li + '/' +
                         img for img in imgs if img.endswith("png")]
            print(len(imgs_list))
            img_num = len(imgs_list)  # 当前类别的图片数量
            # 三个文件夹的数量
            train_num = int(rate[0]*img_num)
            validation_num = int(rate[1]*img_num)
            # test_num = int(rate[2]*img_num)

            for img in imgs_list[0:train_num]:
                # 训练集复制
                src = img
                dst = dst_train + '/' + \
                    img.split('/')[-2] + '/' + img.split('/')[-1]
                # print(src, " ", dst)
                shutil.copy(src=img, dst=dst)
            print("训练集数量:", len(imgs_list[0:train_num]))

            for img in imgs_list[train_num:train_num+validation_num]:
                # 验证集复制
                src = img
                dst = dst_validation + '/' + \
                    img.split('/')[-2] + '/' + img.split('/')[-1]
                # print(src, " ", dst)
                shutil.copy(src=img, dst=dst)
            print("验证集数量:", len(imgs_list[train_num:train_num+validation_num]))

            for img in imgs_list[train_num + validation_num:]:
                # 测试集复制
                src = img
                dst = dst_test + '/' + \
                    img.split('/')[-2] + '/' + img.split('/')[-1]
                # print(src, " ", dst)
                shutil.copy(src=img, dst=dst)
            print("测试集数量:", len(imgs_list[train_num + validation_num:]))

    except:
        print("目标文件夹已经存在或原始文件夹不存在,请检查!")


# # 例程
src_path = './data/MNIST100000/'
dst_path = './data/'    
path_init(src_path, dst_path, rate=(0.6, 0.2, 0.2))
根据原始数据创建数据集自己的类

制作自己的数据集类,需要继承 torch.utils.data.dataset.Dataset 并重写 __getitem____len__ 方法

可以参考框架中 MNIST 数据集类的写法:https://pytorch.org/docs/stable/_modules/torchvision/datasets/mnist.html#MNIST

代码语言:javascript
复制
# 创建一个数据集类:继承 Dataset
class My_DataSet(Dataset):
    def __init__(self, img_dir, transform=None):
        super(My_DataSet, self).__init__()
        self.img_dir = img_dir
        class_dir = [self.img_dir + '/' + i for i in os.listdir(self.img_dir)] # 10 个数字的路径
        img_list = []
        for num in range(len(class_dir)):
            img_list += [class_dir[num]+'/'+img_name for img_name in os.listdir(class_dir[num]) if img_name.endswith("png")]
        self.img_list = img_list # 得到所有图片的路径
        self.transform = transform

    def __getitem__(self, index):
        label = self.img_list[index].split("/")[-2]
        img = np.array(Image.open(self.img_list[index]))

        if self.transform is not None:
            img = self.transform(img)
        return img, int(label) # 得到的是字符串,故要进行类型转换

    def __len__(self):
        return len(self.img_list)

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.1307, ), (0.3081, ))]) # 这个归一化值是数据提供方给的
简单可视化
代码语言:javascript
复制
# 这里可视化就没有进行 transform 操作
for data, label in DataLoader(My_DataSet("./data/MNIST100000_init/train/"), batch_size=32, shuffle=True):
    break

fig, ax = plt.subplots(4, 8, figsize=(6, 6))
for i, axi in enumerate(ax.flat):
    axi.imshow(data[i], cmap='binary')
    axi.set(xticks=[], yticks=[])
    axi.set_title(str(label[i].item())) 
可视化结果
可视化结果

可视化结果

构建训练、测试数据加载器
代码语言:javascript
复制
BATCH_SIZE = 512  # 大概需要2G的显存
EPOCHS = 20  # 总共训练批次
DEVICE = torch.device("cuda" if torch.cuda.is_available()
                      else "cpu")  # 让torch判断是否使用GPU,建议使用GPU环境,因为会快很多
代码语言:javascript
复制
train_dataSet = My_DataSet("./data/MNIST100000_init/train/", transform=transform)
train_dataSet_loader = DataLoader(train_dataSet, batch_size=BATCH_SIZE, shuffle=True)
test_dataSet = My_DataSet("./data/MNIST100000_init/test/", transform=transform)
test_dataSet_loader = DataLoader(test_dataSet, batch_size=BATCH_SIZE, shuffle=True)
构建卷积神经网络
代码语言:javascript
复制
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        # 1,28x28
        self.conv1 = nn.Conv2d(1, 10, 5)  # 10, 24x24  # 图片输入只有一个通道,10 个卷积,kernel_size=5
        self.conv2 = nn.Conv2d(10, 20, 3)  # 20, 10x10 # 20 个卷积 ,kernel_size=3
        self.fc1 = nn.Linear(20*10*10, 500) # 20 个特征图,每个 10*10
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        in_size = x.size(0)
        out = self.conv1(x)  # 24
        out = F.relu(out)
        out = F.max_pool2d(out, 2, 2)  # 12
        out = self.conv2(out)  # 10
        out = F.relu(out)
        out = out.view(in_size, -1)
        out = self.fc1(out)
        out = F.relu(out)
        out = self.fc2(out)
        out = F.log_softmax(out, dim=1)
        return out

model = ConvNet().to(DEVICE)
optimizer = optim.Adam(model.parameters())
代码语言:javascript
复制
# 将训练的过程封装成函数
def train(model, device, train_loader, optimizer, epoch):
    model.train() # 训练模式
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()  # 清除前一步的梯度
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if (batch_idx + 1) % 30 == 0:
            print(('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item())))

# 将测试的过程封装成函数
def test(model, device, test_loader):
    model.eval() # 评估模式
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # 将一批的损失相加
            pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
代码语言:javascript
复制
# 开始训练
for epoch in range(1, EPOCHS + 1):
    train(model, DEVICE, train_dataSet_loader, optimizer, epoch)
    test(model, DEVICE, test_dataSet_loader)
代码语言:javascript
复制
# 部分训练结果
Train Epoch: 1 [14848/60000 (25%)]    Loss: 0.328185
Train Epoch: 1 [30208/60000 (50%)]    Loss: 0.207250
Train Epoch: 1 [45568/60000 (75%)]    Loss: 0.140810
Test set: Average loss: 0.1036, Accuracy: 19401/20000 (97%)
...
Train Epoch: 20 [14848/60000 (25%)]    Loss: 0.000147
Train Epoch: 20 [30208/60000 (50%)]    Loss: 0.000108
Train Epoch: 20 [45568/60000 (75%)]    Loss: 0.000108

Test set: Average loss: 0.0224, Accuracy: 19922/20000 (100%)

训练的精度已经到 1 了!说明效果相当不错,记得保存一下模型,方便下次使用

代码语言:javascript
复制
# 保存整个模型
torch.save(model, './data/model.pkl')

# # 仅保存和加载模型参数(推荐使用)
torch.save(model.state_dict(), './data/model_only_weighrt.pkl')
预测
代码语言:javascript
复制
def predict(model, data, label):
    model.eval()
    with torch.no_grad():
        data, label = data.to(DEVICE), label.to(DEVICE)
        output = model(data)
        pred = output.max(-1)[1]
        print(pred.eq(label.view_as(pred)).sum().item()/BATCH_SIZE)


val_dataSet = My_DataSet("./data/MNIST100000_init/validation/", transform=transform)
val_dataSet_loader = DataLoader(val_dataSet, batch_size=BATCH_SIZE, shuffle=True)
for data, label in val_dataSet_loader:
    break
代码语言:javascript
复制
model_load = torch.load('./data/model.pkl')
predict(model_load, data, label)
# 预测结果
# 0.994140625
总结

本文实现从原始数据(图片)到卷积神经网络的设计,一步一步的实现 MNIST 的分类器的训练,练习了如何制作自己的数据集类,当面对一个新的问题,就知道懂得建立自己的数据集类,而不是无从下手!

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 说在前面
  • 原始数据处理
  • 根据原始数据创建数据集自己的类
  • 简单可视化
  • 构建训练、测试数据加载器
  • 构建卷积神经网络
  • 预测
  • 总结
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档