前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >CIFAR-10 数据集介绍

CIFAR-10 数据集介绍

作者头像
用户6021899
发布2019-12-10 16:40:51
5.7K0
发布2019-12-10 16:40:51
举报

CIFAR-10和CIFAR-100是带有标签的数据集,都出自于规模更大的一个数据集,它有八千万张小图片(http://groups.csail.mit.edu/vision/TinyImages/)。CIFAR-10和CIFAR-100的共同主页是:http://www.cs.toronto.edu/~kriz/cifar.html

CIFAR-10数据集共有60000张彩色图像,这些图像是32*32,分为10个类,每类6000张图。这里面有50000张用于训练,构成了5个训练批,每一批10000张图;另外10000用于测试,单独构成一批。

下面这幅图就是列举了这10个分类,每一类展示了随机的10张图片:

该数据集有有如下三种版本:

python版本下载并解压后包含以下文件:

其中的html文件是数据集的官网网页。其它文件均是由cPickle库产生的序列化后的对象。各个data_batch_x 和test_batch 都是包含以下元素的字典:

  • data

类型为uint8s 大小为10000x3072 的numpy 数组。数组的每一行代表一张32X32X3的彩色图片(32X32X3 =3072), 3072个元素中的前1024个元素代表红色通道的亮度值,中间的1024个元素代表绿色通道的亮度值,后面的1024个元素代表蓝色通道的亮度值。

  • label

10000个元素的列表。列表的每个元素范围为0~9的整数,代表10个分类。

batches.meta 包含一个字典对象:

  • label_names

它给出了0~9每个数字对应的分类名。例如label_names[0] == "airplane",label_names[1] == "automobile"等等。

下面的代码可以将CIFAR-10 数据集解析到Numpy数组

代码语言:javascript
复制
import numpy as np
from matplotlib import pyplot as plt
import os

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict
    
def loadDataset(dataset_folder): #训练集
    dataset = np.zeros((10000*5, 3*32*32), dtype = np.int32) #训练集 先用0填充,每个元素都是4byte integer
    labels = np.zeros((10000*5), dtype =np.int32)
    for i in range(5):
        d = unpickle(os.path.join(dataset_folder,"data_batch_%d"%(i+1)))
        #每个文件含1万张图片的数据
        for j in range(len(d[b'labels'])):  #每张图片,shape: (3072,)
            dataset[ 10000*i+j] = d[b'data'][j]
            labels[ 10000*i+j] = d[b'labels'][j]
    reshaped = np.reshape(dataset,(10000*5, 3, 32, 32))
    #交换轴
    swapaxesed = np.swapaxes(reshaped,1,2)# 按图片,行,颜色通道 ,列 排列
    swapaxesed = np.swapaxes(swapaxesed,2,3) # 按图片,行,列,颜色通道 排列
    
    return swapaxesed, labels
    
def loadValidset(path): #验证集
    valid_dataset = np.zeros((10000, 3*32*32), dtype = np.int32) #训练集 先用0填充,每个元素都是4byte integer
    valid_labels = np.zeros((10000), dtype =np.int32)
    d = unpickle(path)
    #测试含1万张图片的数据
    for j in range(len(d[b'labels'])):  #每张图片,shape: (3072,)
        valid_dataset[j] = d[b'data'][j]
        valid_labels[j] = d[b'labels'][j]
    reshaped = np.reshape(valid_dataset,(10000, 3, 32, 32))
    #交换轴
    swapaxesed = np.swapaxes(reshaped,1,2)# 按图片,行,颜色通道 ,列 排列
    swapaxesed = np.swapaxes(swapaxesed,2,3) # 按图片,行,列,颜色通道 排列
    return swapaxesed, valid_labels
    
    
dataset_folder = r"E:\Python36\my tensorflow\DataSet\CIFAR-10\cifar-10-batches-py"
train_dataset, train_labels = loadDataset(dataset_folder)
valid_dataset, valid_labels = loadValidset(os.path.join(dataset_folder,"test_batch"))
#图片识别时才需要
label_names =unpickle(os.path.join(dataset_folder, "batches.meta"))
names = label_names[b'label_names']#分类结果字节字符串


if __name__ == "__main__":
    #测试数据集是否加载成功
    print(min(train_labels)) #标签编码从0开始
    for index in range(100,200):
        a = valid_dataset[index]
        print(a.shape)
        print(a)
        plt.imshow(a)
        print(valid_labels[index])
        plt.title(names[int(valid_labels[index])],fontsize =8)
        plt.xticks([]); plt.yticks([])
        plt.show()

下面我给出由matplotlib绘制的data_batch_1中的前50张图片和它们对应的label:

有些图片真的不怎么好辨认……

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-12-08,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Python可视化编程机器学习OpenCV 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档