专栏首页SnailTyantensorflow的基本用法——使用MNIST训练神经网络

tensorflow的基本用法——使用MNIST训练神经网络

本文主要是使用tensorflow和mnist数据集来训练神经网络。

#!/usr/bin/env python
# _*_ coding: utf-8 _*_

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 下载mnist数据
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)


# 定义神经网络模型的评估部分
def compute_accuracy(W, b):
    # 定义测试数据的placeholder
    x = tf.placeholder(tf.float32, [None, 784])
    # 定义测试数据的真实标签的placeholder
    y_ = tf.placeholder(tf.float32, [None, 10])
    # 定义预测值
    y = tf.nn.softmax(tf.matmul(x, W) + b)
    # 判断预测值y和真实值y_中最大数的索引是否一致,y的值为1-10概率
    correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
    # 计算准确率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    # 输入测试数据,执行准确率的计算并返回
    return sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})

# 定义神经网络模型的训练部分
# 下面定义的神经网络只有一层W*x+b
# 定义输入数据placeholder,不定义输入样本的数目——None,但定义每个样本的大小为784
x = tf.placeholder(tf.float32, [None, 784])
# 定义神经网络层的权重参数
W = tf.Variable(tf.zeros([784, 10]))
# 定义神经网络层的偏置参数
b = tf.Variable(tf.zeros([10]))
# 定义一层神经网络运算,激活函数为softmax
y = tf.nn.softmax(tf.matmul(x, W) + b)
# 定义训练数据真实标签的placeholder
y_ = tf.placeholder(tf.float32, [None, 10])
# 定义损失函数,损失函数为交叉熵,reduction_indices表示沿着tensor的哪个纬度来求和
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
# 定义神经网络的训练步骤,使用的是梯度下降法,学习率为0.5
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
# 初始化所有变量
init = tf.global_variables_initializer()
# 定义Session
sess = tf.Session()
# 执行变量的初始化
sess.run(init)
# 迭代进行训练
for i in range(1000):
    # 取出mnist数据集中的100个数据
    batch_xs, batch_ys = mnist.train.next_batch(100)
    # 执行训练过程并传入真实数据x, y_
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
    if i % 100 == 0:
        print compute_accuracy(W, b) 

执行结果如下:

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
0.4075
0.8948
0.9031
0.9074
0.9037
0.9125
0.9158
0.912
0.9181
0.9169

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • tensorflow的基本用法——保存神经网络参数和加载神经网络参数

    Tyan
  • tensorflow的基本用法——dropout的作用

    本文主要是介绍tensorflow中dropout的作用,dropout主要是用来防止过拟合,即提供模型的泛化能力。

    Tyan
  • tensorflow的基本用法

    Tyan
  • 强化学习笔记-Python/OpenAI/TensorFlow/ROS-程序指令

    版权声明:本文为zhangrelay原创文章,有错请轻拍,转载请注明,谢谢... https://...

    zhangrelay
  • 【TensorFlow篇】--Tensorflow框架可视化之Tensorboard

    TensorBoard是tensorFlow中的可视化界面,可以清楚的看到数据的流向以及各种参数的变化,本文基于一个案例讲解TensorBoard的用法。

    LhWorld哥陪你聊算法
  • 机器学习基础

    监督学习:训练时有输入及对应的输出结果的学习方式。目前推荐的学习方式,适合有比较好数据源的场景 非监督学习:训练时只有输入,不知道结果的学习方式。各种数据不完善...

    企鹅号小编
  • tf21: 身份证识别——识别身份证号

    上一篇: 身份证识别——生成身份证号和汉字 代码直接参考,验证码识别 #!/usr/bin/env python2 # -*- coding: utf-8 -*...

    MachineLP
  • tf API 研读3:Building Graphs

    tensorflow是通过计算图的方式建立网络。 比喻说明: 结构:计算图建立的只是一个网络框架。编程时框架中不会出现任何的实际值,所有权重(weight)和偏...

    MachineLP
  • python实现最大似然函数与结果展示

    AI之禅
  • 多任务验证码识别

    使用Alexnet网络进行训练,多任务学习:验证码是根据随机字符生成一幅图片,然后在图片中加入干扰象素,用户必须手动填入,防止有人利用机器人自动批量注册、灌水、...

    瓜大三哥

扫码关注云+社区

领取腾讯云代金券