首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

Mini-VGG实现CIFAR10数据集分类

概要

本篇博客主要讲解了CIFAR10数据集的预处理,Mini-VGG的tensorflow和keras实现,并实现了CIFAR数据集的分类。

整个项目的github地址为:Mini-VGG-CIFAR10(https://github.com/Daipuwei/Mini-VGG-CIFAR10) ,如果喜欢集的点个Star

一、 Cifar10数据集说明

为了实现VGG16网络对CIFAR10数据集的分类,我们首先得对CIFAR10进行一个详细介绍

Cifar10数据集共有60000张彩色图像,这些图像是32*32,分为10个类,每类6000张图。其中,有50000张用于训练,构成了5个训练批,每一批10000张图;另外10000用于测试,单独构成一批。测试批的数据里,取自10类中的每一类,每一类随机取1000张。抽剩下的就随机排列组成了训练批。注意一个训练批中的各类图像并不一定数量相同,总的来看训练批,每一类都有5000张图。

下面这幅图就是列举了10各类,每一类展示了随机的10张图片:

该数据是由以下三个人收集而来:Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton。第一位是AlexNet的提出者,第三位就更不用说了——深度学习的奠基人。

该数据集的下载网址为:http://www.cs.toronto.edu/~kriz/cifar.html 。这个数据主要有三个下载版本:Python、Matlab和二进制文件(适合于C语言)。由于我主要是利用tensorflow和Keras来实现VGG,因此我下载的是Python版本的数据集。从网站上可以看出,无论下载那个版本的数据集文件都不是挺大,足够学习跑跑程序用。

下面开始导入Cifar10数据集。将官网上下载的数据集打开之后,文件结构如下图所示。主要包含了5个data_batch文件data_batch_1至data_batch_5、1个test_batch文件和1个batches的meta文件。

从Cifar10数据集官网:http://www.cs.toronto.edu/~kriz/cifar.html

上的介绍来看,5个data_batch文件和test_batch文件是利用pickel序列化之后的文件因此在导入 Cifar10数据集必须利用pickel进行解压数据,之后将数据还原。5个data_batch文件和test_batch文件分别代表5个训练集批次和测试集,因此我们首先利用pickel编写解压函数:

在下面项目中训练Mini-VGG时候使用Keras官方的ImageDataGenerator来构造数据生成器,因此首先需要将官方CIFAR10数据集存储转化成适应ImageDataGenerator接口的数据集格式。在这里我们将CIFAR官方中data_batch1至data_batch5作为训练集,test_batch作为验证集,即训练集有5万张图片,验证集有1万张图片。

同时,6万张图像需要利用opencv库重新写入内存这会涉及大量I/O操作,因此为了加快图像写入内存速度,采用了异步多进程方式来实现CIFAR10数据集的写入内存。CIFAR10数据集格式转化脚本如下:

结果如下:

二、训练阶段Mini-VGG的keras实现

2.1 Mini-VGG的网络架构

由于CIFAR10数据集中所有图片的分辨率为32 * 32,VGG16的下采样率为32,那么使用VGG16来实现CIFAR10数据集的分类任务,那么CIFAR10数据集的图像在经过VGG16的卷积模块作用下提取得到特征维度为1 * 1 * 1024。那么这将导致大量特征丢失,反而不利于图像分类。因此为了技能提取得到的特征又能使得特征图不为1 * 1 * 1024,在本次项目中我们对VGG16的结构进行有所删减,形成Mini-VGG。

Mini-VGG的网络架构为:

第一层:INPUT =>

第二层:CONV => ReLU => BN => CONV => ReLU=>BN =>MAXPOOL => DROPOUT =>

第三层:CONV =>ReLU => BN => CONV =>ReLU =>BN => MAXPOOL => DROPOUT =>

第四层:FC => ReLU => BN => DROPOUT =>

第五层:FC => SOFTMAX

2.2 Mini-VGG训练

Mini-VGG的keras实现与利用数据集生成器进行训练的代码如下,由于训练过程中涉及较多参数,为了在Mini-VGG类代码编写过程中指定过多参数,首先实现参数配置类config用来保存训练过程中所有相关参数,并实现将所有参数保存到本地txt文件函数,方便训练过后查看每次训练的相关细节。参数配置类config的定义如下:

接下来是训练阶段Mini-VGG类的定义如下:

那么Mini-VGG的训练脚本如下,在训练过程中为了增加模型的鲁棒性,对于训练集和测试集都进行相关的数据增强,包括旋转、裁剪、水平翻转、垂直翻转、亮度变化等。

经过100个epoch的训练之后,Mini-VGG的训练与验证损失、训练与验证精度的走势图如下图所示。由于在训练过程中设置了早停回调函数,因此100个epoch的训练在不到50个epoch训练周期就结束了。从下损失和精度的走势图可以得知,在验证集上虽然损失在上下波动,但是精度收敛到了85%附近。因此,Mini-VGG在对CIFAR10数据集经过训练之后,拥有较高的分类性能。

三、测试阶段Mini-VGG的keras实现

在训练结束之后,下一步就是准确评估Mini-VGG在数据集上的性能。测试阶段的Mini-VGG类的定义如下:

在一个数据集上评估Mini-VGG性能的脚本如下:

Mini-VGG在CIFAR10的验证集上的评估结果如下:

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20201029A0BW6400?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券