全连接层解决MNIST

Tensorflow

一开始呢,让我们先了解一下tensorflow的运行方式。简单来说,我们使用tensorflow的步骤一共有三个:创建图,运行图,保存图。

Tensorflow标志

Tensorflow的计算是在图(graph)里面计算的,因此我们必须按照自己的需求来设计一张图。当然图的意思不是代表图片,而是代表一种结构。当创建好图之后,我们导入数据(也叫喂数据)来运行这张图得到我们要的结果。运行的过程中我们需要调整自己的参数。假如结果符合我们的要求,我们就保存这张图和里面的数据。

即使看的不明所以也没关系,接下来我们会用最简单的一种结构来解决MNIST数据集。在使用的途中你会对tensorflow更加了解。

MNIST

MNIST数据集是一个手写数字训练集(handwritten digit database)。里面有0到9的手写数字图片,并帮你打上了标签。打上标签的意思它有一个文件写明了图片代表的数字。

MNIST部分图片

MNIST是一个很有用的数据集,在接下来的时间里,我们会针对它不断提高我们神经网络的复杂度进而提高我们的网络的准确率。

全连接层

全连接层(full-connected layer),顾名思义,是将前面层的节点全部连接然后通过自己之后传入下一层。

前面讲到我们需要创建图,然后喂数据来运行。传入的数据被我们称为输入层。在处理MNIST数据集的时候,我们把每个像素都作为输入的数据,然后分批导入图片。

MNIST的每张图片的分辨率都为28*28,那么输入层一共有784个节点(即每个像素都是一个节点)。之所以这样设置,是因为每个像素都包含了图片的信息,它们共同决定了这张图片的数字。

然后我们设置全连接层的形状(shape)为[784, 10]。因为我们只有一层全连接层,它接受输入层的784个节点然后输出十个节点(十个分类)。如下图所示,X代表图片的某个像素,经过全连接层层后输出十个值,最大值即是网络的结果。

全连接层示意图

代码解析

导入需要的包

相信掌握python的人对于"import ... as ..."的用法不会陌生。Tensorflow可以通过第三句导入MNIST数据集,命名为input_data。

构建tensorflow的图

这里就在构建一个图了。tf.placehoder是创建一个占位符,用来接受输入的数据。在这里我们创建x,y来分别接受传入的图片和对应的标签。需要注意的是,tf.float32是tensorflow里面的float类型,而后面的[None, 784]代表了占位符的形状。前文提到,我们将784个像素作为输入,但我们一次性输入100张图片,所以输入会是一个[batch_size, 784]的矩阵。用None表示数量可以产生变化。

tf.Variable就是创建一个变量。权重的参数都应该设置为变量,因为它在训练的时候需要被更新,在测试的时候又能需要不产生变化。这里的有W和b,tf.zeros把他们初始化成形状为[784, 10]和[10]但值全为0的矩阵。

定义需要的变量

prediction是预测值。我们网络最后会导出一个[batch_size, 10]的张量出来,利用softmax我们可以得到预测值。详细的地方会在后面解释,但softmax主要还是选择数据经过网络之后十个值之中最大的那个作为预测值。

softmax函数

tf.argmax可以取张量里面一维的最大值。那么取出每一张图片标签和预测值里面的分类,再判断是否相等就可以得到准确率(correct_prediction)。

tf.reduce_mean把准确率平均就能求出平均准确率。tf.cast使得准确率做成浮点数不会平均的时候省略小数部分。

loss是损失值。由于神经网络得到的分类并不一定正确,所以不正确的估计我们会传递回去作为一个损失激励权重更新。而如何确定loss的大小就是用损失函数来决定。这里的损失函数是将y减去网络的预测值然后平方取平均。

train_step节点代表利用梯度下降法来降低loss值。换句话说,它告诉我们需要求loss对权重的梯度来更新权重。这方面涉及到权重的更新方法,会在后面详细介绍。

init代表初始化所有变量的操作。这又要重新提一下,我们到这里也只是画好了一个图。我们在图里面放了很多节点,但到这里它都没产生任何值。

图的结构

运行构建好的图

with tf.Session() as sess代表之后我们开始运行。首先我们都会开始sess.run(init)来运行init这个操作,即现在才开始初始化变量的操作。epoch代表迭代的次数。迭代代表跑完一整个数据集。

feed_dict是代表你喂的数据的字典。将batch_xs, batch_ys都放置在对应的占位符x, y上,此时每次运行x, y都是我们得到的新的批次的数据。接着是运行准确率的节点,调用的是测试集的图片。

我们会得到这样的数据:

打印出来的记录

你会发现准确率到一定的值就上升不了了。这是因为我们的网络过于简陋。在接下来的课程我们会加入卷积层,池化层,正则化等部分来改善识别的能力。

但是下一篇文章我们会继续深入这个网络来讲权重更新的细节。

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20180729G1J13A00?refer=cp_1026
  • 腾讯「云+社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 yunjia_community@tencent.com 删除。

扫码关注云+社区

领取腾讯云代金券