前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >手写数字图像数据集MNIST

手写数字图像数据集MNIST

作者头像
龙哥
发布2022-03-09 15:09:26
7680
发布2022-03-09 15:09:26
举报
文章被收录于专栏:Python绿色通道Python绿色通道

注:本文选自机械工业出版社出版的《从零开始构建深度前馈神经网络(Python+TensorFlow 2.x)》一书,略有改动。经出版社授权刊登于此。

MNIST是经典的手写数字(handwritten digits)图像数据集。其中,训练数据集(training set,简称训练集)包含60 000个样本,测试数据集(test set,简称测试集)包含10 000个样本。

图1展示了MNIST训练集的前15个样本。每幅图像代表一个手写数字,每个方框下方的数字是这个图像对应的标签(label)。

一幅图像及其对应的标签构成了一个输入/输出对,例如,图1左上角的图像与其正下方的5构成了一个输入/输出对,我们把这个输入/输出对称为一个样本(sample/example)。输入通常由特征向量(feature vector)表示。例如,图1左上角的图像的原始数据是一个784维的特征向量。

图1 MNIST训练集的前15个样本

本章将训练一个k-NN模型,其输入是784维的特征向量,输出为相应标签的预测值,即,给定任意一个表示手写数字的784维向量,预测它是0~9中的哪一个。

使用TensorFlow加载MNIST

先来看一段示例代码:

%matplotlib inline

import matplotlib.pyplot as plt

from tensorflow import keras

def ds_imshow(im_data, im_label):

plt.figure(figsize=(10,10))

for i in range(len(im_data)):

plt.subplot(5,5,i+1)

plt.xticks([])

plt.yticks([])

plt.grid(False)

plt.imshow(im_data[i], cmap=plt.cm.binary)

plt.xlabel(im_label[i])

plt.show()

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

ds_imshow(x_train[:15].reshape((15,28,28)), y_train[:15])

上面的代码在导入必要的模块后定义了一个名为ds_imshow()的函数,然后加载数据集,最后将加载的数据作为参数并调用ds_imshow()函数显示图像。

其中,%matplotlib inline需要在新建Notebook后且首次调用plt.show()之前运行,仅需运行一次即可作用于整个Notebook。

ds_imshow()函数将传入的NumPy数组显示为图像,参数im_data用于接收图像数组,每幅图像表示一个样本特征,im_label是与之对应的标签。

keras.datasets.mnist.load_data()方法用于加载数据集,首次运行时需要用十几秒到几分钟的时间进行远程下载,再次使用时将从本地加载。

数组x_train表示训练集中60 000个像素为28×28的手写数字灰度图像,y_train表示与之对应的标签集合;x_test表示测试集中10 000个像素为28×28的手写数字灰度图像,y_test表示与之对应的标签集合。

示例中的最后一行代码是调用ds_imshow()函数将训练集中的前15个样本绘制为图像,并在每幅图像的正下方显示与之对应的标签。例如,y_train[0]为5,表示与之对应的x_train[0]是手写数字5的灰度图像,即位于图2左上角的样本。

Keras默认是将数据集文件(mnist.npz)存储在用户家目录下的.keras\datasets中。在Windows运行窗口中输入以下命令,如图2所示。

%HOMEPATH%\.keras\datasets

回车或单击OK按钮即可以查看该目录。

使用scikit-learn加载MNIST

与keras.datasets.mnist.load_data()方法类似,scikit-learn也提供了加载MNIST数据集的方法,通过以下代码可以导入datasets模块。

from sklearn import datasets

以下两行代码用于加载MNIST数据集,并将数据集中的前15个样本绘制为图像:

mnist = datasets.load_digits()

ds_imshow(mnist.data[:15].reshape((15,8,8)), mnist.target[:15])

程序运行结果如图3所示。

细心的读者可能已经发现了一个问题,MNIST的每个样本的像素是28×28,而代码中却将每个样本调整为(8,8)。这是因为datasets.load_digits()加载的样本像素并非是28×28,而是8×8,所以图像显得比较模糊。

尽管scikit-learn也提供了以下方法用于加载28×28像素版本的MNIST:

from sklearn.datasets import fetch_openml

X, y = fetch_openml('mnist_784', version=1, return_X_y=True, as_frame=

False)

但是偶尔会遭遇加载缓慢甚至失败。因此建议读者使用keras.datasets.mnist.load_data()方法加载28×28像素版本的MNIST。

图3 运行结果

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2022-01-28,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Python绿色通道 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 使用TensorFlow加载MNIST
  • 使用scikit-learn加载MNIST
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档