专栏首页贾志刚-OpenCV学堂tensorflow中实现神经网络训练手写数字数据集mnist

tensorflow中实现神经网络训练手写数字数据集mnist

tensorflow中实现神经网络训练手写数字数据集mnist

一:网络结构

基于tensorflow实现一个简单的三层神经网络,并使用它训练mnist数据集,神经网络三层分别为:

输入层:

像素数据输入28x28=784 个输入节点

隐藏层:

30个神经元节点

输出层:

10个神经元节点,对应 0 ~ 9 十个数字

图示结构如下:

网络结构的代码实现:

hidden_nodes = 30
x = tf.placeholder(shape=[None, 784], dtype=tf.float32)
y = tf.placeholder(shape=[None, 10], dtype=tf.float32)

w1 = tf.Variable(tf.truncated_normal(shape=[784, hidden_nodes]), dtype=tf.float32)
b1 = tf.Variable(tf.truncated_normal(shape=[1, hidden_nodes]), dtype=tf.float32)

w2 = tf.Variable(tf.truncated_normal(shape=[hidden_nodes, 10]), dtype=tf.float32)
b2 = tf.Variable(tf.truncated_normal(shape=[1, 10]), dtype=tf.float32)

# layer hidden
nn_1 = tf.add(tf.matmul(x, w1), b1)
h1 = tf.nn.sigmoid(nn_1)

# layer output
nn_2 = tf.add(tf.matmul(h1, w2), b2)
out = tf.nn.sigmoid(nn_2)

# loss function
error = tf.square(tf.subtract(y, out))
loss = tf.reduce_sum(error)

# back prop
step = tf.train.GradientDescentOptimizer(0.05).minimize(loss)
init = tf.global_variables_initializer()

二:数据读取与训练

读取mnist数据集

from tensorflow.examples.tutorials.mnist import inputdata
mnist = inputdata.readdatasets("MNISTdata/", onehot=True)

如果不行,就下载下来,放到本地即可

执行训练的代码如下

# accurate  model
acc_mat = tf.equal(tf.argmax(out, 1), tf.argmax(y, 1))
acc = tf.reduce_sum(tf.cast(acc_mat, tf.float32))
with tf.Session() as sess:
    sess.run(init)
    for i in range(20000):
        batch_xs, batch_ys = mnist.train.next_batch(10)
        sess.run(step, feed_dict={x: batch_xs, y: batch_ys})
        if i % 1000 == 0:
            x_input = mnist.test.images[:1000]
            y_input = mnist.test.labels[:1000]
            curr_acc = sess.run(acc, feed_dict={x: x_input, y: y_input})
            print("current acc : ", curr_acc)

训练结果:

测试集上对1000张手写数字图像测试正确识别921张,准确率高达92.1%。说明传统的人工神经网络表现还是不错的,这个还是在没有优化的情况下,通过修改批量数大小,修改学习率,添加隐藏层节点数与dropout正则化,可以更进一步提高识别率。

本文分享自微信公众号 - OpenCV学堂(CVSCHOOL)

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2018-07-07

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • TensorFlow中的feed与fetch

    TensorFlow中的feed与fetch 一:占位符(placeholder)与feed 当我们构建一个模型的时候,有时候我们需要在运行时候输入一些初始数...

    OpenCV学堂
  • 干货 | Tensorflow设计简单分类网络实现猫狗图像分类训练与测试

    第一层:32个feature map 5x5卷积、步长为2、最大值池化 局部相应归一化处理(LRN) 第二层:64个feature map 3x3卷积、步长为...

    OpenCV学堂
  • 教程 | 基于LSTM实现手写数字识别

    基于tensorflow,如何实现一个简单的循环神经网络,完成手写数字识别,附完整演示代码。

    OpenCV学堂
  • 深度学习之 TensorFlow(五):mnist 的 Alexnet 实现

    希希里之海
  • TensorFlow2.X学习笔记(1)--TensorFlow核心概念

    TensorFlow™ 是一个采用 数据流图(data flow graphs),用于数值计算的开源软件库。节点(Nodes)在图中表示数学操作,图中的线(e...

    MiChong
  • TensorFlow2.X学习笔记(2)--TensorFlow的层次结构介绍

    MiChong
  • TensorFlow2.X学习笔记(3)--TensorFlow低阶API之张量

    TensorFlow提供的方法比numpy更全面,运算速度更快,如果需要的话,还可以使用GPU进行加速。

    MiChong
  • TensorFlow2.X学习笔记(4)--TensorFlow低阶API之AutoGraph相关研究

    而Autograph机制可以将动态图转换成静态计算图,兼收执行效率和编码效率之利。

    MiChong
  • win10 tensorflow笔记2 MNIST机器学习入门

    这里跟官方有两处不同 1:第1行代码原文是import input_data这里的input_data是无法直接导入的。需要给出具体路径from tensor...

    我是木木酱呀
  • 【tensorflow2.0】张量的结构操作

    张量数学运算主要有:标量运算,向量运算,矩阵运算。另外我们会介绍张量运算的广播机制。

    绝命生

扫码关注云+社区

领取腾讯云代金券