以github上yscbm的代码为例进行讲解,代码链接:https://github.com/yscbm/tensorflow/blob/master/common/extract_cifar10.py
首先导入必要的模块
import gzip
import numpy as np
import os
import tensorflow as tf
我们定义一些变量,因为针对的是cifar10数据集,所以变量的值都是固定的,为什么定义这些变量呢,因为变量的名字可以很直观的告诉我们这个数字的代表什么,试想如果代码里面全是些数字,我们会不会看糊涂了呢,我们知道cifar10数据集下载下来你会发现有data_batch_1.bin,data_batch_2.bin….data_batch_5.bin五个作为训练,test_batch.bin作为测试,每一个文件都是10000张图片,因此50000张用于训练,10000张用于测试
LABEL_SIZE = 1
IMAGE_SIZE = 32
NUM_CHANNELS = 3
PIXEL_DEPTH = 255
NUM_CLASSES = 10