首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

深度学习第21讲:迁移学习的基本原理和实践

作为一门实验性学科,深度学习通常需要反复的实验和结果论证。在现在和将来,是否有海量的数据资源和强大的计算资源,这是决定学界和业界深度学习和人工智能发展的关键因素。通常情况下,获取海量的数据资源对于企业而言并非易事,尤其是对于像医疗等特定领域,要想做一个基于深度学习的医学影像的自动化诊断系统,大量且高质量的打标数据非常关键。但通常而言,莫说高质量,就是想获取大量的影像数据就已困难重重。

那怎么办呢?是不是获取不了海量的数据研究就一定进行不下去了呢?当然不是。因为我们有迁移学习。那究竟什么是迁移学习呢?顾名思义,迁移学习就是利用数据、任务或模型之间的相似性,将在旧的领域学习过或训练好的模型,应用于新的领域这样的一个过程。从这段定义里面,我们可以窥见迁移学习的关键点所在,即新的任务与旧的任务在数据、任务和模型之间的相似性。因而这也引出了迁移的应用场景。

迁移学习的使用场景

迁移学习到底在什么情况下使用呢?是不是我模型训练不好就可以用迁移学习进行改进呢?当然不是。正如前文所言,使用迁移学习的主要原因在于数据资源的可获得性和训练任务的成本。当我们有海量的数据资源时,自然不需要迁移学习,机器学习系统很容易从海量数据中学习到一个鲁棒性很强的模型。但通常情况下,我们需要研究的领域可获得的数据极为有限,仅靠有限的数据量进行学习,所习得的模型必然是不稳健、效果差的,通常情况下很容易造成过拟合,在少量的训练样本上精度极高,但是泛化效果极差。另一个原因在于训练成本,即所依赖的计算资源和耗费的训练时间。通常情况下,很少有人从头开始训练一整个深度卷积网络,一个是上面提到的数据量的问题,另一个就是时间成本和计算资源的问题,从头开始训练一个卷积网络通常需要较长时间且依赖于强大的 GPU 计算资源,对于一门实验性极强的领域而言,花费好几天乃至一周的时间去训练一个自己心里都没谱的深度神经网络通常是不能忍受的。

所以,迁移学习的使用场景如下:假设有两个任务系统 A 和 B,任务 A 拥有海量的数据资源且已训练好,但并不是我们的目标任务,任务 B 是我们的目标任务,但数据量少且极为珍贵,这种场景便是典型的迁移学习的应用场景。那究竟什么时候使用迁移学习是有效的呢?对此笔者不敢武断地下结论。但必须如前文所言,新的任务系统和旧的任务系统必须在数据、任务和模型等方面存在一定的相似性,你将一个训练好的语音识别系统迁移到放射科的图像识别系统上,恐怕结果不会太妙。所以,要判断一个迁移学习应用是否有效,最基本的原则还是要遵守,即任务 A 和任务 B 在输入上有一定的相似性,即两个任务的输入属于同一性质,要么同是图像、要么同是语音或其他,这便是前文所说到的任务系统的相似性的含义之一。

深度卷积网络的可迁移性

还有一个值得探讨的问题在于,深度卷积网络的可迁移性在于什么呢?为什么说两个任务具有同等性质的输入旧具备可迁移性?一切都还得从卷积神经网络的基本原理说起。由之前的学习我们知道,卷积神经网络具备良好的层次结构,通常而言,普通的卷积神经网络都具备卷积-池化-卷积-池化-全连接这样的层次结构,在深度可观时,卷积神经网络可以提取图像各个 level 的特征。当我们要从图像中识别一张人脸的时候,通常在一开始我们会检测到图像的横的、竖的等边缘特征,然后会检测到脸部的一些曲线特征,再进一步会检测到脸部的鼻子、眼睛和嘴巴等具备明显识别要素的特征等等。

这便揭示了深度卷积网络可迁移性的基本原理和卷积网络训练过程的基本事实。具备良好层次的深度卷积网络通常都是在最初的前几层学习到图像的通用特征(general feature),但随着网络层次的加深,卷积网络便逐渐开始检测到图像的特定的特征,两个任务系统的输入越相近,深度卷积网络检测到的通用特征越多,迁移学习的效果越好。所以,这也引出了笔者最后一个问题,怎样在实际操作中使用迁移学习?

迁移学习的使用方法

通常而言,迁移学习有两种使用套路。第一种便是常说的 finetune,即微调,简单而言就是将别人训练好的网络拿来进行简单修改用于自己的学习任务。在实际操作中,通常用预训练的网络权值对自己网络的权值进行初始化,以代替原先的随机初始化。第二种称为 fixed feature extractor,即将预训练的网络作为新任务的特征提取器,在实际操作中通常将网络的前几层进行冻结,只训练最后的全连接层,这时候预训练网络便是一个特征提取器。

keras 为我们提供了经典网络在 ImageNet 上为我们训练好的预训练模型,预训练模型的基本信息如下表所示:

笔者以 VGG16 网络预训练为例对手写数字数据集 mnist 进行迁移学习任务,试验代码如下:

fromkeras.modelsimportModel

fromkeras.layersimportDense, Flatten, Dropout

fromkerasimportdatasets

fromkeras.applications.vgg16importVGG16

fromkeras.optimizersimportSGD

fromkeras.datasetsimportmnist

importnumpyasnpimportcv2

# 查看 VGG16 预训练模型的基本信息

model_vgg = VGG16(include_top=False, weights='imagenet', input_shape=(224,224,3))model = Flatten(name='Flatten')(model_vgg.output)model= Dense(10, activation='softmax')(model)model_vgg_mnist = Model(inputs=model_vgg.input, outputs=model, name='vgg16')model_vgg_mnist.summary()

冻结预训练模型的卷积和池化层,仅修改全连接层:

model_vgg = VGG16(include_top=False, weights='imagenet', input_shape=(224,224,3))

forlayersinmodel_vgg.layers: layers.trainable =Falsemodel = Flatten()(model_vgg.output)model = Dense(10, activation='softmax')(model)model_vgg_mnist_pretrain = Model(inputs=model_vgg.input, outputs=model, name='vgg16_pretrain')sgd = SGD(lr=0.05, decay=1e-5)model_vgg_mnist_pretrain.compile(optimizer=sgd, loss='categorical_crossentropy', metrics=['accuracy'])

然后转换 mnist 训练数据的输入大小以适应 VGG16 的输入:

X_train = [cv2.cvtColor(cv2.resize(i, (img, img)), cv2.COLOR_GRAY2BGR)foriinX_train]X_train = np.concatenate([arr[np.newaxis]forarrinX_train]).astype('float32')X_test = [cv2.cvtColor(cv2.resize(i, (img, img)), cv2.COLOR_GRAY2BGR)foriinX_test ]X_test = np.concatenate([arr[np.newaxis]forarrinX_test] ).astype('float32')

稍加处理后进行训练:

X_train /= X_train/255X_test /= X_test/255

np.where(X_train[] !=)

deftrain_y(y): y_one = np.zeros(10) y_one[y] =1returny_oney_train_one = np.array([train_y(y_train[i])foriinrange(len(y_train))])y_test_one = np.array([train_y(y_test [i])foriinrange(len(y_test ))])model_vgg_mnist_pretrain.fit(X_train, y_train_one, validation_data=(X_test, y_test_one), epochs=10, batch_size=128)

限于笔者 windows 单机的计算能力,笔者仅进行了10轮的简单训练作为示例。当训练次数足够时,基于 VGG16 预训练网络在 mnist 数据集上的迁移学习效果会不错的。

参考资料:

https://www.deeplearning.ai/

王晋东 迁移学习手册

Yosinski J, Clune J, Bengio Y, et al. How transferable are features in deep neural networks?[J]. Eprint Arxiv, 2014, 27:3320-3328.

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20180807B1UYUS00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券