DL开源框架Caffe | 模型微调 (finetune)的场景、问题、技巧以及解决方案

前言

什么是模型的微调?

  使用别人训练好的网络模型进行训练,前提是必须和别人用同一个网络,因为参数是根据网络而来的。当然最后一层是可以修改的,因为我们的数据可能并没有1000类,而只有几类。把最后一层的输出类别和层的名称改一下就可以了。用别人的参数、修改后的网络和自己的数据进行训练,使得参数适应自己的数据,这样一个过程,通常称之为微调(fine tuning).

微调时候网络参数是否更新?

  更新,finetune的过程相当于继续训练,跟直接训练的区别是初始化的时候:    a. 直接训练是按照网络定义指定的方式初始化(如高斯随机初始化)   b. finetune是用你已经有的参数文件来初始化(就是之前训练好的caffemodel)

**第一部分:Caffe命令行解析** —————

一、训练模型代码

  脚本:

./build/tools/caffe train -solver models/finetune/solver.prototxt -weights models/vgg_face_caffe/VGG_FACE.caffemodel -gpu 0

  BAT命令:  

..\..\bin\caffe.exe train --solver=.\solver.prototxt -weights .\test.caffemodel
pause

二、caffe命令全解析

http://www.cnblogs.com/denny402/p/5076285.html

第二部分:微调参数调整示例

一、各类模型finetune示例

Caffe finetune Resnet-50

http://blog.csdn.net/tangwenbo124/article/details/56070322

Caffe finetune googlenet

http://blog.csdn.net/sinat_30071459/article/details/51679995

Caffe finetune FCN

http://blog.csdn.net/zy3381/article/details/50458331

Caffe finetune Alexnet

二、参数调整注意

  • 首先修改名字,这样预训练模型赋值的时候这里就会因为名字不匹配从而重新训练,也就达成了我们适应新任务的目的;
  • 调整学习速率,因为最后一层是重新学习,因此需要有更快的学习速率相比较其他层,因此我们将,weight和bias的学习速率加快10倍,目的是让非微调层学习更快;
  • finetune时将最后的全连接层的名字全部修改,需要根据自己数据集的类别数重新设置fc8层的output数;
  • 数据集的类别号从0开始,中间要连续,否则会造成意外的错误
  • 数据集记得打乱,不然很可能不收敛;
  • 如果出现不收敛的问题,可以把solver里的lr设的小一点,一般从0.01开始,如果出现loss=nan了就不断往小调整;
  • 可以把accuracy和loss的曲线画出来,方便设定stepsize,一般在accuracy和loss都趋于平缓的时候就可以减小lr了;
  • finetune时应该用自己的数据集生成的均值文件(是否正确?);

第三部分:fine-tune的选择经验

  在fine-tune时,究竟该选择哪种方式的Transfer Learning?需要考虑的因素有许多,其中最重要的两条是新数据库的规模和它与预训练数据库的相似程度,根据这两条因素的不同配置,存在四种场景:   新数据库小,和预训练数据库相似。因为数据库比较小,fine-tune的话可能会产生过拟合,比较好的做法是用预训练的网络作为特征提取器,然后训练线性分类器用在新的任务上。   新数据库比较大,和预训练数据库相似。这种情况下,不用担心过拟合,可以放心地微调整个网络。   新数据库小,和预训练数据库不相似。这时,既不能微调,用预训练网络去掉最后一层作为特征提取器也不合适,可行的方案是用预训练网络的前面几层的激活值作为特征,然后训练线性分类器。   新数据库大,和预训练数据库不相似。这时可以从头开始训练,也可以在预训练的基础上进行微调。

  综述:做freeze操作时,通常还会根据数据集在不同情况进行有选择的性的finetune。如small datasets时,可以freeze前面conv layer-> fc4086来提取cnn在imagenet上的多类泛化特征来辅助作为分类的feature,再对如这边revise的fc-20->softmax进行training。以此类推,如果是medium datasets则freeze到一半的conv。个人理解这样做的很大原因在于lower level layer具有更强泛化的basic feature,同时记得考量你的数据来选择。

第四部分:如何针对上述不同的方式进行网络参数固定

比如有4个全连接层A->B->C->D:   a. 你希望C层的参数不会改变,C前面的AB层的参数也不会改变,这种情况也就是D层的梯度不往前反向传播到D层的输入blob(也就是C层的输出blob 没有得到梯度),你可以通过设置D层的lr_mult: 0,layer的梯度就不会反向传播啦,前面的所有layer的参数也就不会改变了。   b. 你希望C层的参数不会改变,但是C前面的AB层的参数会改变,这种情况,只是固定了C层的参数,C层得到的梯度依然会反向传播给前面的B层。只需要将对应的参数blob的学习率调整为0:   在layer里面加上param { lr_mult: 0 }就可以了,比如全连接层里面:

layer {
    type: "InnerProduct"
    param { # 对应第1个参数blob的配置,也就是全连接层的参数矩阵的配置
         lr_mult: 0 # 学习率为0,其他参数可以看caffe.proto里面的ParamSpec这个类型
    }
    param { # 对应第2个参数blob的配置,也就是全连接层的偏置项的配置
        lr_mult: 0 # 学习率为0
    }
}

第五部分:Caffe fine-tune常见问题

一、按照网上的教程微调alexnet,为什么loss一直是87.3365?

  解决办法:检查数据集的标签是否是从0开始,base_lr调低了一个数量级,batch_size调高一倍。   出现的原因:87.3365是个很特殊的数字,NAN经过SoftmaxWithLoss就产生了这个数字,所以就是你的FC8输出全是NAN;   具体分析: http://blog.csdn.net/jkfdqjjy/article/details/52268565?locationNum=14

二、Loss下降了,但是准确率没有明显变化?

  解决办法:训练前首先shuffle,其次学习率是否合适。

三、Data augmentation 的技巧总结:

转自小白在闭关 https://www.zhihu.com/question/35339639

图像亮度、饱和度、对比度的变化; PCA Jittering Random resize Random crop Horizontal/vertical filp 旋转仿射变换 加高斯噪声、模糊处理 Label shuffle:类别不平衡数据的扩增,参见海康威视ILSVRC2016的report

四、如何通过loss曲线判断网络训练的情况:

  单独的 loss 曲线能提供的信息很少的,一般会结合测试机上的 accuracy 曲线来判断是否过拟合;   关键是要看你在测试集上的acc如何;   如果你的 learning_rate_policy 是 step 或者其他变化类型的话, loss 曲线可以帮助你选择一个比较合适的 stepsize;

五、finetune_net.bin不能用之后,用新的方法做finetune会出问题,怎么解决?

  给最后那个InnerProduct层换个名字。

第六部分:参考资料

1.http://caffe.berkeleyvision.org/gathered/examples/finetune_flickr_style.html 2.https://www.zhihu.com/question/54775243 3.http://blog.csdn.net/u012526120/article/details/49496617 4.https://zhidao.baidu.com/question/363059557656952932.html

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏跟着阿笨一起玩NET

.NET3.5 GDI+ 图形操作1

      前言: 本文章抄袭自本人刚刚买的《ASP.NET 3.5从入门到精通》这本书,此书介绍在 http://www.china-pub.com/4499...

442
来自专栏ATYUN订阅号

【学术】实践教程:使用神经网络对犬种进行分类

几天前,我注意到由Kaggle主办的犬种识别挑战赛。我们的目标是建立一个模型,能够通过“观察”图像来进行犬种分类。我开始考虑可能的方法来建立一个模型来对犬种进行...

3595
来自专栏大数据杂谈

【Excel系列】Excel数据分析:时间序列预测

移动平均 18.1 移动平均工具的功能 “移动平均”分析工具可以基于特定的过去某段时期中变量的平均值,对未来值进行预测。移动平均值提供了由所有历史数据的简单的平...

3729
来自专栏新智元

PyTorch 最新版发布:API 变动,增加新特征,多项运算和加载速度提升

【新智元导读】PyTorch 发布了最新版,API 有一些变动,增加了一系列新的特征,多项运算或加载速度提升,而且修改了大量bug。官方文档也提供了一些示例。 ...

5147
来自专栏ATYUN订阅号

一个简单而强大的深度学习库—PyTorch

AiTechYun 编辑:yuxiangyu 每过一段时间,总会有一个python库被开发出来,改变深度学习领域。而PyTorch就是这样一个库。 在过去的几周...

4206
来自专栏人工智能LeadAI

宠物狗图片分类之迁移学习代码笔记

本文主要是总结之前零零散散抽出时间做的百度西交大狗狗图片分类竞赛题目 竞赛.目前本人已经彻底排到了50名后面,,,也没有想到什么办法去调优,并且平时也忙没时间再...

701
来自专栏挖坑填坑

Angular练习之animations动画二

引入动画模块>创建动画对象>在动画载体上使用。我觉得其核心的内容在创建动画对象上,今天我们就来练习创建不同的动画对象trigger

652
来自专栏专知

【专知国庆特刊-PyTorch手把手深度学习教程系列01】一文带你入门优雅的PyTorch

【导读】主题链路知识是我们专知的核心功能之一,为用户提供AI领域系统性的知识学习服务,一站式学习人工智能的知识,包含人工智能( 机器学习、自然语言处理、计算机视...

7417
来自专栏AI黑科技工具箱

1.试水:可定制的数据预处理与如此简单的数据增强(下)

上一部分我们讲了MXNet中NDArray模块实际上有很多可以继续玩的地方,不限于卷积,包括循环神经网络RNN、线性上采样、池化操作等,都可以直接用NDArra...

2853
来自专栏瓜大三哥

face++人脸识别

该系统主要分为: 1.数据库:500万张图片和2000个人,而且删除了LFW中特有的人,其分布如下(网上搜集的图片有一个长尾效应:就是随着图片数量的增加不利于网...

2789

扫码关注云+社区