pytorch入门教程 | 第五章:训练和测试CNN

我们 按照 pytorch入门教程(四):准备图片数据集准备好了图片数据以后,就来训练一下识别这10类图片的cnn神经网络吧。

按照 pytorch入门教程(三):构造一个小型CNN构建好一个神经网络,唯一不同的地方就是我们这次训练的是彩色图片,所以第一层卷积层的输入应为3个channel。修改完毕如下:

我们准备了训练集和测试集,并构造了一个CNN。与之前LeNet不同在于conv1的第一个参数1改成了3

现在咱们开始训练

我们训练这个网络必须经过4步:

第一步:将输入input向前传播,进行运算后得到输出output

第二步:将output再输入loss函数,计算loss值(是个标量)

第三步:将梯度反向传播到每个参数

第四步:利用下面公式进行权重更新

新权重w = 旧权重w + 学习速率? x 梯度向量g

非常幸运,pytorch帮我们写好了计算loss的函数和优化的函数。

我们先初始化loss和优化函数:

criterion = nn.CrossEntropyLoss() #叉熵损失函数 optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) #使用SGD(随机梯度下降)优化,学习率为0.001,动量为0.9

待会我们就要用到这两个函数

假设我们需要对训练数据完全遍历两次,人话就是:我们把所有训练集的数据扔进去进行训练,但是扔一次怎么够呢,扔一次并不能保证我的网络的参数就训练的很完美了,那么我们就会反复将训练集的数据扔进去训练,每次扔的时候,数据的顺序是不一样的。

这里我们就先扔两次练练。

训练网络

先不管running_loss,它是我们待会用来统计loss的平均值的。

我们先看data,data是从trainloader中枚举出来的,数据的结构看上面注释。

我们在训练前,会将网络中每个参数的grad值清空为0,这样做是因为grad值是累加的,设置为0后,每次bp后的grad更新后的值才是正确的。

我们将inputs输入net之后,得到outputs,将outputs和labels输入之前定义的叉熵函数计算loss值。除了叉熵方式计算外还有其他计算loss的方法。

loss算完后,我们就使用backward向后传播啦!我们稍微想一下传播会怎么进行,传播应该会让每一个网络参数的grad值进行更新,我们网络中的每一个参数都是Variable类型,并且均是叶子节点,grad值必然会进行更新。

接下来,每个参数利用自身的grad值进行梯度下降法的更新就好了,我们利用先前定义好的optimizer使用step()函数进行更新。好了!讲了这么久,我们将代码下载下来溜溜,看看是什么情况!下载cnn.py(https://pan.baidu.com/s/1hrNeyEw)

如果没错的话,跑完你应该会看到下图(loss平均值每次跑都会有变化的,因为咱们的loader设置了shuffle=True):

如图,我们的训练数据被我们扔进去了两遍,而且每2000批数据我们打印一次平均loss值,请注意不断减小的loss值,证明我们的网络正在被优化啊!!!!

好了,训练完之后,我们当然我测试一下我们的网络的分类的正确率到底是多少

上代码:

测试部分

关于total值我们可以设为10000,因为我们知道训练集中的图片数量就是10000,但是为了泛化,我们还是老老实实的点人头。一开始我们设置correct和total都为0。

我们要计算正确率,就用

正确数/全部数量

我解释一下第92行代码,outputs.data是一个4x10张量,max函数会将每一行的最大的那一列的值和序号各自组成一个一维张量返回,第一个是值的张量,第二个是序号的张量。我想还是举个例子吧:

随机生成了4x10的tensor,然后max函数会帮我们挑出每一行最大的那个值,比如第一行第10个,第二行是第9个,第三行是第5个,第四行是第10个。而[9,8,5,9]正是表示这些数的位置(从0开始算)

那么为啥输出的outputs是个4x10的张量呢,我们试着想一下,假设我们现在输入的是一张图片,那么出来的是一个10维的特征向量,因为我们同时输入了4张,所以就是4x10啦!

第93行,我们的labels是4维的向量,size(0)就是4,即没次total都加4。

第94行,两个4维向量逐行对比,相同的行记为1,不同的行记为0,再利用sum(),求各元素总和,得到相同的个数。这个不懂也可以自己命令行上试试:

只有第1和第3个元素相同,使用sum的话则会等于2。

完整代码在这(https://pan.baidu.com/s/1qYEWJFm),我们下载后跑一下:

正确率为52%

童鞋们可以看到,训练2遍结果是不够好的,才52%!大家可以回去把循环次数2改成10试试。

原文发布于微信公众号 - 人工智能LeadAI(atleadai)

原文发表时间:2017-09-05

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏MelonTeam专栏

ArrayList源码完全分析

导语: 这里分析的ArrayList是使用的JDK1.8里面的类,AndroidSDK里面的ArrayList基本和这个一样。 分析的方式是逐个API进行解析 ...

4479
来自专栏项勇

笔记68 | 切换fragmengt的replace和add方法笔记

1444
来自专栏开发与安全

算法:AOV网(Activity on Vextex Network)与拓扑排序

在一个表示工程的有向图中,用顶点表示活动,用弧表示活动之间的优先关系,这样的有向图为顶点表示活动的网,我们称之为AOV网(Activity on Vextex ...

2517
来自专栏Phoenix的Android之旅

Java 集合 Vector

List有三种实现,ArrayList, LinkedList, Vector, 它们的区别在于, ArrayList是非线程安全的, Vector则是线程安全...

662
来自专栏计算机视觉与深度学习基础

Leetcode 114 Flatten Binary Tree to Linked List

Given a binary tree, flatten it to a linked list in-place. For example, Given...

1938
来自专栏alexqdjay

HashMap 多线程下死循环分析及JDK8修复

1K4
来自专栏java闲聊

JDK1.8 ArrayList 源码解析

当运行 ArrayList<Integer> list = new ArrayList<>() ; ,因为它没有指定初始容量,所以它调用的是它的无参构造

1192
来自专栏刘君君

JDK8的HashMap源码学习笔记

3008
来自专栏xingoo, 一个梦想做发明家的程序员

AOE关键路径

这个算法来求关键路径,其实就是利用拓扑排序,首先求出,每个节点最晚开始时间,再倒退求每个最早开始的时间。 从而算出活动最早开始的时间和最晚开始的时间,如果这两个...

2507
来自专栏xingoo, 一个梦想做发明家的程序员

20120918-向量实现《数据结构与算法分析》

#include <iostream> #include <list> #include <string> #include <vector> #include...

1716

扫码关注云+社区