python mnist数据导入以及处理

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/haluoluo211/article/details/81042529

在使用机器学习以及深度学习的时,常用的示例是使用mnist数据进行分类,本文简要的实现下mnis数据的导入以及处理,问题来源stackoverflow

直接上代码了,注释很清楚了:

import cPickle
import gzip
import numpy as np
import matplotlib.pyplot as plt

def load_data():
    path = '../../data/mnist.pkl.gz'
    f = gzip.open(path, 'rb')
    training_data, validation_data, test_data = cPickle.load(f)
    f.close()

    X_train, y_train = training_data[0], training_data[1]
    print X_train.shape, y_train.shape
    # (50000L, 784L) (50000L,)

    # get the first image and it's label
    img1_arr, img1_label = X_train[0], y_train[0]
    print img1_arr.shape, img1_label
    # (784L,) , 5

    # reshape first image(1 D vector) to 2D dimension image
    img1_2d = np.reshape(img1_arr, (28, 28))
    # show it
    plt.subplot(111)
    plt.imshow(img1_2d, cmap=plt.get_cmap('gray'))
    plt.show()

输出结果如下:

对label进行向量化:

def vectorized_result(label):
    e = np.zeros((10, 1))
    e[label] = 1.0
    return e

print vectorized_result(img1_label)
# output as below:
[[ 0.]
 [ 0.]
 [ 0.]
 [ 0.]
 [ 0.]
 [ 1.]
 [ 0.]
 [ 0.]
 [ 0.]
 [ 0.]]

我们也可以使用简单的for循环来将上述的784为输入向量转化为28*28维向量给CNN使用:

def load_data_v2():
    path = '../../data/mnist.pkl.gz'
    f = gzip.open(path, 'rb')
    training_data, validation_data, test_data = cPickle.load(f)
    f.close()

    X_train, y_train = training_data[0], training_data[1]
    print X_train.shape, y_train.shape
    # (50000L, 784L) (50000L,)

    X_train = np.array([np.reshape(item, (28, 28)) for item in X_train])
    y_train = np.array([vectorized_result(item) for item in y_train])

    print X_train.shape, y_train.shape
    # (50000L, 28L, 28L) (50000L, 10L, 1L)

来源自己的stack overflow回答。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏人工智能头条

实战Google深度学习框架:TensorFlow计算加速

2965
来自专栏人工智能LeadAI

YOLO:实时目标检测

一瞥(You Only Look Once, YOLO),是检测Pascal VOC(http://host.robots.ox.ac.uk:8080/pasc...

9977
来自专栏ml

mxnet运行时遇到问题及解决方法

1.训练好模型之后,进行预测时出现这种错误: 1 mxnet.base.MXNetError: [15:05:50] src/ndarray/ndarray.c...

5824
来自专栏北京马哥教育

python数据可视化----绘制各种图形

1.环境 系统:windows10 python版本:python3.6.1 使用的库:matplotlib,numpy 2.numpy库产生随机数几种方法 n...

4677
来自专栏GAN&CV

迁移学习之--tensorflow选择性加载权重

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_25737169/article/d...

1904
来自专栏数据小魔方

时间管理的工具——甘特图(Gantt chart)

今天跟大家分享一种用作时间管理的工具——甘特图(Gantt Chart)。 ▽▼▽ 这种图表的制作理念非常简单,就是通过设定项目开始时间和持续时间,利用堆积条形...

6497
来自专栏数据小魔方

饼图的两个变体——双饼图、饼柱图

今天给大家讲解图表中饼图的两个变体——双饼图、饼柱图 饼图的两个变体 ▽ 一 双饼图 通常如果一个数据系列要做对比 数据量较少并且数据之间差异不大的话还好 但是...

3404
来自专栏WOLFRAM

三维图形绘制指定区域的方法

1483
来自专栏深度学习与数据挖掘实战

【深度学习】资料整理

1002
来自专栏机器人网

深度学习三要素:数据、模型、计算

数据来源:主要通过对初始数据图片进行人工标注和机器标注。数据样本非常的重要,好的样本等于成功了一半。

912

扫码关注云+社区

领取腾讯云代金券