前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >TensorFlow2.0 实战强化专栏(二):CIFAR-10项目

TensorFlow2.0 实战强化专栏(二):CIFAR-10项目

作者头像
磐创AI
发布2020-03-04 16:09:25
1K0
发布2020-03-04 16:09:25
举报
作者 | 小猴锅

出品 | 磐创AI团队

CIFAR-10项目

Alex Krizhevsky,Vinod Nair和Geoffrey Hinton收集了8000万个小尺寸图像数据集,CIFAR-10和CIFAR-100分别是这个数据集的一个子集(http://www.cs.toronto.edu/~kriz/cifar.html)。CIFAR-10数据集由10个类别共60000张彩色图片组成,其中每张图片的大小为32X32,每个类别分别6000张。

图1 cifar-10数据集中部分样本可视化

我们首先下载CIFAR-10数据集(http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz),解压之后如图2所示。其中“data_batch_1”至“data_batch_5”是训练文件,每个文件分别有10000个训练样本,共计50000个训练样本,“test_batch”是测试文件,包含了10000个测试样本。

图2 CIFAR-10数据集文件

  • 数据预处理

我们先导入需要用到的包:

代码语言:javascript
复制
1  import tensorflow as tf
2  import numpy as np
3  import pickle
4  import os

由于这些数据文件是使用“cPickle”进行存储的,因此我们需要定义一个函数来加载这些数据文件:

代码语言:javascript
复制
5  def get_pickled_data(data_path):
6      data_x = []
7      data_y = []
8      with open(data_path, mode='rb') as file:
9          data = pickle.load(file, encoding='bytes')
10          x = data[b'data']
11          y = data[b'labels']
12          # 将3*32*32的数组变换为32*32*3
13          x = x.reshape(10000, 3, 32, 32)\
14              .transpose(0, 2, 3, 1).astype('float')
15          y = np.array(y)
16          data_x.extend(x)
17          data_y.extend(y)
18      return data_x, data_y

接下来我们定义一个“prepare_data”函数用来获取训练和测试数据:

代码语言:javascript
复制
19  def prepare_data(path):
20      x_train = []
21      y_train = []
22      x_test = []
23      y_test = []
24      for i in range(5):
25          # train_data_path为训练数据的路径
26          train_data_path = os.path.join(path, ('data_batch_'+str(i + 1)))
27          data_x, data_y = get_pickled_data(train_data_path)
28          x_train += data_x
29          y_train += data_y
30      # 将50000个list型的数据样本转换为ndarray型
31      x_train = np.array(x_train)
32  
33      # test_data_path为测试文件的路径
34      test_data_path = os.path.join(path, 'test_batch')
35      x_test, y_test = get_pickled_data(test_data_path)
36      x_test = np.array(x_test)
37  
38      return x_train, y_train, x_test, y_test
  • 模型搭建

在这个项目里我们将使用RasNet模型,RasNet我们简单的介绍过,它是一个残差网络,一定程度上解决了网络过深后出现的退化问题(论文地址:https://arxiv.org/abs/1512.03385)。ResNet的基本结构是如图3所示的“残差块(residual block)”,右侧是针对50层以上网络的优化结构。

图3 残差块(residual block)

图4所示是一个34层的ResNet的网络结构,ResNet的提出者以VGG-19模型(图4左)为参考,设计了一个34层的网络(图4中),并进一步构造了34层的ResNet(图4右),34层是按有参数更新的层来计算的,图4所示的34层ResNet中有参数更新的层包括第1层卷积层,中间残差部分的32个卷积层,以及最后的一个全连接层。

如图4所示,ResNet中主要使用的是3X3的卷积核,并遵守着两个简单的设计原则:(1)对于每一层卷积层,如果输出的特征图尺寸相同,那么这些层就使用相同数量的滤波器;(2)如果输出的特征图尺寸减半了,那么卷积核的数量加增加一倍,以便保持每一层的时间复杂度。

ResNet的第一层是66个7X7的卷积核,滑动步长为2;接着是一个步长为2的池化层;再接着是16个残差块,共32个卷积层,根据卷积层中卷积核数量的不同可以分为4个部分,每个部分的衔接处特征图的尺寸都缩小了一半,因此卷积核的数量也相应地增加了一倍;残差部分之后是一个池化层,采用平均池化;最后是一个全连接层,并用softmax作为激活函数,得到分类结果。

图4 ResNet34的网络结构

接下来我们先定义残差块:

代码语言:javascript
复制
1  class residual_lock(tf.keras.layers.Layer):
2      def __init__(self, filters, strides=1):
3          super(residual_lock, self).__init__()
4          self.conv1 = tf.keras.layers.Conv2D(filters=filters,
5                                              kernel_size=(3, 3),
6                                              strides=strides,
7                                              padding="same")
8          # 规范化层:加速收敛,控制过拟合
9          self.bn1 = tf.keras.layers.BatchNormalization()
10          self.conv2 = tf.keras.layers.Conv2D(filters=filters,
11                                              kernel_size=(3, 3),
12                                              strides=1,
13                                              padding="same")
14          # 规范化层:加速收敛,控制过拟合
15          self.bn2 = tf.keras.layers.BatchNormalization()
16          # 残差块的第一个卷积层中,卷积核的滑动步长为2时,输出特征图大小减半,
17          # 需要对残差块的输入使用步长为2的卷积来进行下采样,从而匹配维度
18          if strides != 1:
19              self.downsample = tf.keras.Sequential()
20  self.downsample.add(tf.keras.layers.Conv2D(filters=filters, kernel_size=(1, 1), strides=strides))
21  self.downsample.add(tf.keras.layers.BatchNormalization())
22          else:
23              self.downsample = lambda x: x
24  
25      def call(self, inputs, training=None):
26          # 匹配维度
27          identity = self.downsample(inputs)
28  
29          conv1 = self.conv1(inputs)
30          bn1 = self.bn1(conv1)
31          relu = tf.nn.relu(bn1)
32          conv2 = self.conv2(relu)
33          bn2 = self.bn2(conv2)
34  
35          output = tf.nn.relu(tf.keras.layers.add([identity, bn2]))
36  
37          return output

接着我们定义一个函数用来组合残差块:

代码语言:javascript
复制
38  def build_blocks(filters, blocks, strides=1):
39      """组合相同特征图大小的残差块"""
40      res_block = tf.keras.Sequential()
41      # 添加第一个残差块,每部分的第一个残差块的第一个卷积层,其滑动步长为2
42      res_block.add(residual_lock(filters, strides=strides))
43  
44      # 添加后续残差块
45      for _ in range(1, blocks):
46          res_block.add(residual_lock(filters, strides=1))
47  
48      return res_block

定义好残差块和组合组合残差块的函数后,我们就可以实现具体的ResNet模型了:

代码语言:javascript
复制
49  class ResNet(tf.keras.Model):
50      """ResNet模型"""
51      def __init__(self, num_classes=10):
52          super(ResNet, self).__init__()
53  
54          self.preprocess = tf.keras.Sequential([
55              tf.keras.layers.Conv2D(filters=64,
56                                     kernel_size=(7, 7),
57                                     strides=2,
58                                     padding='same'),
59              # 规范化层:加速收敛,控制过拟合
60              tf.keras.layers.BatchNormalization(),
61  tf.keras.layers.Activation(tf.keras.activations.relu),
62              # 最大池化:池化操作后,特征图大小减半
63              tf.keras.layers.MaxPool2D(pool_size=(3, 3),strides=2)
64          ])
65  
66          # 组合四个部分的残差块
67          self.blocks_1 = build_blocks(filters=64, blocks=3)
68  self.blocks_2 = build_blocks(filters=128, blocks=4,strides=2)
69          self.blocks_3 = build_blocks(filters=256, blocks=6, strides=2)
70          self.blocks_4 = build_blocks(filters=512, blocks=3, strides=2)
71  
72          # 平均池化
73          self.avg_pool = tf.keras.layers.GlobalAveragePooling2D()
74          # 最后的全连接层,使用softmax作为激活函数
75  self.fc=tf.keras.layers.Dense(units=num_classes,activation=tf.keras.activations.softmax)
76  
77      def call(self, inputs, training=None):
78          preprocess = self.preprocess(inputs)
79          blocks_1 = self.blocks_1(preprocess)
80          blocks2 = self.blocks_2(blocks_1)
81          blocks3 = self.blocks_3(blocks2)
82          blocks4 = self.blocks_4(blocks3)
83          avg_pool = self.avg_pool(blocks4)
84          out = self.fc(avg_pool)
85  
86          return out

这里ResNet模型的实现完全依照图4中34层的ResNet模型结构。

  • 模型训练

最后我们实现模型的训练部分:

代码语言:javascript
复制
87  if __name__ == '__main__':
88      model = ResNet()
89      model.build(input_shape=(None, 32, 32, 3))
90      model.summary()
91  
92      # 数据集路径
93      path = "./cifar-10-batches-py"
94  
95      # 数据载入
96      x_train, y_train, x_test, y_test = prepare_data(path)
97      # 将类标进行one-hot编码
98      y_train = tf.keras.utils.to_categorical(y_train, 10)
99      y_test = tf.keras.utils.to_categorical(y_test, 10)
100  
101      model.compile(loss='categorical_crossentropy',
102                    optimizer=tf.keras.optimizers.Adam(),
103                    metrics=['accuracy'])
104  
105      # 动态设置学习率
106      lr_reducer = tf.keras.callbacks.ReduceLROnPlateau(
107          monitor='val_accuracy',
108          factor=0.2, patience=5,
109          min_lr=0.5e-6)
110      callbacks = [lr_reducer]
111  
112      # 训练模型
113      model.fit(x_train, y_train,
114                batch_size=50, epochs=20,
115                verbose=1, callbacks=callbacks,
116                validation_data=(x_test, y_test),
117                shuffle=True)

在第106行代码中我们设置了动态学习率,并通过“callbacks”传递给模型。“tf.keras.callbacks.ReduceLROnPlateau”函数可以用来动态调整学习率,参数“monitor”是我们要监测的指标,“factor”是调整学习率时的参数(新的学习率=旧的学习率*factor),“patience”个回合后如果“monitor”指定的指标没有变化,则对学习率进行调整,“min_lr”限定了学习率的下限。

训练过程的Accuracy和Loss的变化如下:

图5 ResNet34训练过程中Accuracy和Loss的变化

(橙色为训练集,蓝色为验证集)

最终在验证集上的准确率为76.12%,有过拟合的现象,准确率还有提升的空间。有兴趣进一步提升分类效果的读者可以尝试如下方法:

1) 数据集增强:通过旋转、平移等操作来扩充数据集;

2) 参数微调:包括训练的回合数、学习率等;

3) 修改模型:可以尝试在ResNet32的基础上修改模型的结构,或者替换其它网络模型;

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-02-14,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 磐创AI 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • CIFAR-10项目
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档