TensorFlow入门1-minist

模型公式

                         y = softmax(Wx+b)

交叉熵loss函数,可以参考似然函数,y 是我们预测的概率分布, y' 是实际的分布(我们输入的one-hot vector)

loss = -Σy'ilogyi

import tensorflow.examples.tutorials.mnist.input_data
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
mnist = read_data_sets('MNIST_data/', one_hot=True)
import tensorflow as tf

# 这里的None表示此张量的第一个维度可以是任何长度的
# x不是一个特定的值,而是一个占位符placeholder
x = tf.placeholder(tf.float32, [None, 784])

# W参数矩阵784行,10列
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

# x 是个noneX784矩阵, y是个noneX10
# 这里为什么要用xW,而不是Wx,因为矩阵+b向量运算,会将b向量每个元素加到xW每一列上
# softmax 按照行来计算,一行算出来正好是对应y
y = tf.nn.softmax(tf.matmul(x,W)+b)
y_ = tf.placeholder('float', [None, 10])

#计算交叉熵
cross_entroy = -tf.reduce_sum(y_*tf.log(y))

train_step = tf.train.GradientDescentOptimizer(
    0.01).minimize(cross_entroy)

init = tf.global_variables_initializer()

sess = tf.Session()
sess.run(init)

for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(200)
    sess.run(train_step, feed_dict = {x:batch_xs, y_:batch_ys})

tf.argmax函数用法

testArgmax=tf.argmax([[12,34,3],[78,21,45]],0)
init = tf.global_variables_initializer()
sess.run(init)
sess.run(testArgmax)

输出(第二个参数为0,取出每一列最大值的索引)

array([1, 0, 1])
testArgmax=tf.argmax([[12,34,3],[78,21,45]],1)
init = tf.global_variables_initializer()
sess.run(init)
sess.run(testArgmax)

输出(第二个参数为1,取出每一行最大值的索引)

array([1, 0])

取出每一行最大值索引与标准比较是否相等,[True,False...]

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))

tf.cast函数将[T,F...]转化为[1,0....],tf.reduce_mean计算1占有多少

accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

算出准确率大约0.9185,这里面还有很多东西可以调整,比如学习率、迭代次数、每次随机取出的样本个数,这个回头以后再来调参。

函数测试小例子

testArgmax1=tf.argmax([[12,34,3],[32,40,45]],0)
testArgmax2=tf.argmax([[12,34,3],[78,21,45]],0)
accurate = tf.equal(testArgmax1, testArgmax2)
testcast = tf.cast(accurate, 'float')
accuracy = tf.reduce_mean(testcast)
init = tf.global_variables_initializer()
sess.run(init)
sess.run([testcast,accuracy])

输出

[array([1., 0., 1.], dtype=float32), 0.6666667]

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏python读书笔记

《python算法教程》Day10 - 平面最近点对问题平面最小点对问题介绍代码演示

今天是《python算法教程》的第10篇读书笔记。笔记的主要内容是使用python实现求最小点对的时间复杂度为O(nlogn)的算法。 平面最小点对问题介绍 在...

705120
来自专栏mathor

matlab—基本操作与矩阵输入

还有一个月就美赛了,本系列文章适用于完全没有任何matlab基础,但是有别的编程语言基础的人看,我会结合自己的理解,有的放矢的讲,不会掺杂很多废话,各位读者轻喷...

12910
来自专栏杂七杂八

matlab中的函数介绍(max,min,unidrnd,norm)

遇到不知道的函数时,可以使用help 函数名来查看帮助 1 求矩阵A的最大值的函数有3种调用格式,分别是: max(A):返回一个行向量,向量的第i个元...

41250
来自专栏Python小屋

Python标准库random用法精要

random标准库主要提供了伪随机数生成函数和相关的类,同时也提供了SystemRandom类(也可以直接使用os.urandom()函数)来支持生成加密级别要...

31160
来自专栏数据结构与算法

27:单词翻转

27:单词翻转 总时间限制: 1000ms 内存限制: 65536kB描述 输入一个句子(一行),将句子中的每一个单词翻转后输出。 输入只有一行,为一个...

43570
来自专栏用户画像

特征处理

版权声明:本文为博主-姜兴琪原创文章,未经博主允许不得转载。 https://blog.csdn.net/jxq0816/article/details...

10520
来自专栏C语言及其他语言

【每日一题】

笨小猴的词汇量很小,所以每次做英语选择题的时候都很头疼。但是他找到了一种方法,经试验证明,用这种方法去选择选项的时候选对的几率非常大! 这种方法的具体描述如下:...

10920
来自专栏个人分享

旋转数组的最小数字

把一个数组最开始的若干个元素搬到数组的末尾,我们称之为数组的旋转。输入一个非递减序列的一个旋转,输出旋转数组的最小元素。例如数组{3,4,5,1,2}为{1,2...

9540
来自专栏李智的专栏

Deep learning基于theano的keras学习笔记(1)-Sequential模型

《统计学习方法》中指出,机器学习的三个要素是模型,策略和优算法,这当然也适用于深度学习,而我个人觉得keras训练也是基于这三个要素的,先建立深度模型,然后选用...

10310
来自专栏desperate633

LintCode 寻找缺失的数题目分析方法二 交换法

给出一个包含 0 .. N 中 N 个数的序列,找出0 .. N 中没有出现在序列中的那个数。

8430

扫码关注云+社区

领取腾讯云代金券