专栏首页FSociety基于PyTorch实现MNIST手写字识别

基于PyTorch实现MNIST手写字识别

本篇不涉及模型原理,只是分享下代码。想要了解模型原理的可以去看网上很多大牛的博客。

目前代码实现了CNN和LSTM两个网络,整个代码分为四部分:

  • Config:项目中涉及的参数;
  • CNN:卷积神经网络结构;
  • LSTM:长短期记忆网络结构;
  • TrainProcess: 模型训练及评估,参数model控制训练何种模型(CNN or LSTM)。

完整代码

Talk is cheap, show me the code.

# -*- coding: utf-8 -*-

# @author: Awesome_Tang
# @date: 2019-04-05
# @version: python3.7

import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from datetime import datetime


class Config:
    batch_size = 64
    epoch = 10
    alpha = 1e-3

    print_per_step = 100  # 控制输出


class CNN(nn.Module):

    def __init__(self):
        super(CNN, self).__init__()
        """
        Conv2d参数:
        第一位:input channels  输入通道数
        第二位:output channels 输出通道数
        第三位:kernel size 卷积核尺寸
        第四位:stride 步长,默认为1
        第五位:padding size 默认为0,不补
        """
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

        self.fc1 = nn.Sequential(
            nn.Linear(64 * 5 * 5, 128),
            nn.BatchNorm1d(128),
            nn.ReLU()
        )

        self.fc2 = nn.Sequential(
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),  # 加快收敛速度的方法(注:批标准化一般放在全连接层后面,激活函数层的前面)
            nn.ReLU()
        )

        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x


class LSTM(nn.Module):
    def __init__(self):
        super(LSTM, self).__init__()

        self.lstm = nn.LSTM(
            input_size=28,
            hidden_size=64,
            num_layers=1,
            batch_first=True,
        )

        self.output = nn.Linear(64, 10)

    def forward(self, x):
        r_out, (_, _) = self.lstm(x, None)

        out = self.output(r_out[:, -1, :])
        return out


class TrainProcess:

    def __init__(self, model="CNN"):
        self.train, self.test = self.load_data()
        self.model = model
        if self.model == "CNN":
            self.net = CNN()
        elif self.model == "LSTM":
            self.net = LSTM()
        else:
            raise ValueError('"CNN" or "LSTM" is expected, but received "%s".' % model)
        self.criterion = nn.CrossEntropyLoss()  # 定义损失函数
        self.optimizer = optim.Adam(self.net.parameters(), lr=Config.alpha)

    @staticmethod
    def load_data():
        print("Loading Data......")
        """加载MNIST数据集,本地数据不存在会自动下载"""
        train_data = datasets.MNIST(root='./data/',
                                    train=True,
                                    transform=transforms.ToTensor(),
                                    download=True)

        test_data = datasets.MNIST(root='./data/',
                                   train=False,
                                   transform=transforms.ToTensor())

        # 返回一个数据迭代器
        # shuffle:是否打乱顺序
        train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                                   batch_size=Config.batch_size,
                                                   shuffle=True)

        test_loader = torch.utils.data.DataLoader(dataset=test_data,
                                                  batch_size=Config.batch_size,
                                                  shuffle=False)
        return train_loader, test_loader

    def train_step(self):
        steps = 0
        start_time = datetime.now()

        print("Training & Evaluating based on '%s'......" % self.model)
        for epoch in range(Config.epoch):
            print("Epoch {:3}.".format(epoch + 1))

            for data, label in self.train:
                data, label = Variable(data.cpu()), Variable(label.cpu())
                # LSTM输入为3维,CNN输入为4维
                if self.model == "LSTM":
                    data = data.view(-1, 28, 28)
                self.optimizer.zero_grad()  # 将梯度归零
                outputs = self.net(data)  # 将数据传入网络进行前向运算
                loss = self.criterion(outputs, label)  # 得到损失函数
                loss.backward()  # 反向传播
                self.optimizer.step()  # 通过梯度做一步参数更新

                # 每100次打印一次结果
                if steps % Config.print_per_step == 0:
                    _, predicted = torch.max(outputs, 1)
                    correct = int(sum(predicted == label))  # 计算预测正确个数
                    accuracy = correct / Config.batch_size  # 计算准确率
                    end_time = datetime.now()
                    time_diff = (end_time - start_time).seconds
                    time_usage = '{:3}m{:3}s'.format(int(time_diff / 60), time_diff % 60)
                    msg = "Step {:5}, Loss:{:6.2f}, Accuracy:{:8.2%}, Time usage:{:9}."
                    print(msg.format(steps, loss, accuracy, time_usage))

                steps += 1

        test_loss = 0.
        test_correct = 0
        for data, label in self.test:
            data, label = Variable(data.cpu()), Variable(label.cpu())
            if self.model == "LSTM":
                data = data.view(-1, 28, 28)
            outputs = self.net(data)
            loss = self.criterion(outputs, label)
            test_loss += loss * Config.batch_size
            _, predicted = torch.max(outputs, 1)
            correct = int(sum(predicted == label))
            test_correct += correct

        accuracy = test_correct / len(self.test.dataset)
        loss = test_loss / len(self.test.dataset)
        print("Test Loss: {:5.2f}, Accuracy: {:6.2%}".format(loss, accuracy))

        end_time = datetime.now()
        time_diff = (end_time - start_time).seconds
        print("Time Usage: {:5.2f} mins.".format(time_diff / 60.))


if __name__ == "__main__":
    p = TrainProcess(model='CNN')
    p.train_step()

Peace~~

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 基于TensorFlow的CNN实现Mnist手写数字识别

    本文实例为大家分享了基于TensorFlow的CNN实现Mnist手写数字识别的具体代码,供大家参考,具体内容如下

    砸漏
  • 用PyTorch实现MNIST手写数字识别(非常详细)

    MNIST可以说是机器学习入门的hello word了!导师一般第一个就让你研究MNIST,研究透了,也算基本入门了。好的,今天就来扯一扯学一学。

    小锋学长
  • 基于Tensorflow的MNIST手写数字识别分类

    本文实例为大家分享了基于Tensorflow的MNIST手写数字识别分类的具体实现代码,供大家参考,具体内容如下

    砸漏
  • Tensorflow | MNIST手写字识别

    原始的网址:https://www.tensorflow.org/versions/r0.12/tutorials/mnist/beginners/index....

    努力在北京混出人样
  • Pytorch框架实现mnist手写库识别(与tensorflow对比)

    前言最近在学习过程中需要用到pytorch框架,简单学习了一下,写了一个简单的案例,记录一下pytorch中搭建一个识别网络基础的东西。对应一位博主写的tens...

    砸漏
  • [机器学习] 用KNN识别MNIST手写字符实战

    Hi, 好久不见,粉丝涨了不少,我要再不更新,估计要掉粉了,今天有时间把最近做的一些工作做个总结,我用KNN来识别MNIST手写字符,主要是代码部分,全部纯手写...

    用户1622570
  • Caffe2 - (九)MNIST 手写字体识别

    AIHGF
  • MXNet | 手写字MNIST识别比赛

    比赛的官网:https://www.kaggle.com/c/digit-recognizer

    努力在北京混出人样
  • TensorFlow 2.0 识别MNIST手写数字

    用户6021899
  • Tensorflow MNIST CNN 手写数字识别

    Tesorflow实现基于MNIST数据集上简单CNN: https://github.com/Asurada2015/TF_Cookbook/blob/mas...

    演化计算与人工智能
  • 教程 | 基于LSTM实现手写数字识别

    基于tensorflow,如何实现一个简单的循环神经网络,完成手写数字识别,附完整演示代码。

    OpenCV学堂
  • 专栏 | 在PaddlePaddle上实现MNIST手写体数字识别

    机器之心专栏 来源:百度PaddlePaddle 不久之前,机器之心联合百度推出 PaddlePaddle 专栏,为想要学习这一平台的技术人员推荐相关教程与资源...

    机器之心
  • tensorflow基于CNN实战mnist手写识别(小白必看)

    很荣幸您能看到这篇文章,相信通过标题打开这篇文章的都是对tensorflow感兴趣的,特别是对卷积神经网络在mnist手写识别这个实例感兴趣。不管你是什么基础,...

    砸漏
  • 【典型案例】用 PyTorch 实现图像识别

    PyTorch 是一个开源的深度学习框架,对深度学习算法进行训练和优化,它被广泛应用在人工智能领域。更多详细介绍可参考 PyTorch 官网。

    腾讯云TI平台
  • 轻松学Pytorch –构建循环神经网络

    大家好,使用pytorch实现简单的人工神经网络跟卷积神经网络的mnist手写识别案例之后,今天给大家分享一下如何基于循环神经网络实现mnist手写数字识别。这...

    OpenCV学堂
  • 基于tensorflow的MNIST数字识别

    MNIST是一个非常有名的手写体数字识别数据集,在很多资料中,这个数据集都会作为深度学习的入门样例。下面大致介绍这个数据集的基本情况,并介绍temsorflow...

    狼啸风云
  • 上手必备!不可错过的TensorFlow、PyTorch和Keras样例资源

    TensorFlow、Keras和PyTorch是目前深度学习的主要框架,也是入门深度学习必须掌握的三大框架,但是官方文档相对内容较多,初学者往往无从下手。本人...

    AI科技大本营
  • 基于OpenCV实现手写体数字训练与识别

    OpenCV实现手写体数字训练与识别 机器学习(ML)是OpenCV模块之一,对于常见的数字识别与英文字母识别都可以做到很高的识别率,完成这类应用的主要思想与方...

    OpenCV学堂
  • Tensorflow入门-白话mnist手写数字识别

    文章目录 mnist数据集 简介 图片和标签 One-hot编码(独热编码) 神经网络的重要概念 输入(x)输出(y)、标签(label) 损失函数(loss ...

    小莹莹

扫码关注云+社区

领取腾讯云代金券