TensorFlow-3: 用 feed-forward neural network 识别数字

今天继续看 TensorFlow Mechanics 101:

https://www.tensorflow.org/get_started/mnist/mechanics

完整版教程可以看中文版tutorial:

http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_tf.html

这一节讲了使用 MNIST 数据集训练并评估一个简易前馈神经网络(feed-forward neural network)

input,output 和前两节是一样的:即划分数据集并预测图片的 label

data_sets.train    55000个图像和标签(labels),作为主要训练集。
data_sets.validation    5000个图像和标签,用于迭代验证训练准确度。
data_sets.test    10000个图像和标签,用于最终测试训练准确度(trained accuracy)。

主要有两个代码:

mnist.py

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/mnist/mnist.py

构建一个全连接网络,由 2 个隐藏层,1 个 softmax_linearv输出构成

定义损失函数,用cross entropyv

定义训练时的优化器,用 GradientDescentOptimizer

定义评价函数,用tf.nn.in_top_k

fully_connected_feed.py

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/mnist/fully_connected_feed.py

  • placeholder_inputs 传入 batch size,得到 image 和 label 的两个placeholder
  • 定义生成 feed_dict的函数,key 是 placeholders,value 是 data
  • 定义 do_eval 函数,每隔 1000 个训练步骤,就对模型进行以下评估,分别作用于训练集、验证集和测试集
  • 训练时:
  • 导入数据
  • 得到 image 和 label 两个 placeholder
  • 传入 mnist.inference 定义的 NN, 得到 predictions
  • 将 predictions 传给 mnist.loss 计算 loss
  • loss 传给mnist.training 进行优化训练
  • 再用 mnist.evaluation 评价预测值和实际值

代码中涉及到下面几个函数:

with tf.Graph().as_default():

即所有已经构建的操作都要与默认的 tf.Graph 全局实例关联起来,tf.Graph 实例是一系列可以作为整体执行的操作

summary = tf.summary.merge_all():

为了释放 TensorBoard 所使用的 events file,所有的即时数据都要在图表构建时合并至一个操作 op 中,每次运行 summary 时,都会向 events file 中写入最新的即时数据

summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph):

用于写入包含了图表本身和即时数据具体值的 events file。

saver = tf.train.Saver():

就是向训练文件夹中写入包含了当前所有可训练变量值 checkpoint file

with tf.name_scope('hidden1'):

主要用于管理一个图里面的各种 op,返回的是一个以 scope_name命名的 context manager,一个 graph 会维护一个 name_space的堆,实现一种层次化的管理,避免各个 op 之间命名冲突。例如,如果额外使用 tf.get_variable() 定义的变量是不会被tf.name_scope() 当中的名字所影响的

tf.nn.in_top_k(logits, labels, 1):

意思是在 K 个最有可能的预测中如果可以发现 true,就将输出标记为 correct。本文 K 为 1,也就是只有在预测是 true 时,才判定它是 correct。

原创声明,本文系作者授权云+社区发表,未经许可,不得转载。

如有侵权,请联系 yunjia_community@tencent.com 删除。

编辑于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏Petrichor的专栏

tf.nn.max_pool

801
来自专栏李智的专栏

Deep learning基于theano的keras学习笔记(0)-keras常用的代码

这里不推荐使用pickle或cPickle来保存Keras模型。 1. 一般使用model.save(filepath)将Keras模型和权重保存在一个HD...

871
来自专栏开发与安全

算法:Solutions for the Maximum Subsequence Sum Problem

The maximum subarray problem is the task of finding the contiguous subarray wit...

1928
来自专栏程序员的知识天地

Python学习,这有可能是最详细的PIL库基本概念文章了

PIL有如下几个模块:Image模块、ImageChops模块、ImageCrackCode模块、ImageDraw模块、ImageEnhance模块、Imag...

813
来自专栏数值分析与有限元编程

有限元 | 梁单元有限元程序算例

之前发过一个梁单元有限元分析程序。在好友测试时发现一个问题,就是程序中的real型变量默认为kind=4,我们姑且称为单精度型。这样限制了程序的使用,在一些问题...

2928
来自专栏云时之间

深度学习与TensorFlow:VGG论文复现

上一篇文章我们介绍了下VGG这一个经典的深度学习模型,今天便让我们通过使用VGG开源的VGG16模型去复现一下该论文.

1.5K3
来自专栏PaddlePaddle

【场景文字识别】场景文字识别

1. STR任务简介 许多场景图像中包含着丰富的文本信息,对理解图像信息有着重要作用,能够极大地帮助人们认知和理解场景图像的内容。场景文字识别是在图像背景复杂、...

4127
来自专栏贾志刚-OpenCV学堂

tensorflow Object Detection API使用预训练模型mask r-cnn实现对象检测

Mask R-CNN是何凯明大神在2017年整出来的新网络模型,在原有的R-CNN基础上实现了区域ROI的像素级别分割。关于Mask R-CNN模型本身的介绍与...

8202
来自专栏机器学习算法工程师

Tensorflow实战:Discuz验证码识别

本文将使用深度学习框架 Tensorflow 训练出一个用于破解 Discuz 验证码的模型。

4.2K9
来自专栏HT

电信网络拓扑图自动布局之总线

在前面《电信网络拓扑图自动布局》一文中,我们大体介绍了 HT for Web 电信网络拓扑图自动布局的相关知识,但是都没有深入地描述各种自动布局的用法,我们今天...

2378

扫码关注云+社区