专栏首页量子位在Keras+TF环境中,用迁移学习和微调做专属图像识别系统

在Keras+TF环境中,用迁移学习和微调做专属图像识别系统

图1:CompCars数据集的示例图像,整个数据集包含163家汽车制造商,1713种车型
王小新 编译自 Deep Learning Sandbox 量子位 出品 | 公众号 QbitAI

量子位曾经编译过Greg Chu的一篇文章,介绍了如何用Keras+TF,来实现ImageNet数据集日常对象的识别

但是,你要研究的物体,往往不在那个列表中。我们可能想要区分出不同型号的太阳镜、认出不同的鞋子、识别各种面部表情、说出不同汽车的型号、在X光影像下判定肺部疾病的类型,这时候该怎么办?

Greg Chu,博客Deep Learning Sandbox的作者,又写了一篇文章,教你在Keras + TensorFlow环境中,用迁移学习(transfer learning)和微调(fine-tuning)定制你专属的图像识别系统,来辨识特定的研究对象。

为什么要使用迁移学习和微调?

一般来说,从头开始训练一个卷积神经网络,不仅需要大规模的数据集,而且会占用大量的计算资源。比如,为了得到ImageNet ILSVRC模型,Google使用了120万张图像,在装有多个GPU的服务器上运行2-3周才完成训练。

在实际应用中,深度学习相关的研究人员和从业者通常运用迁移学习和微调方法,将ImageNet等数据集上训练的现有模型底部特征提取层网络权重传递给新的分类网络。这种做法并不是个例。

这种做法的效果很好。Razavian等人2014年发表的论文*表明,从ImageNet ILSVRC的训练模型中,简单地提取网络权重的初级特征,应用在多种图像分类任务中,都取得了与ImageNet网络相同或几乎相同的分类效果。

* CNN Features off-the-shelf: an Astounding Baseline for Recognition:https://arxiv.org/pdf/1403.6382.pdf

接下来我们介绍一下这两种方法:

迁移学习:在ImageNet上得到一个预训练好的ConvNet网络,删除网络顶部的全连接层,然后将ConvNet网络的剩余部分作为新数据集的特征提取层。这也就是说,我们使用了ImageNet提取到的图像特征,为新数据集训练分类器。

微调:更换或者重新训练ConvNet网络顶部的分类器,还可以通过反向传播算法调整预训练网络的权重。

该选择哪种方法?

有两个主要因素,将影响到所选择的方法:

1. 你的数据集大小;

2. 新数据集与预训练数据集的相似性,通常与ImageNet数据集相比。

内容相似性较高

内容相似性较低

小型数据集

迁移学习:高级特征+分类器

迁移学习:低级特征+分类器

大型数据集

微调

微调

上表指出了在如下4个场景下,该如何从这两种方法中做选择:

新数据集相比于原数据集在样本量上更小,在内容上相似:如果数据过小,考虑到过拟合,这使用微调则效果不大好。因为新数据类似于原数据,我们希望网络中高级特征也与此数据集相关。因此,最好的思路可能是在ConvNet网络上重新训练一个线性分类器。

新数据集相比于原数据集在样本量上更小,且内容非常不同:由于数据较小,只训练一个线性分类器可能更好。但是数据集不同,从网络顶部开始训练分类器不是最好的选择,这里包含了原有数据集的高级特征。所以,一般是从ConvNet网络前部的激活函数开始,重新训练一个线性分类器。

新数据集相比于原数据集在样本量上较大,在内容上相似:由于我们有更多的数据,所以在我们试图微调整个网络,那我们有信心不会导致过拟合。

新数据集相比于原数据集在样本量上较大,但内容非常不同:由于数据集很大,我们可以尝试从头开始训练一个深度网络。然而,在实际应用中,用一个预训练模型的网络权重来初始化新网络的权重,仍然是不错的方法。在这种情况下,我们有足够的数据和信心对整个网络进行微调。

另外,在新数据集样本量较大时,你也可以尝试从头开始训练一个网络。

数据增强

数据增强方法能大大增加训练数据集的样本量和增大网络模型的泛化能力。实际上,在数据比赛中,每个获胜者的ConvNet网络一定会使用数据增强方法。在本质上,数据增强是通过数据转换来人为地增加数据集样本量的过程。

大多数深度学习框架具有一些基本函数,可以直接实现常用的数据转换。为了建立特定的图像识别系统,我们的任务是去确定对现有数据集有意义的转换方法。比如,不能对X射线图像旋转超过45度,因为这意味着在图像采集过程中出现错误。

图2:通过水平翻转和随机裁剪进行数据增强

常用转换方法:像素颜色抖动、旋转、剪切、随机裁剪、水平翻转、镜头拉伸和镜头校正等。

迁移学习和微调方法实现

数据准备

图3:Kaggle猫狗大赛的示例图像

我们将使用Kaggle猫狗大赛中提供的数据集,将训练集目录和验证集目录设置如下:

代码1

网络实现

让我们开始定义generators:

代码2

在上篇文章中,我们已经强调了在图像识别中预处理环节的重要性。从keras.applications.inception_v3模块中引出参数preprocess_input,进而设置preprocessing_function = preprocess_input

为了实现数据增强,还定义了旋转、移动、剪切、缩放和翻转操作的参数范围。

接下来,我们从keras.applications模块中引出InceptionV3网络。

代码3

设置了标志位include_top = False,去除ImageNet网络的全连接层权重,因为这是针对ImageNet竞赛的1000种日常对象预先训练好的网络权重。因此,我们将添加一个新的全连接层,并进行初始化。

代码4

全局平均初始化函数GlobalAveragePooling2DMxNxC张量转换后输出为1xC张量,其中C是图像的通道数。

然后我们添加一个维度为1024的全连接层Dense,同时加上一个softmax函数,得到[0,1]之间的输出值。

在这个项目中,我将演示如何实现迁移学习和微调。当然你可以在以后的项目中自由选用。

1. 迁移学习:除去倒数第二层,固定所有其他层的参数,并重新训练最后一层全连接层。

2. 微调:固定用来提取低级特征的底部卷积层,并重新训练更多的网络层。

这样做,将确保更稳定和全局一致的训练网络。因为如果不固定相关层,随机初始化网络权重会导致较大的梯度更新,进一步可能会破坏卷积层中的学习权重。我们应用迁移学习,训练得到稳定的最后全连接层后,可以再通过微调的方法训练更多的网络层。

迁移学习

代码5

微调

代码6

在微调过程中,最重要的是与网络从头开始训练时所使用的速率相比(lr = 0.0001),要降低学习率,否则优化过程可能不稳定,Loss函数可能会发散。

网络训练

现在我们开始训练,使用函数fit_generator同时实现迁移学习和微调。

代码7

我们将使用AWS上的EC2 g2.2xlarge实例进行网络训练。

我们可以使用对象history,绘制训练准确率和损失曲线。

代码8

模型预测

现在我们通过keras.model保存训练好的网络模型,通过修改predict.py中的predict函数后,只需要输入本地图像文件的路径或是图像的URL链接即可实现模型预测。

代码9

完工

作为例子,我将猫狗大赛数据集中的24000张图像作为训练集,1000张图像作为验证集。从结果中,可以看出训练迭代2次后,准确率已经相当高了。

图4:经过2次迭代后的输出日志

测试

代码10

图5:猫的图片和类别预测
图6:狗的图片和类别预测

将上述代码组合起来,你就创建了一个猫狗识别系统。

该项目的完整程序请查看GitHub链接: https://github.com/DeepLearningSandbox/DeepLearningSandbox/tree/master/transfer_learning

其他相关链接:

原文:https://deeplearningsandbox.com/how-to-use-transfer-learning-and-fine-tuning-in-keras-and-tensorflow-to-build-an-image-recognition-94b0b02444f2

Kaggle猫狗数据集:https://www.kaggle.com/c/dogs-vs-cats/data

训练好的模型:https://drive.google.com/file/d/0B9-cM_0P8MFGRmM1TU14b3ptVGM/view

本文分享自微信公众号 - 量子位(QbitAI),作者:好学上进

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2017-05-03

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • MIT课程全面解读2019深度学习最前沿 | 附视频+PPT

    人类公元纪年2019年伊始,深度学习技术也同样处在一个新的“开端”,宜review、宜展望。

    量子位
  • AI医疗来袭,患者、机构、数据科学家分别有哪些新机会?

    原作:Jeremy Howard 安妮 编译自 Hacker Noon 量子位 出品 | 公众号 QbitAI 本文作者Jeremy Howard,澳大利亚数据...

    量子位
  • 数据集查找神器!100个大型机器学习数据集都汇总在这了 | 资源

    网上各种数据集鱼龙混杂,质量也参差不齐,简直让人挑花了眼。想要获取大型数据集,还要挨个跑到各数据集的网站,两个字:麻烦。

    量子位
  • 大数据24小时 | 众企业开疆拓土布局大数据 贵州豪掷万金求人才

    HCR慧辰资讯收购瑞斡咨询,布局大数据应用层 自去年8月HCR以大数据商业应用第一股登陆新三板后,近日又与瑞斡咨询上海有限公司达成收购协议,将瑞斡正式纳入旗下。...

    数据猿
  • 终于有人把云计算、物联网和大数据讲明白了

    根据美国国家标准与技术研究院(National Institute of Standards and Technology,NIST)的定义,云计算是指能够针对...

    华章科技
  • 2016中国国际大数据大会召开

    <数据猿导读> 9月27日,2016中国国际大数据大会在京盛大开幕,本届大会主题为“数聚新动能数创大未来”。本届大会汇集了大数据产业链各环节共2000多位代表出...

    数据猿
  • 大数据投融资周报(4月16日—4月22日:共5起)

    <数据猿导读> 在本周(4月16日——4月22日),大数据领域共发生5起投融资事件。其中,承德市政府在京签订了14个大数据产业项目,总投资达205亿元;Orac...

    数据猿
  • 百度启动最大规模AI公开数据集计划;Uber承认数据遭窃 | DT数读

    过去一周,国际、国内的大数据相关公司都有哪些值得关注的新闻?数据行业都有哪些新观点和新鲜事?DT君为你盘点解读。

    DT数据侠
  • 数据科学家、数据分析师、数据挖掘工程师、数据工程师,你分的清楚吗?

    数据科学家(Data scientist)的叫法来自国外,广义上它是对从事数据分析和数据挖掘从业人员的一个泛称,它只是一个头衔,并不是一个职位。狭义上,数据科学...

    小莹莹
  • 拥抱大数据时代

    image.png 推荐语: 田溯宁,被称为“中国宽带先生”,是中国最早一批的互联网弄潮者。身为企业家的他,却极为低调。幸运且巧合的是我第一次亲眼见...

    腾讯研究院

扫码关注云+社区

领取腾讯云代金券