MNIST数据集的格式转换

以前直接用的是sklearn或者TensorFlow提供的mnist数据集,已经转换为矩阵形式的数据格式。但是sklearn体用的数据集合并不全,一共只有3000+图,每个图是8*8的大小,但是原始数据并不是这样的。 MNIST数据集合的原始网址为:http://yann.lecun.com/exdb/mnist/ 进入官网,发现有4个文件,分别对应训练集、测试集的图像和标签:

官网给的数据集合并不是原始的图像数据格式,而是编码后的二进制格式: 图像的编码为:

典型的head+data模式:前16个字节分为4个整型数据,每个4字节,分别代表:数据信息des、图像数量(img_num),图像行数(row)、图像列数(col),之后的数据全部为像素,每row*col个像素构成一张图,每个色素的值为(0-255)。 标签的编码为:

模式和前面的一样,不同的是head只有8字节,分别为des和标签的数量(label_num).之后每一个字节代表一个标签,值为(0-9)。 弄清楚编码后,就可以直接上代码了:

import numpy as np
import struct
mnist_dir = r'./digit/'
def fetch_mnist(mnist_dir,data_type):
    train_data_path = mnist_dir + 'train-images.idx3-ubyte'
    train_label_path = mnist_dir + 'train-labels.idx1-ubyte'
    test_data_path = mnist_dir + 't10k-images.idx3-ubyte'
    test_label_path = mnist_dir + 't10k-labels.idx1-ubyte'

    # train_img
    with open(train_data_path, 'rb') as f:
        data = f.read(16)
        des,img_nums,row,col = struct.unpack_from('>IIII', data, 0)
        train_x = np.zeros((img_nums, row*col))
        for index in range(img_nums):
            data = f.read(784)
            if len(data) == 784:
                train_x[index,:] = np.array(struct.unpack_from('>' + 'B' * (row * col), data, 0)).reshape(1,784)
        f.close()
    # train label
    with open(train_label_path, 'rb') as f:
        data = f.read(8)
        des,label_nums = struct.unpack_from('>II', data, 0)
        train_y = np.zeros((label_nums, 1))
        for index in range(label_nums):
            data = f.read(1)
            train_y[index,:] = np.array(struct.unpack_from('>B', data, 0)).reshape(1,1)
        f.close()

        # test_img
        with open(test_data_path, 'rb') as f:
            data = f.read(16)
            des, img_nums, row, col = struct.unpack_from('>IIII', data, 0)
            test_x = np.zeros((img_nums, row * col))
            for index in range(img_nums):
                data = f.read(784)
                if len(data) == 784:
                    test_x[index, :] = np.array(struct.unpack_from('>' + 'B' * (row * col), data, 0)).reshape(1, 784)
            f.close()
        # test label
        with open(test_label_path, 'rb') as f:
            data = f.read(8)
            des, label_nums = struct.unpack_from('>II', data, 0)
            test_y = np.zeros((label_nums, 1))
            for index in range(label_nums):
                data = f.read(1)
                test_y[index, :] = np.array(struct.unpack_from('>B', data, 0)).reshape(1, 1)
            f.close()
        if data_type == 'train':
            return train_x, train_y
        elif data_type == 'test':
            return test_x, test_y
        elif data_type == 'all':
            return train_x, train_y,test_x, test_y
        else:
            print('type error')

if __name__ == '__main__':
    tr_x, tr_y, te_x, te_y = fetch_mnist(mnist_dir,'all')
    import matplotlib.pyplot as plt # plt 用于显示图片
    img_0 = tr_x[59999,:].reshape(28,28)
    plt.imshow(img_0)
    print(tr_y[59999,:])
    img_1 = te_x[500,:].reshape(28,28)
    plt.imshow(img_1)
    print(te_y[500,:])
    plt.show()

运行结果:

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏https://www.cnblogs.com/L

【Caffe篇】--Caffe solver层从初始到应用

672
来自专栏小鹏的专栏

01 TensorFlow入门(1)

tensorflow_cookbook--第1章 TensorFlow入门         Google的TensorFlow引擎具有独特的解决问题的方法。 ...

21610
来自专栏Jack-Cui

Caffe学习笔记(三):cifar10_quick_train_test.prototxt配置文件分析

运行平台: Ubuntu14.04     在上篇笔记中,已经记录了如何进行图片数据格式的转换和生成txt列表清单文件。本篇笔记主要记录如何计算图片数据的均值和...

2528
来自专栏FreeBuf

学点算法搞安全之HMM(下篇)

前言 我们介绍了HMM的基本原理以及常见的基于参数的异常检测实现,这次我们换个思路,把机器当一个刚入行的白帽子,我们训练他学会XSS的攻击语法,然后再让机器从访...

1888
来自专栏IT派

【深度学习入门系列】TensorFlow训练线性回归

作者:董超 来源:腾讯云技术社区「腾云阁」 上一篇文章我们介绍了 MxNet 的安装,但 MxNet 有个缺点,那就是文档不太全,用起来可能是要看源代码才能理...

3133
来自专栏企鹅号快讯

无人驾驶系列——深度学习笔记:Tensorflow基本概念

随着无人驾驶的火爆,深度学习在无人驾驶中的应用受到广泛关注,我在工作中对此有所接触,因此进行了相关学习和整理,给大家大家可以参考。 ? TensorFlow深度...

2166
来自专栏Jack-Cui

Caffe学习笔记(七):使用训练好的model做预测(mnist)

Python版本: Python2.7 运行平台: Ubuntu14.04 一、前言     在之前的笔记中,已经生成了训练好的mnist.cafffemod...

2735
来自专栏ATYUN订阅号

ChainerCV: 一个用于深度学习的计算机视觉库

ChainerCV是一个基于Chainer用于训练和运行计算机视觉任务的神经网络工具。它涵盖了计算机视觉模型的高质量实现,以及开展计算机视觉研究的必备工具集。 ...

3737
来自专栏深度学习那些事儿

深度学习图像分割(二)——如何制作自己的PASCAL-VOC2012数据集

在之前的那篇文章中:深度学习图像分割(一)——PASCAL-VOC2012数据集(vocdevkit、Vocbenchmark_release)详细介绍 我们大...

5254
来自专栏Petrichor的专栏

图像处理: 如何将 像素值 控制在 值域[0, 255]

在做计算机视觉方向项目的时候,往往需要进行图像处理。但是在此过程中,常常会遇到 对 像素值 进行 变换计算 后,像素值 超出 值域区间 [0, 255] 的情况...

814

扫码关注云+社区