专栏首页FSociety基于PyTorch实现MNIST手写字识别

基于PyTorch实现MNIST手写字识别

本篇不涉及模型原理,只是分享下代码。想要了解模型原理的可以去看网上很多大牛的博客。

目前代码实现了CNN和LSTM两个网络,整个代码分为四部分:

  • Config:项目中涉及的参数;
  • CNN:卷积神经网络结构;
  • LSTM:长短期记忆网络结构;
  • TrainProcess: 模型训练及评估,参数model控制训练何种模型(CNN or LSTM)。

完整代码

Talk is cheap, show me the code.

# -*- coding: utf-8 -*-

# @author: Awesome_Tang
# @date: 2019-04-05
# @version: python3.7

import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from datetime import datetime


class Config:
    batch_size = 64
    epoch = 10
    alpha = 1e-3

    print_per_step = 100  # 控制输出


class CNN(nn.Module):

    def __init__(self):
        super(CNN, self).__init__()
        """
        Conv2d参数:
        第一位:input channels  输入通道数
        第二位:output channels 输出通道数
        第三位:kernel size 卷积核尺寸
        第四位:stride 步长,默认为1
        第五位:padding size 默认为0,不补
        """
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

        self.fc1 = nn.Sequential(
            nn.Linear(64 * 5 * 5, 128),
            nn.BatchNorm1d(128),
            nn.ReLU()
        )

        self.fc2 = nn.Sequential(
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),  # 加快收敛速度的方法(注:批标准化一般放在全连接层后面,激活函数层的前面)
            nn.ReLU()
        )

        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x


class LSTM(nn.Module):
    def __init__(self):
        super(LSTM, self).__init__()

        self.lstm = nn.LSTM(
            input_size=28,
            hidden_size=64,
            num_layers=1,
            batch_first=True,
        )

        self.output = nn.Linear(64, 10)

    def forward(self, x):
        r_out, (_, _) = self.lstm(x, None)

        out = self.output(r_out[:, -1, :])
        return out


class TrainProcess:

    def __init__(self, model="CNN"):
        self.train, self.test = self.load_data()
        self.model = model
        if self.model == "CNN":
            self.net = CNN()
        elif self.model == "LSTM":
            self.net = LSTM()
        else:
            raise ValueError('"CNN" or "LSTM" is expected, but received "%s".' % model)
        self.criterion = nn.CrossEntropyLoss()  # 定义损失函数
        self.optimizer = optim.Adam(self.net.parameters(), lr=Config.alpha)

    @staticmethod
    def load_data():
        print("Loading Data......")
        """加载MNIST数据集,本地数据不存在会自动下载"""
        train_data = datasets.MNIST(root='./data/',
                                    train=True,
                                    transform=transforms.ToTensor(),
                                    download=True)

        test_data = datasets.MNIST(root='./data/',
                                   train=False,
                                   transform=transforms.ToTensor())

        # 返回一个数据迭代器
        # shuffle:是否打乱顺序
        train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                                   batch_size=Config.batch_size,
                                                   shuffle=True)

        test_loader = torch.utils.data.DataLoader(dataset=test_data,
                                                  batch_size=Config.batch_size,
                                                  shuffle=False)
        return train_loader, test_loader

    def train_step(self):
        steps = 0
        start_time = datetime.now()

        print("Training & Evaluating based on '%s'......" % self.model)
        for epoch in range(Config.epoch):
            print("Epoch {:3}.".format(epoch + 1))

            for data, label in self.train:
                data, label = Variable(data.cpu()), Variable(label.cpu())
                # LSTM输入为3维,CNN输入为4维
                if self.model == "LSTM":
                    data = data.view(-1, 28, 28)
                self.optimizer.zero_grad()  # 将梯度归零
                outputs = self.net(data)  # 将数据传入网络进行前向运算
                loss = self.criterion(outputs, label)  # 得到损失函数
                loss.backward()  # 反向传播
                self.optimizer.step()  # 通过梯度做一步参数更新

                # 每100次打印一次结果
                if steps % Config.print_per_step == 0:
                    _, predicted = torch.max(outputs, 1)
                    correct = int(sum(predicted == label))  # 计算预测正确个数
                    accuracy = correct / Config.batch_size  # 计算准确率
                    end_time = datetime.now()
                    time_diff = (end_time - start_time).seconds
                    time_usage = '{:3}m{:3}s'.format(int(time_diff / 60), time_diff % 60)
                    msg = "Step {:5}, Loss:{:6.2f}, Accuracy:{:8.2%}, Time usage:{:9}."
                    print(msg.format(steps, loss, accuracy, time_usage))

                steps += 1

        test_loss = 0.
        test_correct = 0
        for data, label in self.test:
            data, label = Variable(data.cpu()), Variable(label.cpu())
            if self.model == "LSTM":
                data = data.view(-1, 28, 28)
            outputs = self.net(data)
            loss = self.criterion(outputs, label)
            test_loss += loss * Config.batch_size
            _, predicted = torch.max(outputs, 1)
            correct = int(sum(predicted == label))
            test_correct += correct

        accuracy = test_correct / len(self.test.dataset)
        loss = test_loss / len(self.test.dataset)
        print("Test Loss: {:5.2f}, Accuracy: {:6.2%}".format(loss, accuracy))

        end_time = datetime.now()
        time_diff = (end_time - start_time).seconds
        print("Time Usage: {:5.2f} mins.".format(time_diff / 60.))


if __name__ == "__main__":
    p = TrainProcess(model='CNN')
    p.train_step()

Peace~~

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • Python爬取猫眼「碟中谍」全部评论

    评论算保存完了,近期会再做一个关于此次数据的可视化分析。另外阿汤哥真心太帅了,全程打肾上腺素,各位还没去看的赶紧~

    Awesome_Tang
  • Python爬取猫眼「碟中谍」全部评论

    昨天晚上看完碟中谍后,有点小激动,然后就有了这片文章。 我们将猫眼上碟中谍的全部评论保存下来,用于后期分析~ 总共评论3W条左右。

    Awesome_Tang
  • 推荐算法之协同过滤介绍以及Python实现

    以上来自于百度百科介绍,协同过滤(collaborative filtering)在我们推荐系统中发挥了巨大作用,譬如抖音会基于你的点赞记录等推送视频,淘宝会基...

    Awesome_Tang
  • PyTorch最佳实践,怎样才能写出一手风格优美的代码

    虽然这是一个非官方的 指南,但本文总结了一年多使用 PyTorch 框架的经验,尤其是用它开发深度学习相关工作的最优解决方案。请注意,我们分享的经验大多是从研...

    磐创AI
  • PyTorch最佳实践,怎样才能写出一手风格优美的代码

    虽然这是一个非官方的 PyTorch 指南,但本文总结了一年多使用 PyTorch 框架的经验,尤其是用它开发深度学习相关工作的最优解决方案。请注意,我们分享的...

    机器之心
  • Github项目推荐 | PyTorch代码规范最佳实践和样式指南

    AI 科技评论按,本文不是 Python 的官方风格指南。本文总结了使用 PyTorch 框架进行深入学习的一年多经验中的最佳实践。本文分享的知识主要是以研究的...

    AI科技评论
  • 小白学PyTorch | 6 模型的构建访问遍历存储(附代码)

    torch.nn.Module是所有网络的基类,在PyTorch实现模型的类中都要继承这个类(这个在之前的课程中已经提到)。在构建Module中,Module是...

    机器学习炼丹术
  • PHP数据源架构模式之表入口模式实例分析

    martin fowler在《企业应用架构模式》一书中将我们平常接触到的应用开发分为三层:表现层、领域层和数据源层。

    砸漏
  • 七夕快到了,教你用python去表白!

    py3study
  • 剑指offer【50~59】

    排序数组,很明显二分查找,找到第一个 >= k 的元素索引以及第一个 > k 的元素索引,两者相减即为答案,即 lowerBound - upperBound。...

    echobingo

扫码关注云+社区

领取腾讯云代金券