前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >PyTorch入门笔记-手写数字问题

PyTorch入门笔记-手写数字问题

作者头像
触摸壹缕阳光
修改2021-03-25 19:23:24
9960
修改2021-03-25 19:23:24
举报
文章被收录于专栏:AI机器学习与深度学习算法

前面介绍了能够对连续值进行预测的简单线性回归模型,并使用梯度下降算法进行迭代求解。当然深度学习不仅能够处理连续值预测的回归问题,还能够处理预测固定离散值的分类问题。分类问题的一个典型应用就是自动识别图像中物体的种类,手写数字识别是常见的图像识别任务。

为了方便统一测试和评估算法,Yann LeCun 发布了名为 MNIST 的手写数字图片的数据集,MNIST 数据集包含 0~9 共 10 种数字的手写图片,每种数字一共有 7000 张图片,采集自不同书写风格的真实手写图片,一共 70000 张图片。70000 张手写数字图片使用 train_test_split 方法划分为 60000 张训练集(Training Set)和 10000 张测试集(Test Set)。如果将 70000 张手写数字图片全部作为模型的训练集,模型很可能过拟合,模型在训练集上表现很好,但是给模型一个新的数字图片进行预测,模型预测的结果会非常不好。

MNIST 数据集中每张图片都被缩放到 (28 x 28) 的大小,同时只保留了灰度信息。下图为 MNIST 数据集的样例图片。

下图为某个手写数字 8 的图片表示示意图。

我们现在还没有学到将这种图片表示的数字矩阵直接作为输入输入到网络中。简单的方法是将这种数字矩阵的特征图打平成特征向量,打平操作非常简单。比如下面将一个 (2 x 2) 的矩阵的打平成 (4, ) 的向量。

\left[\begin{matrix} 0 & 1 \\ 2 & 3 \end{matrix} \right] =>\begin{bmatrix} 0 \\ 1 \\ 2 \\ 3 \end{bmatrix}

将 (28 x 28) 的数字矩阵打平成 (784, ) 的特征向量,打平后的特征没有了位置信息。由于特征比较多,如果依然使用 for 循环等进行计算会耗费大量的时间,而使用 Numpy 模块中的矩阵运算可以利用 Numpy 中的并行化操作大幅度提高运算效率。打平后的图片特征为 (784, ) 的向量,如果想要使用矩阵运算需要为向量增加一个维度变成 (1 x 784) 的矩阵,此时的 1 代表的图片的数量,即输入的X = [图片数量, 图片特征]矩阵。

上一小节介绍了简单的线性模型 y = wx + b,显然一个简单的线性模型是不可能分类手写数字识别任务的,我们通常将几个线性函数进行嵌套:

\begin{aligned}H_1 &= XW_1 + b_1 \\H_2 &= H_1W_2 + b_2 \\H_3 &= H_2W_3 + b_3\end{aligned}

如果将 X = [1, d_0] 输入到嵌套线性函数中,各个变量的维度变化如下:

[1, d_0] @ [d_0, d_1] + [d_1] = [1, d_1]
[1, d_1] @ [d_1, d_2] + [d_2] = [1, d_2]
[1, d_2] @ [d_2, d_3] + [d_3] = [1, d_3]

将上面的线性函数结合在一起:

H_3 = \{[XW_1 + b_1]W_2 + b_2 \}W_3 + b_3

可以发现即使使用多个嵌套的线性函数结果依然是线性函数,但是处理这种复杂的图片识别,光靠线性的关系是不能够学习更深层次的知识,所以我们需要添加非线性的部分:激活函数。本小节使用 ReLU 激活函数,ReLU 也是现在比较流行的激活函数。加入 ReLU 激活函数的嵌套函数表达式:

H_3 = relu(\{relu([relu(XW_1 + b_1)]W_2 + b_2) \}W_3 + b_3)

三个非线性模型叠加依然是非线性模型,而且使模型的表达能力进一步增强。

如何将类别标签进行编码呢?

  • 如果将类别标签转换成数字编码,即用一个数字来表示标签信息,此时的输出只需要一个节点就可以表示网络的预测类别,即 d_3 = 1。但是数字编码有一个很大的问题,数字之间存在天然的大小关系,手写数字图片的 0~9 十个类别之间并没有大小关系,但是如果使用数字编码标签信息,可能导致模型迫使去学习 0 < 1 < 2这种数字大小的关系。如本小节题图所示;
  • 如果将类别标签转码成 one-hot 编码,即用一个包含 0 和 1 的向量来表示标签信息,向量的维度为标签类别的个数,由于手写数字识别的类别为 0~9 的十个类别,此时的输出需要十个节点,即 d_3 = 10。假设某个手写图片属于类别 i,即手写图片中的数字为 i,只需要一个长度为 10 的向量 y,向量 y 的索引号为 i 的元素设置为1,其余位置设置为 0;

「使用 one-hot 编码类别标签没有使用数字编码中的问题,所以通常类别标签使用这种 one-hot 编码的方式。」

有了这些准备接下来就可以使用梯度下降算法进行迭代求解,由于标签采用 one-hot 编码方式,预测输出 H_3 和真实标签 y 都是一个十维的向量,我们需要找到使得 H_3y 之间距离最小的参数 W, b,衡量两个向量之间距离最简单的方式是使用欧式距离: \sum(H_3 - y)^2,此时的损失函数 L = \sum(H_3 - y)^2。需要注意参数 W, b 包含了:

  • 第一个非线性函数的 W_1, b_1
  • 第二个非线性函数的 W_2, b_2
  • 第三个非线性函数的 W_3, b_3

References: 1. 龙良曲深度学习与PyTorch入门实战:https://study.163.com/course/introduction/1208894818.htm

原文地址:https://mp.weixin.qq.com/s/DvrRsW42GALhrppEAU8XbA

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-10-19,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI机器学习与深度学习算法 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档