前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >MNIST机器学习入门

MNIST机器学习入门

作者头像
py3study
发布2020-01-03 13:35:07
4580
发布2020-01-03 13:35:07
举报
文章被收录于专栏:python3python3

当我们开始学习编程的时候,第一件事往往是学习打印"Hello World"。就好比编程入门有Hello World,机器学习入门有MNIST。

MNIST是一个入门级的计算机视觉数据集,它包含各种手写数字图片。它也包含每一张图片对应的标签,告诉我们这个是数字几。

详细内容请参考:http://wiki.jikexueyuan.com/p...

文章末尾会给出相关python代码,运行环境是python3.6+anaconda+tensorflow,具体环境搭建本文不做阐述。

一、MNIST简介

官网链接:http://yann.lecun.com/exdb/mn...

这个MNIST数据库是一个手写数字的数据库,它提供了六万的训练集和一万的测试集。

它的图片是被规范处理过的,是一张被放在中间部位的28px*28px的灰度图。

总共4个文件:

train-images-idx3-ubyte: training set images train-labels-idx1-ubyte: training set labels t10k-images-idx3-ubyte: test set images t10k-labels-idx1-ubyte: test set labels

图片都被转成二进制放到了文件里面,

所以,每一个文件头部几个字节都记录着这些图片的信息,然后才是储存的图片信息。

二、tensorflow手写数字识别步骤

1、 将要识别的图片转为灰度图,并且转化为28*28矩阵

2、 将28*28的矩阵转换成1维矩阵

3、 用一个1*10的向量代表标签,因为数字是0~9,如数字1对应的矩阵就是:[0,1,0,0,0,0,0,0,0,0]

4、 softmax回归预测图片是哪个数字的概率。

这里顺带说一下还有一个回归:logistic,因为这里我们表示的状态不只两种,因此需要使用softmax

三、标签介绍(有监督学习/无监督学习)

监督学习:利用一组已知类别的样本调整分类器的参数,使其达到所要求性能的过程,也称为监督训练或有教师学习举个例子,MNIST自带了训练图片和训练标签,每张图片都有一个对应的标签,比如这张图片是1,标签也就是1,用他们训练程序,之后程序也就能识别测试集中的图片了,比如给定一张2的图片,它能预测出他是2

无监督学习:其中很重要的一类叫聚类举个例子,如果MNIST中只有训练图片,没有标签,我们的程序能够根据图片的不同特征,将他们分类,但是并不知道他们具体是几,这个其实就是“聚类”

标签的表示

在这里标签的表示方式有些特殊,它也是使用了一个一维数组,而不是单纯的数字,上面也说了,他是一个一位数组,0表示方法[1,0,0,0,0,0,0,0,0,0],1表示[0,1,0,0,0,0,0,0,0,0],………,

主要原因其实是这样的,因为softmax回归处理后会生成一个1*10的数组,数组[0,0]的数字表示预测的这张图片是0的概率,[0,1]则表示这张图片表示是1的概率……以此类推,这个数组表示的就是这张图片是哪个数字的概率(已经归一化),

因此,实际上,概率最大的那个数字就是我们所预测的值。两者对应来看,标准的标签就是表示图片对应数字的概率为100%,而表示其它数字的概率为0,举个例子,0表示[1,0,0,0,0,0,0,0,0,0],可以理解为它表示0的概率为100%,而表示别的数字的概率为0.

softmax回归

这是一个分类器,可以认为是Logistic回归的扩展,Logistic大家应该都听说过,就是生物学上的S型曲线,它只能分两类,用0和1表示,这个用来表示答题对错之类只有两种状态的问题时足够了,但是像这里的MNIST要把它分成10类,就必须用softmax来进行分类了。

P(y=0)=p0,P(y=1)=p1,p(y=2)=p2……P(y=9)=p9.这些表示预测为数字i的概率,(跟上面标签的格式正好对应起来了),它们的和为1,即 ∑(pi)=1。

tensorflow实现了这个函数,我们直接调用这个softmax函数即可,对于原理,可以参考下面的引文,这里只说一下我们这个MNIST demo要用softmax做什么。

(注:每一个神经元都可以接收来自网络中其他神经元的一个或多个输入信号,神经元与神经元之间都对应着连接权值,所有的输入加权和决定该神经元是处于激活还是抑制状态。感知器网络的输出只能取值0或1,不具备可导性。而基于敏感度的训练算法要求其输出函数必须处处可导,于是引入了常见的S型可导函数,即在每个神经元的输出之前先经过S型激活函数的处理。)

交叉熵

通俗一点就是,方差大家都知道吧,用它可以衡量预测值和实际值的相差程度,交叉熵其实也是一样的作用,那为什么不用方差呢,因为看sigmoid函数的图像就会发现,它的两侧几乎就是平的,导致它的方差在大部分情况下很小,这样在训练参数的时候收敛地就会很慢,交叉熵就是用来解决这个问题的,它的公式是 −∑y′log(y) ,其中,y是我们预测的概率分布,y’是实际的分布。

梯度下降

上面那步也说了,有个交叉熵,根据大伙对方差的理解,值越小,自然就越好,因此我们也要训练使得交叉熵最小的参数,这里梯度下降法就派上用场了,这个解释见上一篇系列文章吧,什么叫训练参数呢,可以想象一下,我们先用实际的值在二位坐标上画一条线,然后我们希望我们预测出来的那些值要尽可能地贴近这条线,我们假设生成我们这条线的公式ax+ax^2+bx^3+…..,我们需要生成这些系数,要求得这些系数,我们就需要各种点代入,然后才能求出,所以其实训练参数跟求参数是个类似的过程。

预测

训练结束以后我们就可以用这个模型去预测新的图片了,大概意思就是输入对应的值就能获取相应的结果。

代码如下:

代码语言:javascript
复制
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import gzip
import os
import tempfile

import numpy
from six.moves import urllib
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
import tensorflow.examples.tutorials.mnist.input_data
"""MNIST机器学习"""
"""下载并读取数据源"""
mnist = read_data_sets("MNIST_data/", one_hot=True)
"""x是784维占位符,基础图像28x28=784"""
x = tf.placeholder(tf.float32,[None,784])
"""W表示证据值向量,因为总数据量为0~9,每一维对应不同的数字,因此为10"""
W = tf.Variable(tf.zeros([784,10]))
"""b同理,为10"""
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x,W)+b)
"""初始化交叉熵的值"""
y_ = tf.placeholder("float",[None,10])
"""计算交叉熵"""
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
"""梯度下降算法"""
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
"""开始训练模型"""
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000):
    batch_xs,batch_ys = mnist.train.next_batch(100)
    sess.run(train_step,feed_dict={x:batch_xs,y_:batch_ys})
"""评估模型"""
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
"""正确率"""
k = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
print(k)
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2019/09/25 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档