在Keras+TF环境中,用迁移学习和微调做专属图像识别系统

图1:CompCars数据集的示例图像,整个数据集包含163家汽车制造商,1713种车型
王小新 编译自 Deep Learning Sandbox 量子位 出品 | 公众号 QbitAI

量子位曾经编译过Greg Chu的一篇文章,介绍了如何用Keras+TF,来实现ImageNet数据集日常对象的识别

但是,你要研究的物体,往往不在那个列表中。我们可能想要区分出不同型号的太阳镜、认出不同的鞋子、识别各种面部表情、说出不同汽车的型号、在X光影像下判定肺部疾病的类型,这时候该怎么办?

Greg Chu,博客Deep Learning Sandbox的作者,又写了一篇文章,教你在Keras + TensorFlow环境中,用迁移学习(transfer learning)和微调(fine-tuning)定制你专属的图像识别系统,来辨识特定的研究对象。

为什么要使用迁移学习和微调?

一般来说,从头开始训练一个卷积神经网络,不仅需要大规模的数据集,而且会占用大量的计算资源。比如,为了得到ImageNet ILSVRC模型,Google使用了120万张图像,在装有多个GPU的服务器上运行2-3周才完成训练。

在实际应用中,深度学习相关的研究人员和从业者通常运用迁移学习和微调方法,将ImageNet等数据集上训练的现有模型底部特征提取层网络权重传递给新的分类网络。这种做法并不是个例。

这种做法的效果很好。Razavian等人2014年发表的论文*表明,从ImageNet ILSVRC的训练模型中,简单地提取网络权重的初级特征,应用在多种图像分类任务中,都取得了与ImageNet网络相同或几乎相同的分类效果。

* CNN Features off-the-shelf: an Astounding Baseline for Recognition:https://arxiv.org/pdf/1403.6382.pdf

接下来我们介绍一下这两种方法:

迁移学习:在ImageNet上得到一个预训练好的ConvNet网络,删除网络顶部的全连接层,然后将ConvNet网络的剩余部分作为新数据集的特征提取层。这也就是说,我们使用了ImageNet提取到的图像特征,为新数据集训练分类器。

微调:更换或者重新训练ConvNet网络顶部的分类器,还可以通过反向传播算法调整预训练网络的权重。

该选择哪种方法?

有两个主要因素,将影响到所选择的方法:

1. 你的数据集大小;

2. 新数据集与预训练数据集的相似性,通常与ImageNet数据集相比。

内容相似性较高

内容相似性较低

小型数据集

迁移学习:高级特征+分类器

迁移学习:低级特征+分类器

大型数据集

微调

微调

上表指出了在如下4个场景下,该如何从这两种方法中做选择:

新数据集相比于原数据集在样本量上更小,在内容上相似:如果数据过小,考虑到过拟合,这使用微调则效果不大好。因为新数据类似于原数据,我们希望网络中高级特征也与此数据集相关。因此,最好的思路可能是在ConvNet网络上重新训练一个线性分类器。

新数据集相比于原数据集在样本量上更小,且内容非常不同:由于数据较小,只训练一个线性分类器可能更好。但是数据集不同,从网络顶部开始训练分类器不是最好的选择,这里包含了原有数据集的高级特征。所以,一般是从ConvNet网络前部的激活函数开始,重新训练一个线性分类器。

新数据集相比于原数据集在样本量上较大,在内容上相似:由于我们有更多的数据,所以在我们试图微调整个网络,那我们有信心不会导致过拟合。

新数据集相比于原数据集在样本量上较大,但内容非常不同:由于数据集很大,我们可以尝试从头开始训练一个深度网络。然而,在实际应用中,用一个预训练模型的网络权重来初始化新网络的权重,仍然是不错的方法。在这种情况下,我们有足够的数据和信心对整个网络进行微调。

另外,在新数据集样本量较大时,你也可以尝试从头开始训练一个网络。

数据增强

数据增强方法能大大增加训练数据集的样本量和增大网络模型的泛化能力。实际上,在数据比赛中,每个获胜者的ConvNet网络一定会使用数据增强方法。在本质上,数据增强是通过数据转换来人为地增加数据集样本量的过程。

大多数深度学习框架具有一些基本函数,可以直接实现常用的数据转换。为了建立特定的图像识别系统,我们的任务是去确定对现有数据集有意义的转换方法。比如,不能对X射线图像旋转超过45度,因为这意味着在图像采集过程中出现错误。

图2:通过水平翻转和随机裁剪进行数据增强

常用转换方法:像素颜色抖动、旋转、剪切、随机裁剪、水平翻转、镜头拉伸和镜头校正等。

迁移学习和微调方法实现

数据准备

图3:Kaggle猫狗大赛的示例图像

我们将使用Kaggle猫狗大赛中提供的数据集,将训练集目录和验证集目录设置如下:

代码1

网络实现

让我们开始定义generators:

代码2

在上篇文章中,我们已经强调了在图像识别中预处理环节的重要性。从keras.applications.inception_v3模块中引出参数preprocess_input,进而设置preprocessing_function = preprocess_input

为了实现数据增强,还定义了旋转、移动、剪切、缩放和翻转操作的参数范围。

接下来,我们从keras.applications模块中引出InceptionV3网络。

代码3

设置了标志位include_top = False,去除ImageNet网络的全连接层权重,因为这是针对ImageNet竞赛的1000种日常对象预先训练好的网络权重。因此,我们将添加一个新的全连接层,并进行初始化。

代码4

全局平均初始化函数GlobalAveragePooling2DMxNxC张量转换后输出为1xC张量,其中C是图像的通道数。

然后我们添加一个维度为1024的全连接层Dense,同时加上一个softmax函数,得到[0,1]之间的输出值。

在这个项目中,我将演示如何实现迁移学习和微调。当然你可以在以后的项目中自由选用。

1. 迁移学习:除去倒数第二层,固定所有其他层的参数,并重新训练最后一层全连接层。

2. 微调:固定用来提取低级特征的底部卷积层,并重新训练更多的网络层。

这样做,将确保更稳定和全局一致的训练网络。因为如果不固定相关层,随机初始化网络权重会导致较大的梯度更新,进一步可能会破坏卷积层中的学习权重。我们应用迁移学习,训练得到稳定的最后全连接层后,可以再通过微调的方法训练更多的网络层。

迁移学习

代码5

微调

代码6

在微调过程中,最重要的是与网络从头开始训练时所使用的速率相比(lr = 0.0001),要降低学习率,否则优化过程可能不稳定,Loss函数可能会发散。

网络训练

现在我们开始训练,使用函数fit_generator同时实现迁移学习和微调。

代码7

我们将使用AWS上的EC2 g2.2xlarge实例进行网络训练。

我们可以使用对象history,绘制训练准确率和损失曲线。

代码8

模型预测

现在我们通过keras.model保存训练好的网络模型,通过修改predict.py中的predict函数后,只需要输入本地图像文件的路径或是图像的URL链接即可实现模型预测。

代码9

完工

作为例子,我将猫狗大赛数据集中的24000张图像作为训练集,1000张图像作为验证集。从结果中,可以看出训练迭代2次后,准确率已经相当高了。

图4:经过2次迭代后的输出日志

测试

代码10

图5:猫的图片和类别预测
图6:狗的图片和类别预测

将上述代码组合起来,你就创建了一个猫狗识别系统。

该项目的完整程序请查看GitHub链接: https://github.com/DeepLearningSandbox/DeepLearningSandbox/tree/master/transfer_learning

其他相关链接:

原文:https://deeplearningsandbox.com/how-to-use-transfer-learning-and-fine-tuning-in-keras-and-tensorflow-to-build-an-image-recognition-94b0b02444f2

Kaggle猫狗数据集:https://www.kaggle.com/c/dogs-vs-cats/data

训练好的模型:https://drive.google.com/file/d/0B9-cM_0P8MFGRmM1TU14b3ptVGM/view

原文发布于微信公众号 - 量子位(QbitAI)

原文发表时间:2017-05-03

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏人工智能

主流机器学习算法简介与其优缺点分析

机器学习算法的分类是棘手的,有几种合理的分类,他们可以分为生成/识别,参数/非参数,监督/无监督等。

2.1K4
来自专栏机器之心

入门 | 单样本学习:使用孪生神经网络进行人脸识别

3959
来自专栏SeanCheney的专栏

《Scikit-Learn与TensorFlow机器学习实用指南》 第10章 人工神经网络

鸟类启发我们飞翔,牛蒡植物启发了尼龙绳,大自然也激发了许多其他发明。从逻辑上看,大脑是如何构建智能机器的灵感。这是启发人工神经网络(ANN)的关键思想。然而,尽...

1763
来自专栏机器之心

302页吴恩达Deeplearning.ai课程笔记,详记基础知识与作业代码

5178
来自专栏杨熹的专栏

attention 机制入门

在下面这两篇文章中都有提到 attention 机制: 使聊天机器人的对话更有营养 如何自动生成文章摘要 今天来看看 attention 是什么。 下面这篇...

4308
来自专栏数据派THU

计算机视觉怎么给图像分类?KNN、SVM、BP神经网络、CNN、迁移学习供你选(附开源代码)

原文:Medium 作者:Shiyu Mou 来源:机器人圈 本文长度为4600字,建议阅读6分钟 本文为你介绍图像分类的5种技术,总结并归纳算法、实现方式,并...

71610
来自专栏AI科技大本营的专栏

深度学习系列:卷积神经网络结构变化——可变形卷积网络deformable convolutional

作者 | 大饼博士X 上一篇我们介绍了:深度学习方法(十二):卷积神经网络结构变化——Spatial Transformer Networks,STN创造性地...

46710
来自专栏机器之心

谷歌云大会教程:没有博士学位如何玩转TensorFlow和深度学习(附资源)

机器之心原创 作者:吴攀、李亚洲 当地时间 3 月 8 日-10 日,Google Cloud NEXT '17 大会在美国旧金山举行。谷歌在本次大会上正式宣布...

44111
来自专栏专知

主流机器学习算法简介与其优缺点分析

机器学习算法的分类是棘手的,有几种合理的分类,他们可以分为生成/识别,参数/非参数,监督/无监督等。 例如,Scikit-Learn的文档页面通过学习机制对算法...

3383
来自专栏企鹅号快讯

机器学习算法分类与其优缺点分析

机器学习算法的分类是棘手的,有几种合理的分类,他们可以分为生成/识别,参数/非参数,监督/无监督等。 例如,Scikit-Learn的文档页面通过学习机制对算法...

2237

扫码关注云+社区

领取腾讯云代金券