首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

博客 CIFAR10 数据预处理

本系列文章已由作者授权在AI研习社发布。

欢迎关注我的AI研习社博客:

http://www.gair.link/page/center/myPage/5104751,

或订阅我的 CSDN:

https://blog.csdn.net/Kuo_Jun_Lin

Brief 概述

在上一章中我们使用了 MNIST 手写数字数据集,套入一个非常简单的线性模型中,得到了大约 90% 左右的正确率,用意在于熟悉神经网络节点的架构和框架的使用方法,接下来这章将把前一章的数据集和方法全面提升一个档次,使用的是 CIFAR10 与 CNN 卷积神经网络的架构,同时也可以做为探讨深层神经网络如 VGG19,GoogleNet,与 ResNet 的敲门砖。

CNN 卷积神经网络假设大家已经有一个大致的了解,它不像线性回归的方法,从每个像素着手发现归类到不同标签的规则,而是使用卷积核逐步扫描整张图片的方式抽取出图像特征,经过逐个卷积和在逐层维度上特征的抽离处理,最终把他们与全连阶层相连,通往标签的归类,但是说着简单,其实操作上还有许多细节需要注意如下面几点:

借鉴上一次的代码运行过程,首先第一件事就是减少「类」中函数的冗长定义,因为每次呼叫类的方法时,其实类中的内容都会被重新刷新一遍,多次反复下来就是一个有负担的计算量。

图片不再是简单的手写数字,CIFAR10 有背景与对应标签的图案,因此为了更好的训练,图片需要做预处理,随机的:旋转角度,灰阶度,对比度,图像尺寸调整,明暗度,色调,与裁切,都是可以尝试的手法。

下面我们将探讨一个数据集在多个维度的比较,并尝试出最好正确率的排列组合。

p.s. 卷积神经网络搭建开始之前必须先确定自身电脑内存是否 >= 8G,虽然这个网络在 CNN 算法中非常简单,但如果从最开始的神经网络加总,一共也会有几十万个参数的量,需要注意电脑是否能够承载。

Code[1]

import sys

import tensorflowastf

print(tf.__version__)

1.10.1

CIFAR10 Dataset

它是一个内涵六万张图片的数据集,比起 MNIST,它的通道数是三个,用来表示其彩色的画面,并且图像尺寸是 32*32,其中分成训练集五万张与测试集一万张,制作人在打包数据的时候分了几个文档如下图:

其内部排列方式为一个大的字典,图片数据对应 'data' 字典键,标签数据对应 'labels' 字典键,而单张图片数据排布方式为一个一维列表:[...1024 red ... ...1024 green... ...1024 blue...],读取的方式可以直接点击官网网址。

为了使自己能够更加熟悉数据集内部结构的解析,同时 CIFAR10 官网只告知了打开它们数据集的方法,我们需要如使用 MNIST 的情况一样开始自己定义我们所需要的函数,不外乎数据读取,数据格式转换,one_hot 等大类,步骤如下:

1. Define functions without being iterated with class

定义的函数分别在如下陈列:

time_counter(): 是一个装饰器,功能是用来计时一个函数启动的时间

one_hot(): 用来把标签转换成 one hot 形式,方便后面神经网络归类匹配使用

get_random_batch(): 随机抽取样本做为一个簇后,方便小批量训练

Code[2]

importtime

importnumpyasnp

# To set a decorator used to count the time a func spent.

deftime_counter(func):

# In order to count many func's time, arguments should be *args and **kwargs

defwrapper(*args, **kwargs):

t1 = time.time()

result = func(*args, **kwargs)

t2 = time.time() - t1

print('Took sec to run "" func'.format(t2, func.__name__))

returnresult

returnwrapper

# To convert the number labels into one hot mode respectively.

defone_hot(labels, class_num=10):

convert = np.eye(class_num, dtype=float)[labels]

returnconvert

# To get a random batch so that we can easily put data to train a model.

defget_random_batch(data, batch_size=32):

random = np.random.randint(, len(data), size=batch_size)

returndata[random]

2. Define a class used to well organized take apart the dataset

由于数据是呈现 5 个批次储存,其中的函数设定我希望把他们融合成一块,后面处理和调用也表方便,并且其图片大小为 32x32 的尺寸,并不至于大到没办法一次容纳,因此设置的函数方法如下陈列:

load_binary_data(): 把二进制数据读取出来,并依照字典键的要求给出一个 numpy 数组的结果,方便后面数据处理

merge_batches(): 把全部批次的训练集数据全部融合起来成为一个大的数组

set_validation(): 设置一个验证集在训练集的比例,如果有不同的模型搭建可能会用到此功能

format_images(): 把一个 1D 向量表示的数据转换成卷积方法需要用到的 4D 格式(Batch, Height, Width, Channels)

Code[3]

# pickle is the module to open cifar10 dataset

import pickle

import os, sys

# This class is used to refer the arranged content of CIFAR10 dataset

classCIFAR10:

# The unchangeable variables should be set here.

image_size =32

image_channels =3

def__init__(self, val_ratio=.1, data_dir='cifar-10-batches-py'):

# Validation set can also be set if it is necessary for other purposes

self.val_ratio = val_ratio

self.data_dir = data_dir

# Get the overall images data "without formatting"!

self.img_train =self.merge_batches('data')

self.img_train_main,self.img_train_val =self.set_validation(self.img_train)

self.lab_train =self.merge_batches('labels')

self.lab_train_main,self.lab_train_val =self.set_validation(self.lab_train)

self.img_test =self.load_binary_data('test_batch','data') /255.0

self.lab_test =self.load_binary_data('test_batch','labels').astype(np.int)

# The data format is binary mode and we should load them with pickle module

# which is introduced at the official web page.

defload_binary_data(self, file_name, dic_key):

path = os.path.join(self.data_dir, file_name)

with open(path,'rb') asfile:

dic = pickle.load(file, encoding='bytes')

# Those binary data are all contained by a dictionary also with

# binary type of dictionary key. The returned list should also be

# converted into np.array so that it can be indexed conveniently.

try:

dic_key = dic_key.encode(encoding='utf-8')

returnnp.array(dic[dic_key])

except:

print('dic_key argument accepts only 4 keys as follow:\n',

'1.batch_label ; 2.labels ; 3.data ; 4.filenames')

# There are five separated images dataset and we will want to

# depose of them all at once.

defmerge_batches(self, dic_key):

merge = []

foriinrange(5):

filename ='data_batch_{}'.format(i+1)

data =self.load_binary_data(filename, dic_key)

merge.append(data)

np_merge = np.array(merge)

ifdic_key =='data':

length =self.image_size *self.image_size *self.image_channels

np_merge = np_merge.reshape(5*len(data), length)

returnnp.array(np_merge) /255.0

else:

np_merge = np_merge.reshape(5*len(data))

returnnp.array(np_merge).astype(np.int)

defset_validation(self, data):

val_set = round(len(data) *self.val_ratio)

val_data = data[:val_set]

main_data = data[val_set:]

return[main_data, val_data]

# The 1D array representing an image should be converted to the format

# that is as same as the regular image format (H, W, C)

defformat_images(self, images_flat):

# The format of original data has (10000, 3072) shape matrix

# with conjoint red 1024, green 1024, blue 1024.

images = images_flat.reshape([-1,self.image_channels,

self.image_size,self.image_size])

# when depositing images, channels should stay at the last dimension.

images = images.transpose([,2,3,1])

returnimages

@property

defget_class_names(self):

path = os.path.join(self.data_dir,'batches.meta')

with open(path,'rb') asfile:

dic = pickle.load(file, encoding='bytes')

class_names = [w.decode('utf-8')forwindic[b'label_names']]

fornum, labelinenumerate(class_names):

print('{}: {}'.format(num, label))

returnclass_names

@property

defnum_per_batch(self):

path = os.path.join(self.data_dir,'batches.meta')

with open(path,'rb') asfile:

dic = pickle.load(file, encoding='bytes')

returndic[b'num_cases_per_batch']

path = input('The directory of CIFAR10 dataset: ')

cifar = CIFAR1(data_dir=path)

cifar.get_class_names

print("Number per batch: {}".format(cifar.num_per_batch))

The directory of CIFAR10 dataset:/Users/kcl/Documents/Python_Projects/cifar-10-batches-py

: airplane

1: automobile

2: bird

3:cat

4: deer

5: dog

6: frog

7: horse

8: ship

9: truck

Number per batch:10000

3. Print Images and Labels respectively

为了验证导入的数据集是否与标签匹配,避免在模型训练前数据集基础就已经歪得一塌糊涂,结合了上面定义的 .format_images() 方法与 get_random_batch() 函数套入以下定义的绘图函数中,随机抽样查看数据匹配的完整性,代码如下:

Code[4]

importmatplotlib.pyplotasplt

images_flat_train = cifar.img_train

images_train = cifar.format_images(images_flat_train)

labels_train = cifar.lab_train

images_flat_test = cifar.img_test

images_test = cifar.format_images(images_flat_test)

labels_test = cifar.lab_test

# To define a universal purpose oriented plotting function here.

# It should not only be able to plot correct images, but also is

# capable of plotting the predicted labels.

defplot_images(images, labels, lab_names, size=[3,3],

pred_labels=False, random=True, smooth=True):

fig, axes = plt.subplots(size[], size[1])

fig.subplots_adjust(hspace=0.6, wspace=0.6)

forn, axinenumerate(axes.flat):

# To decide if the printed images should be smooth or not.

ifsmooth:

interpolation ='spline16'

else:

interpolation ='nearest'

# To decide if the images should be randomly picked up.

ifrandom:

i = np.random.randint(, len(labels), size=None, dtype=np.int)

else:

i = n

ax.imshow(images[i], interpolation=interpolation)

ifpred_labelsisFalse:

xlabel ='T: {}'.format(lab_names[labels[i]])

else:

xlabel ='T: \nP:'.format(lab_names[labels[i]],

lab_names[pred_labels[i]])

ax.set_xlabel(xlabel)

ax.set_xticks([])

ax.set_yticks([])

plt.show()

plot_images(images_train, labels_train, cifar.get_class_names, size=[3,5])

: airplane

1: automobile

2: bird

3: cat

4: deer

5: dog

6: frog

7: horse

8: ship

9: truck

Data Preprocessing 数据预处理

如同概述部分提及的图像预处理步骤,接下来要使用下面 Tensorflow 所提供的方法来实现图像的随机改动:

p.s. 还有很多 Tensorflow 框架支持的图像处理方法,点击此查看官网

对输入数据使用上面函数方法做改动就如同给数据集加了几个维度的数据,而丰富的数据集正是神经网络能够达到更高归类准确率的基本要素,同时还可减少过拟合的结果发生,换个角度思考这些产生的数据,它们就如同数据的噪声,为过拟合可能发生的情况提供了一道保险。

不过当使用此方法在训练的时候,产生数据的过程会添加计算的负担,进而造成时间上的消耗,是我们应用此方法的时候一个重要的考虑要点。

结合上述方法定义的函数代码如下:

Code[5]

def image_preprocessing(single_img, crop=[28,28], crop_only=False):

H, W = cifar.image_size, cifar.image_size

height, width = crop

single_img =tf.random_crop(single_img, size=[height, width,3])

single_img =tf.image.random_flip_left_right(single_img)

single_img =tf.image.random_flip_up_down(single_img)

single_img =tf.image.random_contrast(single_img, lower=0.5, upper=1.0)

single_img =tf.image.random_hue(single_img, max_delta=0.03)

single_img =tf.image.random_brightness(single_img, max_delta=0.2)

single_img =tf.image.random_saturation(single_img, lower=0.5, upper=1.5)

single_img =tf.minimum(single_img,1.0)

single_img =tf.maximum(single_img,0.0)

single_img =tf.image.resize_image_with_crop_or_pad(

single_img, target_height=H, target_width=W)

returnsingle_img

此函数的逻辑为下面陈列的几点说明:

调整我们要随机位置裁切的尺寸大小后

对裁切下来的图像开始随意颠倒,变化色调等等

把超出 RGB 三个单元最大值和最小值的部分抹平

把缩小尺寸的裁切团重新 padding 回到原本未裁切的大小,目的是使用数据流图时测试机不需要预处理图像就能够测试,此一做法更为合理

上面定义的函数必须强调的是,它只处理 "单一张" 图片,如果关联到批量处理,例如我们习惯于把一整批图像数据用 4D 张量的方式表示,格式分别为 (张数,图高,图宽,颜色阶数),则可以使用 tf.map_fn 配合 lambda 的方式一次随机处理整批图像数据,并且每张图像数据的调整系数本身都不尽相同,最后面即为详细的搭配使用代码与说明。

A glimpse to the Preprocessed Images

为了确定我们处理的数据完整性与效果,下面尝试使用我们定义好的函数来随机打印预处理图片集的结果,步骤如下:

导入数据集,并使用定义的类方法呼叫训练图像

使用 Tensorflow 框架的构建方法,把导入的数据集放入我们预先定义好的函数中

启动 tf 会话 .Session() 功能

sess.run() 了上个函数的运算结果后,才把这里的运算结果放入绘图函数中

等待时间约为一分半钟,预处理好后即自行打印

Code[6]

importtensorflowastf

lab_train = cifar.lab_train

format_imgs = cifar.format_images(cifar.img_train)

# We can put every single element of a list into the argument which

# is belonging to tf.map_fn()'s fn by using lambda expression so it can

# iterate all elements to the preset function "image_preprocessing".

format_imgs = tf.map_fn(lambdaimg: image_preprocessing(img, crop=[24,24]), format_imgs)

sess = tf.Session()

format_imgs = sess.run(format_imgs)

plot_images(format_imgs, lab_train, cifar.get_class_names, size=[3,4])

: airplane

1: automobile

2: bird

3: cat

4: deer

5: dog

6: frog

7: horse

8: ship

9: truck

文章回顾

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20181012A1Q46100?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券