MNIST__数字识别__SOFTMAX

本次MNIST的手写数字识别未采用input_data.py文件,想尝试一下用原始的数据集来运行这个DEMO。

需要注意的一点是,源码中的图片标签采用的的ONE-HOT编码,而数据集中的标签用的是具体的数字。

例如:图片上的数字和标签的值是5,其对应的ONT-HOT编码为[0,0,0,0,0,1,0,0,0,0](分别对应数值【0,1,2,3,4,5,6,7,8,9】) ,也就是长度为10的一维数组的第6个元素为1,其余的全为0。

源码结构:

1.读取MNIST

2.创建占位符(用读取的数据填充这些空占位符)

3.选用交叉熵作为损失函数

4.使用梯度下降法(步长0.02),来使损失函数最小

5.初始化变量

6.开始计算

7.输出识别率

源码:

import tensorflow as tf
import numpy as np
import struct
#  解析IDX文件格式的MNIST数据集,需要用struct模块对二进制文件进行读取操作
#  struct模块中最重要的三个函数是pack() , unpack() 和calcsize()
#  calculate 英 [ˈkælkjuleɪt]  vt.计算  

#  按照给定的格式(fmt)解析字节流string,返回解析出来的tuple
#  tuple = unpack(fmt, string)
#  format  英 [ˈfɔ:mæt] 格式;使格式化 (format在代码中简化为fmt)
#  tuple 英 [tʌpl] 美 [tʌpl]   n.元组,数组

#  按照给定的格式化字符串,把数据封装成字符串(实际上是类似于c结构体的字节流)
#  string = struct.pack(fmt, v1, v2, ...)

#  计算给定的格式(fmt)占用多少字节的内存
#  offset = calcsize(fmt)

import matplotlib.pyplot as plt
#  matplotlib.pyplot是一个命令型函数集合,功能齐全的绘图模块

#------------------------------------ 1 ------------------------------------------
def images_load(filename):
    #   def image_load(filename)用于读取图片数据
    #   file_name表示要访问的文件名 
    
    with open(filename, 'rb') as contents:
    #rb表示该文件以只读方式打开,使用with open()as 的好处在于:读取文件内容后会
    #自动关闭文件,无需手动关闭。
    
        data_buffers = contents.read()
    #   从一个打开的文件读取数据 
    #   buffer   英 [ˈbʌfə(r)]  
    #   n. 缓冲器; 起缓冲作用的人或物; [化] 缓冲液,缓冲剂; [计] 缓冲区
    #   vt. 缓冲       (个人感觉,了解单词的意思,代码会变的亲切一些)      
   
        magic,num,rows,cols = struct.unpack_from('>IIII',data_buffers, 0)
    #   读取图片文件前4个整型数字  
    
        bits = num * rows * cols
    #  整个images数据大小为60000*28*28        
    
        images = struct.unpack_from('>' + str(bits) + 'B', data_buffers, struct.calcsize('>IIII'))
    #   读取images数据
    
        images = np.reshape(images, [num, rows * cols])
    #   转换为[60000,784]型数组
    
    return images
#--------------------------------------- 2 --------------------------------------
def labels_load(filename):
    
    contants = open(filename, 'rb')
    #这里用open()打开文件,读取结束后要用close()关闭    
    
    data_buffers = contants.read()
    
    magic,num = struct.unpack_from('>II', data_buffers, 0) 
    #   读取label文件前2个整形数字,label的长度为num
    #   magic翻译成“魔数”,用于校验下载的文件是否属于MNIST数据集
   
    labels = struct.unpack_from('>' + str(num) + "B", data_buffers, struct.calcsize('>II'))
    #  读取labels数据
        
    contants.close()
    #  关闭文件    
    
    labels = np.reshape(labels, [num])
    #  转换为一维数组
    
    return labels   
#---------------------------------------- 3 ----------------------------------------
#读取训练和测试文件
filename_train_images = 'E:\\MNIST\\train-images.idx3-ubyte'
filename_train_labels = 'E:\\MNIST\\train-labels.idx1-ubyte'
filename_test_images = 'E:\\MNIST\\t10k-images.idx3-ubyte'
filename_test_labels = 'E:\\MNIST\\t10k-labels.idx1-ubyte'
train_images=images_load(filename_train_images)
train_labels=labels_load(filename_train_labels)
test_images=images_load(filename_test_images)
test_labels=labels_load(filename_test_labels)

#------------------------------------- 4 ------------------------------------------

x = tf.placeholder("float", [None, 784]) #输入占位符(每张手写数字有28X28个像素点)
y_ = tf.placeholder("float", [None,10]) #输入占位符(用one-hot编码表示标签的值)

w = tf.Variable(tf.zeros([784,10])) #权重
b = tf.Variable(tf.zeros([10])) #偏置
y = tf.nn.softmax(tf.matmul(x,w) + b) 
# 输入矩阵x与权重矩阵w相乘,加上偏置矩阵b,然后求softmax(sigmoid函数升级版,可以分成多类)
# softmax会将xW+b分成10类,对应数字0-9

cross_entropy = -tf.reduce_sum(y_*tf.log(y))
# 计算交叉熵

train_step = tf.train.GradientDescentOptimizer(0.02).minimize(cross_entropy)
# 使用梯度下降法(步长0.02),来使偏差和最小

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
# 初始化变量

def train_num(n_t):
    xst=train_images[:n_t,:]  
    zst=train_labels[:n_t]  
    yst=np.zeros((n_t,10))
    for i in range(0,n_t-1):
        yst[i][zst[i]]=1
    return xst,yst
#训练图片的数量,标签转换为ONE-HOT编码

def test_num(n_t):
    xst=test_images[:n_t,:]  
    zst=test_labels[:n_t]  
    yst=np.zeros((n_t,10))
    for i in range(0,n_t-1):
        yst[i][zst[i]]=1
    return xst,yst
#测试图片的数量,标签转换为ONE-HOT编码
#======================================= 5 ===========================
xs_t,ys_t=test_num(10000)  
#测试图片10000张

xs,ys=train_num(1300)
#用1300张图片进行训练

sess.run(train_step, feed_dict={x:xs,y_:ys })  
correct_prediction_1 = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy_1 = tf.reduce_mean(tf.cast(correct_prediction_1, "float"))
# 计算训练精度  

print(sess.run(accuracy_1, feed_dict={x: xs_t, y_: ys_t})) 
#输出识别的准确率

#=========================================================

print('GOOD WORK')
#  点个赞
0.7147
GOOD WORK
#运行结果 0.7147,看起来很糟糕……

将训练数据的值由1300提高到60000,结果是0.6803,居然降低了。好吧,总感觉哪里不太对,可又说不上来~

参考资料:

ONE-HOT使用体会 : https://blog.csdn.net/lanhaier0591/article/details/78702558

训练Tensorflow识别手写数字 : https://www.cnblogs.com/tengge/p/6363586.html

原创声明,本文系作者授权云+社区发表,未经许可,不得转载。

如有侵权,请联系 yunjia_community@tencent.com 删除。

编辑于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏机器学习实践二三事

Tensorflow实现word2vec

大名鼎鼎的word2vec,相关原理就不讲了,已经有很多篇优秀的博客分析这个了. 如果要看背后的数学原理的话,可以看看这个: https://wenku.b...

55570
来自专栏null的专栏

挑战数据结构和算法面试题——最大间隔

题目来自伯乐在线,欢迎有不同答案的同学来一起讨论。 ? 分析: 本题首先需要理解清楚最大间隔的最小: 最初的间隔为:[1,1,4,1],此时最大间隔为4 删...

31430
来自专栏简书专栏

基于tensorflow+CNN的搜狐新闻文本分类

tensorflow是谷歌google的深度学习框架,tensor中文叫做张量,flow叫做流。 CNN是convolutional neural netwo...

32520
来自专栏后端技术探索

一致性hash算法清晰详解!

consistent hashing 算法早在 1997 年就在论文 Consistent hashing and random trees 中被提出,目前在 ...

11420
来自专栏冷冷

利用iText 组件导出PDF

maven依赖:       <dependency>    <groupId>com.itextpdf</groupId>    <artifactId>...

28350
来自专栏liuchengxu

详解 MNIST 数据集

MNIST 数据集已经是一个被"嚼烂"了的数据集, 很多教程都会对它"下手", 几乎成为一个 "典范". 不过有些人可能对它还不是很了解, 下面来介绍一下.

20020
来自专栏小小挖掘机

TensorFlow 和 NumPy 的 Broadcasting 机制探秘

在使用Tensorflow的过程中,我们经常遇到数组形状不同的情况,但有时候发现二者还能进行加减乘除的运算,在这背后,其实是Tensorflow的broadca...

13520
来自专栏智能算法

深度学习三人行(第2期)---- TensorFlow爱之再体验

上一期,我们一起学习了TensorFlow的基础知识,以及其在线性回归上的初体验,该期我们继续学习TensorFlow方面的相关知识。学习的路上,我们多多交流,...

365100
来自专栏小樱的经验随笔

qsc oj 22 哗啦啦村的刁难(3)(随机数,神题)

哗啦啦村的刁难(3) 发布时间: 2017年2月28日 20:00   最后更新: 2017年2月28日 20:01   时间限制: 1000ms   内存限制...

29690
来自专栏Petrichor的专栏

leetcode: 36. Valid Sudoku

18830

扫码关注云+社区

领取腾讯云代金券