大家好,今天我们进一步学习Pytorch的用法之正向传播(FeedForward)网络的用法。
在开始本次分享之前先跟我明确先要强调一下深度神经网络的训练思路,一般是这样一个四部曲。
(1) 明确输入内容
(2) 明确输出内容
(3) 明确网络结构
(4) 明确损失函数
不论是你自己搭建网络,还是阅读别人的代码来了解一个模型的结构,都是按照这样的路径去理解的。
只要这样几个东西明白以后,我们就可以依靠Pytorch提供的各种库来轻松实现一个训练过程了。当然,前提是咱们对于题设的分析和建模要正确。其实上次的线性回归的训练就是一个只有一层,一个节点的最简单的神经网络,大家想想看是不是。那么本次的任务是什么呢?
任务:
如图所示,我们需要有这样一个网络,让这个网络能够输入一张图片,输出一个分类标签。
注意点:
(1) 输入格式
输入的图片是一个28*28像素的图片,每个像素是一个灰度值
(2) 输出标签
训练的时候给到网络的是一个独热标签(one-hot)。例如,在这个例子中,我们期望输出的内容是“2”,但是并不是用实数“2”来做标签,而是用了表示实数2的分类概率的独热向量[0,0,1,0,0,0,0,0,0,0]做了标签;同理,如果是数字0,那么就是[1,0,0,0,0,0,0,0,0,0]来做标签。在多分类问题(非回归问题)中我们多用这种标签类型。
好,输入和输出确定了,我们就可以明确网络结构了。
(3) 网络结构
请先用pwd确认一下自己是否已经在实验环境的目录下,我们应该进入
~/pytorch-tutorial/tutorials/01-basics/feedforward_neural_network
里面有两个文件,一个是gpu版本的,一个是cpu版本的,我们只要读懂一个结构就可以了,以cpu版本为例好了。
使用命令:
vim main.py
图中框处了前33行的部分,这33行的部分可以分成4部分内容,分别是:
库导入,
超参数设置,
训练集和测试集下载,
训练集和测试集载入。
这4个部分其实都不重要,为什么这么说呢。库导入和超参数设置我们在前面的内容中见过了,没什么新鲜的。而后面的训练集和测试集下载则是torchvision模块中提供的MNIST数据集下载功能,封装得很好,不用我们再输入下载位置了。最后面的训练集和测试集载入,则是把刚刚下载的的训练集合测试集进行了“载入”,使它们成为我们想要的输入格式(这些部分都不是本模型的重点核心内容,是官方提供的一些帮助我们读取数据的工具。)。
是什么格式呢?其实是把每个图片处理成了一个1*784的向量,即1行784列的矩阵。这么处理主要是为了做全连接网络输入方便一些。
MNIST数据集是一个著名的第三方数据集,有60000张图片组成,其中50000张是训练集,10000张是测试集,看上去就是下面这个样子,是手写数字信息,每张图片都是一个28*28的图片
所谓处理成1*784的尺寸也就是把每张图片的像素都“拉直”变成一根横着的“直线”。
看一下36行到47行这一部分,定义了网络,网络长什么样子呢?
在网上我们随便就能找到一个神经网络的拓扑示意图,比如上面这个图吧,我们解释一下,输入是3维,两个隐藏层都是5个神经元,输出是一个4维的向量。
我们现在说的深度学习通常就是指基于深度神经网络的建模方式,而通常大于2个隐藏层的都可以统称为深度神经网络或者称这种机器学习方式叫深度学习。那么为什么大家都喜欢用深度学习的方式来建立模型呢?简单说有这样几个原因可以做个补充,给大家做以参考。
原因1、每个神经元都是一个小的模型或者分类器。大量的神经元进行叠加,可以使得分布的映射种类比较丰富,组合比较多样,使得分类能力提高。咱就这么简单理解吧,你说是一个If…Else…语句能够叙述的程序分支逻辑复杂,还是1000个If…Else…彼此组合嵌套所组成的程序分支逻辑复杂呢?当然是后者,这种复杂度越高,就越能够描述复杂的逻辑场景。
原因2、每个神经元的表达式比较简单,通常是一个y=wz+b的线性函数,模型含义相对好解释。而其后通常跟随一个非线性的激励函数(Activation Function),用来拟合输入与输出之间存在的非线性逻辑关系。有线性和非线性两个部分组合之后,通常就能对付那些输入与输出之间有着线性非线性两部分逻辑组合的场景,或者说大部分场景都能适用。
这两个原因就是我们喜欢用深度学习的比较重要的原因了,当然我这里解释得比较口语化,具体的学术层面的论述大家还是要去找相关的学术书籍了。由于单个神经元的分类能力比较有限,而随着神经元数量的加多加深,整个网络的分类能力就有了不断的提高,就能对付很多原来传统机器学习很难处理的内容。
我们再回来看36~47行,就比较容易懂了,通过代换里面的变量关系,可以得到这样一个图:
输入为x,784维(1*784)的矩阵。
fc1指的是一个全连接层(FullConnection),有500个神经元,其实是一个784*500的矩阵(具体怎么算我们很快会说)。
后面的relu是说的激励函数,对于fc1输出的这个1*500的矩阵,每一个维度值都过了一个非线性激励函数relu。
最后fc2又是一个全连接层,相当于一个500*10的矩阵。
中间橘色框表示的是一个完整的隐藏层,就是一个线性函数y=wz+b,和一个非线性函数relu叠加的过程。
为了把这个模型的内容说得比较明白,我再画一张详细一些的图解。
要先强调一下,这个图和上面那张图是完全等价的!只不过上面那张图画的是一个网络连接的拓扑图,这个图是表示的相同逻辑下的,矩阵乘法和向后传递的过程。
首先一个x是一个1*784维的向量,这在前面已经由26~33行部分的程序处理过了得到的。
nn.Linear定义了一个y=wx+b,只不过这是一个x是[1,784],w是[784,500]的矩阵相乘,然后再加上偏置(Bias)b。根据矩阵乘法的定义:
一个[1,784]的矩阵和一个[784,500]的矩阵相乘的话,相当于一个500次的循环。每一次循环都是这个[1,784]的矩阵和这个[784,500]的矩阵的第i列做点积,看上去就是两个784维的向量,每两个对应的维度做乘法,最后再把这784个值加在一起,再加上一个b,这就是y=wx+b的表达式,如果忘了这个部分什么含义的话,请复习一下上上节课的内容。因此,有500列,那就循环500次,每一次都会输出后面这个[1,500]矩阵的第1行的第i个元素,一共输出500个,结果为一个[1,500]尺寸的矩阵。
然后就过relu函数,就是这个红方块表示的部分。
表达式为f(x)=max(0,x),画出图像来就是:
请注意,这是一个典型非线性激活函数,一个x经过这样的映射后,就会在小于0的时候输出0,而大于0的时候输出x值。现在绝大部分的神经网络中都会使用到relu激活函数。
由于这个[1,500]的值每个元素都通过了relu函数,那么它的输出内容仍然是一个[1,500]的函数,只不过每个元素对应的都是被relu函数加工过的输入值而已。
最后那个500个元素和10个输出节点之间的全连接,在数学上就是一个[1,500]的矩阵和一个[500,10]的矩阵相乘。过程就不赘述了。
网络结构就是这个样子,当x通过这一系列的计算与映射就会跑到输出端产生输出结果。这个结果当然最开始非常不准确,因为图里那些橙色的矩阵的待定系数w和b都还没有被训练求出来。
(4) 损失函数
第53行定义了损失函数为交叉熵损失函数CrossEntropyLoss,54行声明了用Adam优化器(一种工作原理和随机梯度下降很相近的优化器)来进行优化。
什么是交叉熵损失函数呢?
首先,输出的内容是一个[1,10]的矩阵,标签内容也是一个[1,10]的矩阵,那么怎么来描述他们之间误差的大小呢?
注意这里例子中没有使用Softmax激励函数(这与Tensorflow有所不同)。我们怎么理解这个交叉熵的含义呢?这么想吧,当一个样本通过网络产生一个[1,10]的拟合值,必然和它期望的标签值——那个独热编码有差距,我们随便看一个小例子吧。
例如,输出x样本通过网络输出了一个矩阵为[0.5,0.5,0,0,0,0,0,0,0,0],
期望的标签为[0,0,1,0,0,0,0,0,0,0]。
那么你的网络分类器的分类能力和期待的分类能力的差距就是把每一位上的分类误差进行计算加和了。
公式里面y就是期望的标签值,那么y=1的时候,后面1-y=0后面一项全都削掉了,这时候期望lna=0。这里的a是指对应位置的输出值,也就是拟合出来的标签值,因为比0大都是误差值,如果lna=0,那么a必须等于1,也就是所谓标签为y,最好拟合出来的对应位置的输出值为1,因为此时的误差为0;后面那项是反过来的情况,就是y=0的时候生效,大家自己理解一下看是不是同一个道理。
那刚刚这个样本产生的损失值有多大呢?大概是这么理解
其余7项都是0。
损失函数的值就是所有的训练数据(一个epoch中),平均在每个样本上的误差值。那么优化的方向,就是使得整个网络中的待定系数们朝着减小这个误差值的方向做调整喽(例如让那些应该拟合出3的,对应的3的向量位置为1,而其余位置为0)。
(5) 训练
57行到72行的部分似曾相识。
65行就是正向传播一张图片;
66行计算损失函数大小;
67行做一次反向传播;
68行进行一次优化,整个网络的w和b都向着减小误差的方向挪一轮。
最外面两层循环,一个是epoch的数量,一个对train_loader中的每个元素(图片,标签)做循环。
(6) 测试
最后这个部分是在测试集上做验证,以防过拟合。
77行的循环就是遍历test_loader这个测试集了。
80行的这个部分中有一个torch.max函数,返回指定列中最大值的那个元素,且返回索引值。你如果打印出labels和predicted,你就会发现,它们都是用索引值1,2,3这些值来表示的,而不是一个独热向量。
最后做统计,计算正确预测的数量有多少。
(7) 运行
python main.py
如果你运行的话,你会发现,这个数据集上效果还不错,在测试集上正确率也是96~98%左右。
好了,这次我们就学会了建立一个最简单的一层的全连接网络。如果你有兴趣可以照猫画虎地把这个网络尝试加深一下,好好玩玩这个入门玩具。