前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Pytorch实现简单的数字识别(上)

Pytorch实现简单的数字识别(上)

作者头像
用户6719124
发布2019-11-18 00:26:39
1.4K0
发布2019-11-18 00:26:39
举报

使用深度学习神经网络对数字识别,大体需要4个步骤:①读取数据。②建立模型。③训练。④测试、验证。

其基本流程示意图如下:

上图由左至右依次为输入层、神经层a、神经层b、输出层。即为input layer、function layer a 、function layer b、output layer。

  1. 读取数据

首先到http://yann.lecun.com/exdb/mnist/网站上下载mnist数据集,或者在代码中加入download代码,但速度较慢。

建立utils.py文件,写入工具代码

首先引入pytorch包

代码语言:javascript
复制
import torch
import matplotlib.pyplot as plt

定义第一个工具:用曲线表示梯度下降过程。

代码语言:javascript
复制
def plot_curve(data):
    # 先画一个曲线,以表示training下降的过程
    fig = plt.figure()
    plt.plot(range(len(data)), data, color='blue')
    plt.legend(['value'], loc='upper right')
    # 将辅助牌放置在上右侧
    plt.xlabel('step')
    # 输入x轴名称
    plt.ylabel('value')
    # 输入y轴名称
    plt.show()

定义第二个工具:用图像表示识别结果

代码语言:javascript
复制
def plot_result_image(img, label, name):
    # 以图像的方式输出识别出的结果
    fig = plt.figure()
    # 先输出空白图像
    for i in range(9):
        # 以迭代的方式,一次性输出9个图像
        plt.subplot(3, 3, i+1)
        # 3 * 3 的图片输出样式
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title('{}: {}'.format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
        # 分别将x和y的刻度值设定为坐标轴刻度
    plt.show()

定义第三个工具:对输出的结果采用one-hot编码

代码语言:javascript
复制
def one_hot(label, depth=10):
    out = torch.zero_(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1, 1)
    out.scatter_(dim=1, index=idx, value=1)
    return out

下面开始写main.py主代码

首先引入pytorch中的相关工具包

代码语言:javascript
复制
import torch
from torch import nn
# nn用于完成神经网络间的相关操作
from torch.nn import functional as F
# F为神经网络运算的常用计算包
from torch import optim
# 引入优化工具包
import torchvision
# 视觉类深度学习需要引入torchvision视觉工具包
from utils import plot_curve, plot_result_image, one_hot
# 从utils.py中导入定义的工具
from torch.utils.data import DataLoader,Dataset
import matplotlib.pyplot as plt

首先开始导入数据

代码语言:javascript
复制
batch_size = 512
# 同时并行处理512张图片
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist data', train=True, download=False,
                               # 更改download=True,以在线下载
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   # numpy数据转换为Tensor格式
                                   torchvision.transforms.Normalize(
                                       # 正则化
                                       (0.1307,), (0.3081,)
                                       # 设置参数,使数据均匀分布在0,1附近,以方便优化
                                   )
                               ])),
    batch_size=batch_size, shuffle=True
    # shuffle=True:随机打散
)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist data', train=False, download=False,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,)
                                   )
                               ])),
    batch_size=batch_size, shuffle=True)

下面开始构建神经网络

代码语言:javascript
复制
x, y = next(iter(train_loader))
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        # 准备构建三层神经网络,每一层都是xw+b函数
        self.fc1 = nn.Linear(28*28, 256)
        # 构建线性层
        # 初始图片的像素是28*28,默认是从“大层到小层”的过程
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)
        # 由于是十分类问题,因此最终输出是10个

    def forward(self, x):
        # 创建计算过程
        # x:[batch, 1, 28, 28]
        # 本次使用relu作为激活函数
        x = F.relu(self.fc1(x))
        # h1 = relu(xw + b1)
        x = F.relu(self.fc2(x))
        # h2 = relu(h1w2 + b2)
        # h3 = h2w3 + b3
        x = F.softmax(self.fc3(x))
        # 分类问题,可以加入softmax函数
        return x
        # 返回预测值

下节将介绍训练和测试部分

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-09-21,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 python pytorch AI机器学习实践 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档