pytorch入门教程 | 第五章:训练和测试CNN

我们 按照 pytorch入门教程(四):准备图片数据集准备好了图片数据以后,就来训练一下识别这10类图片的cnn神经网络吧。

按照 pytorch入门教程(三):构造一个小型CNN构建好一个神经网络,唯一不同的地方就是我们这次训练的是彩色图片,所以第一层卷积层的输入应为3个channel。修改完毕如下:

我们准备了训练集和测试集,并构造了一个CNN。与之前LeNet不同在于conv1的第一个参数1改成了3

现在咱们开始训练

我们训练这个网络必须经过4步:

第一步:将输入input向前传播,进行运算后得到输出output

第二步:将output再输入loss函数,计算loss值(是个标量)

第三步:将梯度反向传播到每个参数

第四步:利用下面公式进行权重更新

新权重w = 旧权重w + 学习速率? x 梯度向量g

非常幸运,pytorch帮我们写好了计算loss的函数和优化的函数。

我们先初始化loss和优化函数:

criterion = nn.CrossEntropyLoss() #叉熵损失函数 optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) #使用SGD(随机梯度下降)优化,学习率为0.001,动量为0.9

待会我们就要用到这两个函数

假设我们需要对训练数据完全遍历两次,人话就是:我们把所有训练集的数据扔进去进行训练,但是扔一次怎么够呢,扔一次并不能保证我的网络的参数就训练的很完美了,那么我们就会反复将训练集的数据扔进去训练,每次扔的时候,数据的顺序是不一样的。

这里我们就先扔两次练练。

训练网络

先不管running_loss,它是我们待会用来统计loss的平均值的。

我们先看data,data是从trainloader中枚举出来的,数据的结构看上面注释。

我们在训练前,会将网络中每个参数的grad值清空为0,这样做是因为grad值是累加的,设置为0后,每次bp后的grad更新后的值才是正确的。

我们将inputs输入net之后,得到outputs,将outputs和labels输入之前定义的叉熵函数计算loss值。除了叉熵方式计算外还有其他计算loss的方法。

loss算完后,我们就使用backward向后传播啦!我们稍微想一下传播会怎么进行,传播应该会让每一个网络参数的grad值进行更新,我们网络中的每一个参数都是Variable类型,并且均是叶子节点,grad值必然会进行更新。

接下来,每个参数利用自身的grad值进行梯度下降法的更新就好了,我们利用先前定义好的optimizer使用step()函数进行更新。好了!讲了这么久,我们将代码下载下来溜溜,看看是什么情况!下载cnn.py(https://pan.baidu.com/s/1hrNeyEw)

如果没错的话,跑完你应该会看到下图(loss平均值每次跑都会有变化的,因为咱们的loader设置了shuffle=True):

如图,我们的训练数据被我们扔进去了两遍,而且每2000批数据我们打印一次平均loss值,请注意不断减小的loss值,证明我们的网络正在被优化啊!!!!

好了,训练完之后,我们当然我测试一下我们的网络的分类的正确率到底是多少

上代码:

测试部分

关于total值我们可以设为10000,因为我们知道训练集中的图片数量就是10000,但是为了泛化,我们还是老老实实的点人头。一开始我们设置correct和total都为0。

我们要计算正确率,就用

正确数/全部数量

我解释一下第92行代码,outputs.data是一个4x10张量,max函数会将每一行的最大的那一列的值和序号各自组成一个一维张量返回,第一个是值的张量,第二个是序号的张量。我想还是举个例子吧:

随机生成了4x10的tensor,然后max函数会帮我们挑出每一行最大的那个值,比如第一行第10个,第二行是第9个,第三行是第5个,第四行是第10个。而[9,8,5,9]正是表示这些数的位置(从0开始算)

那么为啥输出的outputs是个4x10的张量呢,我们试着想一下,假设我们现在输入的是一张图片,那么出来的是一个10维的特征向量,因为我们同时输入了4张,所以就是4x10啦!

第93行,我们的labels是4维的向量,size(0)就是4,即没次total都加4。

第94行,两个4维向量逐行对比,相同的行记为1,不同的行记为0,再利用sum(),求各元素总和,得到相同的个数。这个不懂也可以自己命令行上试试:

只有第1和第3个元素相同,使用sum的话则会等于2。

完整代码在这(https://pan.baidu.com/s/1qYEWJFm),我们下载后跑一下:

正确率为52%

童鞋们可以看到,训练2遍结果是不够好的,才52%!大家可以回去把循环次数2改成10试试。

原文发布于微信公众号 - 人工智能LeadAI(atleadai)

原文发表时间:2017-09-05

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏量化投资与机器学习

深度学习Matlab工具箱代码注释之cnnsetup.m

%%========================================================================= %...

2545
来自专栏有趣的Python

6- OpenCV+TensorFlow 入门人工智能图像处理-图片移位

1987
来自专栏张俊红

机器学习中的参数调整

总第102篇 前言 我们知道每个模型都有很多参数是可以调节的,比如SVM中使用什么样的核函数以及C值的大小,决策树中树的深度等。在特征选好、基础模型选好以后我们...

3637
来自专栏贾志刚-OpenCV学堂

基于OpenCV实现手写体数字训练与识别

OpenCV实现手写体数字训练与识别 机器学习(ML)是OpenCV模块之一,对于常见的数字识别与英文字母识别都可以做到很高的识别率,完成这类应用的主要思想与方...

4116
来自专栏落影的专栏

Metal视频处理——绿幕视频合成

Metal入门教程总结 Metal图像处理——直方图均衡化 本文介绍如何用Metal把一个带绿幕的视频和一个普通视频进行合并。

1003
来自专栏人工智能LeadAI

Char RNN原理介绍以及文本生成实践

Char-RNN,字符级循环神经网络,出自于Andrej Karpathy写的The Unreasonable Effectiveness of Recurre...

871
来自专栏李智的专栏

Python针对图像的基础操作

5. 返回目录中所有JPG 图像的文件名列表,直方图均衡化,平均图像,主成分分析等

942
来自专栏利炳根的专栏

学习笔记CB012: LSTM 简单实现、完整实现、torch、小说训练word2vec lstm机器人

LSTM(Long Short Tem Memory)特殊递归神经网络,神经元保存历史记忆,解决自然语言处理统计方法只能考虑最近n个词语而忽略更久前词语的问题。...

4146
来自专栏数据处理

scikit-learning小试牛刀

1292
来自专栏Python数据科学

Seaborn从零开始学习教程(四)

数据集中的数据类型有很多种,除了连续的特征变量之外,最常见的就是类目型的数据类型了,常见的比如人的性别,学历,爱好等。这些数据类型都不能用连续的变量来表示,而是...

722

扫描关注云+社区