专栏首页share ai happinessMNIST数据集介绍及计算

MNIST数据集介绍及计算

最近也是考试多,没来得及更新文章。废话不多说,理论讲太多没啥感觉,不清楚的可以翻到前面的文章仔细看看。

MNIST数据集

MNIST数据集是机器学习领域中非常经典的一个数据集,由60000个训练样本和10000个测试样本组成,每个样本都是一张28 * 28像素的灰度手写数字图片,

其中每一张图片都代表0~9中的一个数字。

怎么通过输入数据经过神经网络参数传到最后的过程?

下载:

官方网站 http://yann.lecun.com/exdb/mnist/

一共4个文件,训练集、训练集标签、测试集、测试集标签

文件名称

大小

内容

train-images-idx3-ubyte.gz

9,681 kb

55000张训练集,5000张验证集

train-labels-idx1-ubyte.gz

29 kb

训练集图片对应的标签

t10k-images-idx3-ubyte.gz

1,611 kb

10000张测试集

t10k-labels-idx1-ubyte.gz

5 kb

测试集图片对应的标签

导入Mnist数据集

MNIST数据集在机器学习领域非常常用的,一般拿出一个模型都会在这里进行验证,所以说TensorFlow想让用户方便实验,本身就集成了这个数据集,不用额外的去下载。

创建一个mnist_cs.py文件。

怎么导入mnist数据集

# 从tensorflow里面加载MNIST数据集
from tensorflow.examples.tutorials.mnist import input_data

# 载入MNIST数据集,如果之前没有下载过,则会自动下载到相应路径
mnist = input_data.read_data_sets(‘/path/MNIST_data/’,
         one_hot=True)
# 打印 Training data size: 55000,将60000数据分成训练集和验证集
print (‘training_data_size:’, mnist.train.num_examples)

# 打印 Example training data: [0. 0. 0. … 0.380 0.376 … 0.]
print (‘Example training data:’, mnist.train.images[0])

打印出来是784维的向量。784是一个28*28矩阵,把它向量化了,一行一行拼在一起。

设置神经网络结构相关的参数

#输入层的节点数。对于MNIST数据集,这个等于图片的总像素=28*28
INPUT_NODE = 784
#输出层的节点数。在MNIST数据集中有0~9这10个数字类别
OUTPUT_NODE = 10 
#神经网络隐藏节点数,这个是自己定的
LAYER1_NODE = 500

定义获取变量函数

可以对参数进行约束,regularizer把这个参数也作为最后的损失函数losses是所有约束项的集合,最后跟目标函数一起优化。

def get_weight_variable(shape, regularizer):
    # shape是变量的大小,regularizer是正则化函数。
    # tf.truncated_normal_initializer是正态分布初始化函数
   weights = tf.get_variable(“weights”, shape, initializer=          
             tf.truncated_normal_initializer(stddev=0.1))
#tf.add_to_collections 将当前变量的正则化损失加入名字为losses的集合
if regularizer != None:
  tf.add_to_collections(‘losses’,regularizer(weights))
return weights

定义神经网络向前传播过程

第一层前馈计算

def inference(input_tensor, regularizer):
   #声明第一层神经网络的命名空间’layer1’及相关变量,并完成前向传播过程 
    with tf.variable_scope(‘layer1’): 
     weights = get_weight_variable(
      [INPUT_NODE,LAYER1_NODE],regularizer)
     biases = tf.get_variable(“biases”,[LAYER1_NODE],
              initializer=tf.constant_initializer(0.0))
     output1 = tf.matmul(input_tensor, weights)+biases
     layer1 = tf.nn.relu(ouput1) #使用relu激活函数  
   #声明第二层神经网络的名命空间’layer2’及相关变量,并完成前向传播过程

input_tensor输入图片

def inference(input_tensor, regularizer):
   #声明第一层神经网络的名命空间’layer1’及相关变量,并完成前向传播过程
     +
   #声明第二层神经网络的名命空间’layer2’及相关变量,并完成前向传播过程
   with tf.variable_scope(‘layer2’): 
     weights = get_weight_variable(
        [LAYER1_NODE, OUTPUT_NODE], regularizer)
     biases = tf.get_variable(“biases”,[OUTPUT_NODE],
              initializer=tf.constant_initializer(0.0))
     layer2 = tf.matmul(layer1, weights) + biases
     #返回前向传播的结果
   return layer2  

最后一层不用激活,直接返回,是为了归一化,放到交叉熵损失函数里。

本文分享自微信公众号 - 1001次重燃(smile765999),作者:木野归郎

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2020-11-08

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • Spark 在大数据中的地位 - 中级教程

    Spark最初由美国加州伯克利大学的AMP实验室于2009年开发,是基于内存计算的大数据并行计算框架,可用于构建大型的、低延迟的数据分析应用程序。

    木野归郎
  • Elasticsearch 加班不睡觉(一)

    在实际MySQL业务中,一般会先验证sql有没有问题,如果没有问题,再写业务代码。实际ES业务中,也一样,先DSL确认没有问题,再写业务代码。

    木野归郎
  • 项目实战中Hive注释乱码解决方案

    下面这些都是我在工作中总结出来的,希望对大家有帮助,如果有其他的问题或者解决方法可以留言给我。

    木野归郎
  • AOV网络拓扑排序

    这个算法,主要是为输出一个无环图的拓扑序列 算法思想: 主要依赖一个栈,用来存放没有入度的节点,每次读取栈顶元素,并将栈顶元素的后继节点入度减一,如果再次出现入...

    用户1154259
  • MyBatis设计思想(4)——缓存模块

    相信大家对于缓存都不陌生,MyBatis也提供了缓存的功能,在执行查询语句时首先尝试从缓存获取,避免频繁与数据库交互,大大提升了查询效率。MyBatis有所谓的...

    张申傲
  • Hbase Region Split compaction 过程分析以及调优

    Hbase以高并发写入而闻名,而Compact和Split功能贯穿了hbase的整个写入过程,而只有掌握了Compact和Split内部逻辑以及控制参数才能根据...

    liubang01
  • 到底哪种类型的错误信息会阻止business transaction的保存

    当试图在CRM WebUI保存一个business transaction比如Opportunity时,可能会遇到各种各样的错误消息。有的错误消息会阻止Busi...

    Jerry Wang
  • 一、简单使用二、 并行循环的中断和跳出三、并行循环中为数组/集合添加项四、返回集合运算结果/含有局部变量的并行循环五、PLinq(Linq的并行计算)

    沿用微软的写法,System.Threading.Tasks.::.Parallel类,提供对并行循环和区域的支持。 我们会用到的方法有For,...

    vv彭
  • 4.0中的并行计算和多线程详解(一)

    转自:https://www.cnblogs.com/sorex/archive/2010/09/16/1828214.html

    vv彭
  • MyBatis-Plus长文图解笔记

    DROP TABLE IF EXISTS user; CREATE TABLE user ( id BIGINT(20) NOT NULL COMMENT...

    崔笑颜

扫码关注云+社区

领取腾讯云代金券