专栏首页锦小年的博客MNIST数据集的格式转换

MNIST数据集的格式转换

以前直接用的是sklearn或者TensorFlow提供的mnist数据集,已经转换为矩阵形式的数据格式。但是sklearn体用的数据集合并不全,一共只有3000+图,每个图是8*8的大小,但是原始数据并不是这样的。 MNIST数据集合的原始网址为:http://yann.lecun.com/exdb/mnist/ 进入官网,发现有4个文件,分别对应训练集、测试集的图像和标签:

官网给的数据集合并不是原始的图像数据格式,而是编码后的二进制格式: 图像的编码为:

典型的head+data模式:前16个字节分为4个整型数据,每个4字节,分别代表:数据信息des、图像数量(img_num),图像行数(row)、图像列数(col),之后的数据全部为像素,每row*col个像素构成一张图,每个色素的值为(0-255)。 标签的编码为:

模式和前面的一样,不同的是head只有8字节,分别为des和标签的数量(label_num).之后每一个字节代表一个标签,值为(0-9)。 弄清楚编码后,就可以直接上代码了:

import numpy as np
import struct
mnist_dir = r'./digit/'
def fetch_mnist(mnist_dir,data_type):
    train_data_path = mnist_dir + 'train-images.idx3-ubyte'
    train_label_path = mnist_dir + 'train-labels.idx1-ubyte'
    test_data_path = mnist_dir + 't10k-images.idx3-ubyte'
    test_label_path = mnist_dir + 't10k-labels.idx1-ubyte'

    # train_img
    with open(train_data_path, 'rb') as f:
        data = f.read(16)
        des,img_nums,row,col = struct.unpack_from('>IIII', data, 0)
        train_x = np.zeros((img_nums, row*col))
        for index in range(img_nums):
            data = f.read(784)
            if len(data) == 784:
                train_x[index,:] = np.array(struct.unpack_from('>' + 'B' * (row * col), data, 0)).reshape(1,784)
        f.close()
    # train label
    with open(train_label_path, 'rb') as f:
        data = f.read(8)
        des,label_nums = struct.unpack_from('>II', data, 0)
        train_y = np.zeros((label_nums, 1))
        for index in range(label_nums):
            data = f.read(1)
            train_y[index,:] = np.array(struct.unpack_from('>B', data, 0)).reshape(1,1)
        f.close()

        # test_img
        with open(test_data_path, 'rb') as f:
            data = f.read(16)
            des, img_nums, row, col = struct.unpack_from('>IIII', data, 0)
            test_x = np.zeros((img_nums, row * col))
            for index in range(img_nums):
                data = f.read(784)
                if len(data) == 784:
                    test_x[index, :] = np.array(struct.unpack_from('>' + 'B' * (row * col), data, 0)).reshape(1, 784)
            f.close()
        # test label
        with open(test_label_path, 'rb') as f:
            data = f.read(8)
            des, label_nums = struct.unpack_from('>II', data, 0)
            test_y = np.zeros((label_nums, 1))
            for index in range(label_nums):
                data = f.read(1)
                test_y[index, :] = np.array(struct.unpack_from('>B', data, 0)).reshape(1, 1)
            f.close()
        if data_type == 'train':
            return train_x, train_y
        elif data_type == 'test':
            return test_x, test_y
        elif data_type == 'all':
            return train_x, train_y,test_x, test_y
        else:
            print('type error')

if __name__ == '__main__':
    tr_x, tr_y, te_x, te_y = fetch_mnist(mnist_dir,'all')
    import matplotlib.pyplot as plt # plt 用于显示图片
    img_0 = tr_x[59999,:].reshape(28,28)
    plt.imshow(img_0)
    print(tr_y[59999,:])
    img_1 = te_x[500,:].reshape(28,28)
    plt.imshow(img_1)
    print(te_y[500,:])
    plt.show()

运行结果:

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • TensorFlow v2.x使用说明[2]-模型构建

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。 ...

    锦小年
  • python学习笔记6.1-类的认识

    前言 都在说,在Python中“无所不对象”,是因为Python是一个面向对象的动态编程语言,因而它具有:多态(Polymorphism),继承(Inherit...

    锦小年
  • python学习笔记4.2-python高级之迭代器

    迭代是Python中最强有力的特性之一,同时对编程人员来说,也是最难理解的一种用法。其实从高层次来看,迭代就是一种处理序列中元素的方式。通过自定义迭代对象可以...

    锦小年
  • 深度学习之二分类问题

    IMDB 数据集包含来自互联网电影数据库(IMDB)的 50 000 条严重两极分化的评论。

    李小白是一只喵
  • 天池-OGeek算法挑战赛Baseline(0.7016)

    此baseline是西安电子科技大学小幸运所提供,借此平台能够帮助更多新入手同学。

    Coggle数据科学
  • 重回机器学习-《python机器学习及实践》读书笔记一

            以前也算比较系统接触过机器学习吧,记得最早的时候是大二,机器学习才刚开始提起,更多的是说统计学习。那个时候,深度学习似乎都还没有听过,看的第一本...

    钱塘小甲子
  • 卷积神经网络的Helloworld例子

    下面是keras官方的卷积神经网络在github上的例子。和原版的唯一区别是:mnist的数据因为在国外(由于翻墙的原因,报错Exception: URL fe...

    马克java社区
  • 搭建浅层神经网络"Hello world"

    下面我们会用 TensorFlow 搭建一个浅层的神经网络来运行 "hello world!" 模型。 以下内容和模块的运算,均在矩池云平台进行。

    矩池云
  • (数据科学学习手札17)线性判别分析的原理简介&Python与R实现

    之前数篇博客我们比较了几种具有代表性的聚类算法,但现实工作中,最多的问题是分类与定性预测,即通过基于已标注类型的数据的各显著特征值,通过大量样本训练出的模型,来...

    Feffery
  • 在lightgbm中使用交叉验证

    防止过拟合的参数: max_depth 树的深度,不要设置的太大; num_leaves 应该小于 2^(max_depth),否则可能会导致过拟合; m...

    生信编程日常

扫码关注云+社区

领取腾讯云代金券