TensorFlow-3: 用 feed-forward neural network 识别数字

今天继续看 TensorFlow Mechanics 101:

https://www.tensorflow.org/get_started/mnist/mechanics

完整版教程可以看中文版tutorial:

http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_tf.html

这一节讲了使用 MNIST 数据集训练并评估一个简易前馈神经网络(feed-forward neural network)

input,output 和前两节是一样的:即划分数据集并预测图片的 label

data_sets.train    55000个图像和标签(labels),作为主要训练集。
data_sets.validation    5000个图像和标签,用于迭代验证训练准确度。
data_sets.test    10000个图像和标签,用于最终测试训练准确度(trained accuracy)。

主要有两个代码:

mnist.py

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/mnist/mnist.py

构建一个全连接网络,由 2 个隐藏层,1 个 softmax_linearv输出构成

定义损失函数,用cross entropyv

定义训练时的优化器,用 GradientDescentOptimizer

定义评价函数,用tf.nn.in_top_k

fully_connected_feed.py

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/mnist/fully_connected_feed.py

  • placeholder_inputs 传入 batch size,得到 image 和 label 的两个placeholder
  • 定义生成 feed_dict的函数,key 是 placeholders,value 是 data
  • 定义 do_eval 函数,每隔 1000 个训练步骤,就对模型进行以下评估,分别作用于训练集、验证集和测试集
  • 训练时:
  • 导入数据
  • 得到 image 和 label 两个 placeholder
  • 传入 mnist.inference 定义的 NN, 得到 predictions
  • 将 predictions 传给 mnist.loss 计算 loss
  • loss 传给mnist.training 进行优化训练
  • 再用 mnist.evaluation 评价预测值和实际值

代码中涉及到下面几个函数:

with tf.Graph().as_default():

即所有已经构建的操作都要与默认的 tf.Graph 全局实例关联起来,tf.Graph 实例是一系列可以作为整体执行的操作

summary = tf.summary.merge_all():

为了释放 TensorBoard 所使用的 events file,所有的即时数据都要在图表构建时合并至一个操作 op 中,每次运行 summary 时,都会向 events file 中写入最新的即时数据

summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph):

用于写入包含了图表本身和即时数据具体值的 events file。

saver = tf.train.Saver():

就是向训练文件夹中写入包含了当前所有可训练变量值 checkpoint file

with tf.name_scope('hidden1'):

主要用于管理一个图里面的各种 op,返回的是一个以 scope_name命名的 context manager,一个 graph 会维护一个 name_space的堆,实现一种层次化的管理,避免各个 op 之间命名冲突。例如,如果额外使用 tf.get_variable() 定义的变量是不会被tf.name_scope() 当中的名字所影响的

tf.nn.in_top_k(logits, labels, 1):

意思是在 K 个最有可能的预测中如果可以发现 true,就将输出标记为 correct。本文 K 为 1,也就是只有在预测是 true 时,才判定它是 correct。

原创声明,本文系作者授权云+社区发表,未经许可,不得转载。

如有侵权,请联系 yunjia_community@tencent.com 删除。

编辑于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏小鹏的专栏

02 The TensorFlow Way(1)

The TensorFlow Way Introduction:          现在我们介绍了TensorFlow如何创建张量,使用变量和占位符,我们将介...

18610
来自专栏我是攻城师

理解算法的复杂度

在计算机科学中,算法的时间复杂度是一个函数,它定性描述该算法的运行时间,时间复杂度常用大O符号表示,不包括这个函数的低阶和首项系数,使用这种方式时,时间的复杂度...

862
来自专栏专知

【干货】计算机视觉实战系列03——用Python做图像处理

【导读】专知成员Hui上一次为大家介绍Matplotlib的使用,包括绘图,绘制点和线,以及图像的轮廓和直方图,这一次为大家详细讲解Numpy工具包中的各种工具...

39610
来自专栏lhyt前端之路

js随机数生成器的扩展0.前言1.扩展+分区2.二进制法3. 总结

给你一个能生成随机整数1-7的函数,就叫他生成器get7吧,用它来生成一个1-11的随机整数,不能使用random,而且要等概率。

411
来自专栏Code_iOS

算法?

建议数据结构和算法分开来学,这里只有算法,没有什么是数据结构!数据结构在这里; --->> 点我

893
来自专栏Python小屋

Python标准库random用法精要

random标准库主要提供了伪随机数生成函数和相关的类,同时也提供了SystemRandom类(也可以直接使用os.urandom()函数)来支持生成加密级别要...

2776
来自专栏me的随笔

JavaScript 随机数

JavaScript内置函数random(seed)可以产生[0,1)之间的随机数,若想要生成其它范围的随机数该如何做呢?

866
来自专栏Python小屋

Python使用系统聚类方法进行数据分类案例一则

首先解释一下为啥最近发的文章中代码都是截图而不是文本,这样做主要是希望大家能对着代码敲一遍而不是直接复制运行得到结果就算了,这样可以加深印象,学到更多东西。当然...

3514
来自专栏ArrayZoneYour的专栏

如何用Python将时间序列转换为监督学习问题

像深度学习这样的机器学习方法可以用于时间序列预测。

5519
来自专栏AI研习社

让 TensorFlow 估算器的推断提速百倍,我是怎么做到的?

TensorFlow 估算器提供了一套中阶 API 用于编写、训练与使用机器学习模型,尤其是深度学习模型。在这篇博文中,我们描述了如何通过使用异步执行来避免每次...

712

扫码关注云+社区