干货 | 大神支招:机器学习中用来防止过拟合的方法有哪些?

AI 科技评论按:本文作者 qqfly,上海交通大学机器人所博士生,本科毕业于清华大学机械工程系,主要研究方向机器视觉与运动规划。本文整理自知乎回答:机器学习中用来防止过拟合的方法有哪些?

给《机器视觉与应用》课程出大作业的时候,正好涉及到这方面内容,所以简单整理了一下(参考 Hinton 的课程)。按照之前的套路写:

是什么

过拟合(overfitting)是指在模型参数拟合过程中的问题,由于训练数据包含抽样误差,训练时,复杂的模型将抽样误差也考虑在内,将抽样误差也进行了很好的拟合。

具体表现就是最终模型在训练集上效果好;在测试集上效果差。模型泛化能力弱。

为什么

为什么要解决过拟合现象?这是因为我们拟合的模型一般是用来预测未知的结果(不在训练集内),过拟合虽然在训练集上效果好,但是在实际使用时(测试集)效果差。同时,在很多问题上,我们无法穷尽所有状态,不可能将所有情况都包含在训练集上。所以,必须要解决过拟合问题。

为什么在机器学习中比较常见?这是因为机器学习算法为了满足尽可能复杂的任务,其模型的拟合能力一般远远高于问题复杂度,也就是说,机器学习算法有「拟合出正确规则的前提下,进一步拟合噪声」的能力。

而传统的函数拟合问题(如机器人系统辨识),一般都是通过经验、物理、数学等推导出一个含参模型,模型复杂度确定了,只需要调整个别参数即可。模型「无多余能力」拟合噪声。

怎么样

既然过拟合这么讨厌,我们应该怎么防止过拟合呢?最近深度学习比较火,我就以神经网络为例吧:

1. 获取更多数据

这是解决过拟合最有效的方法,只要给足够多的数据,让模型「看见」尽可能多的「例外情况」,它就会不断修正自己,从而得到更好的结果:

如何获取更多数据,可以有以下几个方法:

  • 从数据源头获取更多数据:这个是容易想到的,例如物体分类,我就再多拍几张照片好了;但是,在很多情况下,大幅增加数据本身就不容易;另外,我们不清楚获取多少数据才算够;
  • 根据当前数据集估计数据分布参数,使用该分布产生更多数据:这个一般不用,因为估计分布参数的过程也会代入抽样误差。
  • 数据增强(Data Augmentation):通过一定规则扩充数据。如在物体分类问题里,物体在图像中的位置、姿态、尺度,整体图片明暗度等都不会影响分类结果。我们就可以通过图像平移、翻转、缩放、切割等手段将数据库成倍扩充;

2. 使用合适的模型

前面说了,过拟合主要是有两个原因造成的:数据太少 + 模型太复杂。所以,我们可以通过使用合适复杂度的模型来防止过拟合问题,让其足够拟合真正的规则,同时又不至于拟合太多抽样误差。

(PS:如果能通过物理、数学建模,确定模型复杂度,这是最好的方法,这也就是为什么深度学习这么火的现在,我还坚持说初学者要学掌握传统的建模方法。)

对于神经网络而言,我们可以从以下四个方面来限制网络能力

2.1 网络结构 Architecture

这个很好理解,减少网络的层数、神经元个数等均可以限制网络的拟合能力;

2.2 训练时间 Early stopping

对于每个神经元而言,其激活函数在不同区间的性能是不同的:

当网络权值较小时,神经元的激活函数工作在线性区,此时神经元的拟合能力较弱(类似线性神经元)。

有了上述共识之后,我们就可以解释为什么限制训练时间(early stopping)有用:因为我们在初始化网络的时候一般都是初始为较小的权值。训练时间越长,部分网络权值可能越大。如果我们在合适时间停止训练,就可以将网络的能力限制在一定范围内。

2.3 限制权值 Weight-decay,也叫正则化(regularization)

原理同上,但是这类方法直接将权值的大小加入到 Cost 里,在训练的时候限制权值变大。以 L2 regularization 为例:

训练过程需要降低整体的 Cost,这时候,一方面能降低实际输出与样本之间的误差C0,也能降低权值大小。

2.4 增加噪声 Noise

给网络加噪声也有很多方法:

2.4.1 在输入中加噪声:

噪声会随着网络传播,按照权值的平方放大,并传播到输出层,对误差 Cost 产生影响。推导直接看 Hinton 的 PPT 吧:

在输入中加高斯噪声,会在输出中生成

的干扰项。训练时,减小误差,同时也会对噪声产生的干扰项进行惩罚,达到减小权值的平方的目的,达到与 L2 regularization 类似的效果(对比公式)。

2.4.2 在权值上加噪声

在初始化网络的时候,用 0 均值的高斯分布作为初始化。Alex Graves 的手写识别 RNN 就是用了这个方法

Graves, Alex, et al. "A novel connectionist system for unconstrained handwriting recognition." IEEE transactions on pattern analysis and machine intelligence 31.5 (2009): 855-868.

- It may work better, especially in recurrent networks (Hinton)

2.4.3 对网络的响应加噪声

如在前向传播过程中,让默写神经元的输出变为 binary 或 random。显然,这种有点乱来的做法会打乱网络的训练过程,让训练更慢,但据 Hinton 说,在测试集上效果会有显著提升 (But it does significantly better on the test set!)。

3. 结合多种模型

简而言之,训练多个模型,以每个模型的平均输出作为结果。

从 N 个模型里随机选择一个作为输出的期望误差

,会比所有模型的平均输出的误差

(我不知道公式里的圆括号为什么显示不了)

大概基于这个原理,就可以有很多方法了:

3.1 Bagging

简单理解,就是分段函数的概念:用不同的模型拟合不同部分的训练集。以随机森林(Rand Forests)为例,就是训练了一堆互不关联的决策树。但由于训练神经网络本身就需要耗费较多自由,所以一般不单独使用神经网络做 Bagging。

3.2 Boosting

既然训练复杂神经网络比较慢,那我们就可以只使用简单的神经网络(层数、神经元数限制等)。通过训练一系列简单的神经网络,加权平均其输出。

3.3 Dropout

这是一个很高效的方法。

在训练时,每次随机(如 50% 概率)忽略隐层的某些节点;这样,我们相当于随机从 2^H 个模型中采样选择模型;同时,由于每个网络只见过一个训练数据(每次都是随机的新网络),所以类似 bagging 的做法,这就是我为什么将它分类到「结合多种模型」中;

此外,而不同模型之间权值共享(共同使用这 H 个神经元的连接权值),相当于一种权值正则方法,实际效果比 L2 regularization 更好。

4. 贝叶斯方法

这部分我还没有想好怎么才能讲得清楚,为了不误导初学者,我就先空着,以后如果想清楚了再更新。当然,这也是防止过拟合的一类重要方法。

综上:

原文发布于微信公众号 - AI科技评论(aitechtalk)

原文发表时间:2017-05-16

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏计算机视觉战队

每日一学——神经网络(下)

神经网络结构 灵活地组织层 将神经网络算法以神经元的形式图形化。神经网络被建模成神经元的集合,神经元之间以无环图的形式进行连接。也就是说,一些神经元的输出是另一...

3417
来自专栏目标检测和深度学习

深度 | 像玩乐高一样拆解Faster R-CNN:详解目标检测的实现过程

作者:Matt Simon 机器之心编译 本文详细解释了 Faster R-CNN 的网络架构和工作流,一步步带领读者理解目标检测的工作原理,作者本人也提供了...

3618
来自专栏ATYUN订阅号

27个问题测试你对逻辑回归的理解

逻辑回归可能是最常用的解决所有分类问题的算法。这里有27个问题专门测试你对逻辑回归的理解程度。 ? 1)判断对错:逻辑回归是一种有监督的机器学习算法吗? A)是...

4716
来自专栏企鹅号快讯

如何使用Keras集成多个卷积网络并实现共同预测

在统计学和机器学习领域,集成方法(ensemble method)使用多种学习算法以获得更好的预测性能(相比单独使用其中任何一种算法)。和统计力学中的统计集成(...

3699
来自专栏AI科技评论

干货 | 深度学习时代的目标检测算法

AI 科技评论按:本文作者 Ronald,首发于作者的知乎专栏「炼丹师备忘录」,AI 科技评论获其授权转发。 目前目标检测领域的深度学习方法主要分为两类:two...

5727
来自专栏量子位

CNN超参数优化和可视化技巧详解

王小新 编译自 Towards Data Science 量子位 出品 | 公众号 QbitAI 在深度学习中,有许多不同的深度网络结构,包括卷积神经网络(CN...

4614
来自专栏AI科技大本营的专栏

深度 | 机器学习中的模型评价、模型选择及算法选择

作者:Sebastian Raschka 翻译:reason_W 编辑:周翔 简介 正确使用模型评估、模型选择和算法选择技术无论是对机器学习学术研究还是工业场景...

5494
来自专栏AI星球

吾爱NLP(2)--解析深度学习中的激活函数

由惑而生,所以我打算总结一下深度学习模型中常用的激活函数的一些特性,方便大家日后为模型选择合适的激活函数。   说到激活函数,就不能不提神经网络或者深度学习,...

2122
来自专栏目标检测和深度学习

深度学习时代的目标检测综述

1581
来自专栏机器之心

解读 | 如何从信号分析角度理解卷积神经网络的复杂机制?

机器之心原创 作者:Qintong Wu 参与:Jane W 随着复杂和高效的神经网络架构的出现,卷积神经网络(CNN)的性能已经优于传统的数字图像处理方法,如...

2858

扫码关注云+社区