fine-tuning的二三事

日常的应用中,我们会很经常遇到一个问题:

如何应用强大的model(比如ResNet)去训练我们自己的数据?

考虑到这样的几个事实:

  1. 通常我们自己的数据集都不会大(<1w)
  2. 从头开始训练耗时

解决方法就是fine-tuning.

方式

参考CS231的资料,有三种方式

  • ConvNet as fixed feature extractor. 其实这里有两种做法: (1) 使用最后一个fc layer之前的fc layer获得的特征,学习个线性分类器(比如SVM) (2) 重新训练最后一个fc layer
  • Fine-tuning the ConvNet. 固定前几层的参数,只对最后几层进行fine-tuning
  • Pretrained models. 这个其实和第二种是一个意思,不过比较极端,使用整个pre-trained的model作为初始化,然后fine-tuning整个网络而不是某些层

选择

考虑两个问题:

  • 你的数据集大小
  • 你的数据集和ImageNet(假设在ImageNet上训练的)的相似性

分为四种情况,解决方法基于的原则就是:

NN中的低层特征是比较generic的,比如说线、边缘的信息,高层特征是Dataset Specific的,基于此,如果你的数据集和ImageNet差异比较大,这个时候你应该尽可能的少用pre-trained model的高层特征.

  1. 数据集小(比如<5000),相似度高 这是最常见的情况,可以仅重新训练最后一层(fc layer)
  2. 数据集大(比如>10000),相似度高 fine-tuning后几层,保持前面几层不变或者干脆直接使用pre-trained model作为初始化,fine-tuning整个网络
  3. 数据集小,相似度低 小数据集没有办法进行多层或者整个网络的fine-tuning,建议保持前几层不动,fine-tuning后几层(效果可能也不会很好)
  4. 数据集大,相似度低 虽然相似度低,但是数据集大,可以和2一样处理

从上面我们可以看出,数据集大有优势,否则最好是数据集和原始的相似度比较高;如果出现数据集小同时相似度低的情况,这个时候去fine-tuning后几层未必会有比较好的效果.

Caffe中如何进行fine-tuning

Caffe做fine-tuning相对tensorflow很简单,只需要简单修改下配置文件就行了.

此处假设你的数据集比较小,同时相似度比较高,仅需重新训练最后一层(fc)的情况.

(1) 降低solver中lr和stepsize

这个很明显,因为相似度比较高我们可以期望原始获得的feature和需要的是很接近的,此时需要降低学习率(lr)和迭代次数(stepsize).

(2) 修改最后一层fc的名字,设置好lr_mult

应为需要训练最后一层,我们把之前的层的学习率设置的很低(比如0.001),或者你干脆设置为0,最后一层设置一定的学习率(比如0.01),所以需要乘以10.

(3) 训练

其实就已经改好了,是不是很简单,按照之前标准化的训练测试就好了

知乎上fine-tuning的介绍上有更加详细的介绍,可以移步去看.

参考

(1) NodYoung的博客

(2) CS231的transfer-learning

(3) 知乎上关于caffe下做fine-tuning的介绍

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏null的专栏

简单易学的机器学习算法——在线顺序极限学习机OS-ELM

   这篇文章主要是前面整理的,就直接上图了。 ? ? ? ? ? ? ? 实验: # coding:UTF-8 ################# # OS_...

7815
来自专栏专知

100+中文词向量,总有一款适合你

1724
来自专栏求索之路

cs231n之KNN算法

1.环境搭建以及前置条件 1.前置环境: 1.mac 2.pycharm 3.python3 4.Anaconda 2.环境搭建: 1.官网下载并安装Ana...

2959
来自专栏数说工作室

文本相似比较

大家好,我是数说君,这篇文章是想跟大家讨教一下。 如果有两段简单文本,如何比较它们的相似度?这里我们就假设是英文,不存在中文的分词问题,文本就类似于: text...

34614
来自专栏深度学习之tensorflow实战篇

ggolot2 画ROC曲线

为了进一步了解ggplot2的使用,利用ROC曲线进行说明学习。 ####获取画图数据(data.frame格式)##### library(ggplot2) ...

2896
来自专栏AI黑科技工具箱

1.试水:可定制的数据预处理与如此简单的数据增强(下)

上一部分我们讲了MXNet中NDArray模块实际上有很多可以继续玩的地方,不限于卷积,包括循环神经网络RNN、线性上采样、池化操作等,都可以直接用NDArra...

2243
来自专栏码云1024

NumPy 介绍与安装

NumPy 是一个 Python 包。 它代表 “Numeric Python”。 它是一个由多维数组对象和用于处理数组的例程集合组成的库。

3385
来自专栏TensorFlow从0到N

TensorFlow从0到1 - 12 - TensorFlow构建3层NN玩转MNIST

上一篇 11 74行Python实现手写体数字识别展示了74行Python代码完成MNIST手写体数字识别,识别率轻松达到95%。这算不上一个好成绩,不过我并...

4205
来自专栏杨建荣的学习笔记

Python之火,可以燎原

Python的优势之一是简洁。同样的功能,Python代码往往只有C、C++和Java代码的1/5-1/3。比如,实现一个Hello World!, Pytho...

582
来自专栏人工智能

将图像转换位mnist数据格式

利用mnist数据对数字符号进行识别基本上算是深度学习的Hello World了。在我学习这个“hello world”的过程中,总感觉缺点什么,于是发现无论是...

26110

扫码关注云+社区