前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >机器学习算法究竟需要试验多少次,才能有效反映模型性能?

机器学习算法究竟需要试验多少次,才能有效反映模型性能?

作者头像
AI研习社
发布2018-03-28 15:43:21
1.5K0
发布2018-03-28 15:43:21
举报
文章被收录于专栏:AI研习社AI研习社AI研习社

编者按:本文作者 Jason Brownlee 为澳大利亚知名机器学习专家,对时间序列预测尤有心得。原文发布于其博客。AI 研习社编译。文中相关链接详见文末“阅读原文”。

Jason Brownlee

许多随机机器学习算法存在同样的问题:相同的算法、相同的数据,得到的计算结果却每次都不同。这意味着在进行随机算法检验或者算法比较的时候,必须重复试验很多次,然后用它们的平均值来评价模型。

那么对于给定问题,随机机器学习算法需要试验多少次,才足以客观有效的反映模型性能?

一般建议重复30次以上甚至100次左右。有人甚至重复几千次,完全无视边际递减效应。

对于衡量随机机器学习算法性能所需的重复试验次数,在本教程中,我将教会大家如何用统计学方法来正确预估。

教程概述

本教程分以下4部分:

  1. 数据生成
  2. 基本分析
  3. 重复次数的影响分析
  4. 标准误差计算

本教程使用Python语言,版本 2或者3均可,为顺利运行示例代码,请务必安装SciPy 、NumPy、Pandas和Matplotlib库。

下面正式开始我们的教程

1. 数据生成

第一步是生成可用的数据。

假设我们将一个神经网络模型或其它随机算法,在数据的训练集上重复训练了1000次,并且记录了模型在测试集上的均方根误差(RMSE)。作为本教程后续分析的前提,假设我们所用的数据呈正态分布。

务必查看一下结果的分布形态,通常结果会呈高斯分布(即正态分布)。

我们会预先生成研究用的样本总体,这么做对后续研究非常有帮助,因为程序生成的样本总体其均值和标准差就确定下来,而这在实际应用中常常是无法得知的。

我们用均值=60,标准差=10作为参数生成试验数据。

下面是生成1000个随机数的代码,将结果保存为results.csv文件.

代码中我们用seed()作为随机数生成器种子函数,来确保每次运行代码后得到的数据都一致。使用normal()函数生成正态分布随机数,用savetxt()函数将数据保存为ASCII格式。

运行这段代码后,我们得到一个名为results.csv的文件,里面保存了1000个随机数,它们代表了随机算法重复运行的模拟结果。

下面是该文件的最后十行数据。

6.160564991742511864e+01 5.879850024371251038e+01 6.385602292344325548e+01 6.718290735754342791e+01 7.291188902850875309e+01 5.883555851728335995e+01 3.722702003339634302e+01 5.930375460544870947e+01 6.353870426882840405e+01 5.813044983467250404e+01

现在咱们先把如何得到这批数据的事放一边,继续往下进行。

2. 基本分析

得到样本总体之后,我们先对其进行简单的统计分析。

下面三种是非常简单有效的方法:

  1. 计算统计信息,比如均值、标准差和百分位数。
  2. 绘制箱线图来查看数据散布程度
  3. 绘制直方图来查看数据分布情况

通过下面的代码进行简单的统计分析,首先加载results.csv数据文件,然后进行统计计算,并绘图显示。

可以看出,算法的平均性能约为60.3,标准差约为9.8。

假定数据表示的是类似均方根误差一样的最小值,从统计结果看,最大值为99.5,而最小值为29.4。

下面的箱线图中展示了数据的散布程度,其中箱形部分是样本中段(上下四分位之间)数据(约占样本的50%),圆点代表异常值,绿线表示中位数。

由图可知,结果围绕中值分布合理。

最后生成的是数据的直方图,图中显示出了正态分布的贝尔曲线(钟形曲线),这意味着我们在进行数据分析工作时,可以使用标准的统计分析工具。

由图可知,数据以60为对称轴,左右几乎没有偏斜。

3. 重复次数的影响分析

之前我们生成了1000个结果数据。对于问题的研究来说可能多了,也可能不够。

该如何判断呢?

第一个想法就是画出试验重复次数和这些试验结果均值之间的曲线图。我们希望随着重复次数的增加,结果的均值能很快稳定。绘制成曲线后,看起来起始段波动较大且短,而中后部平稳且长。

利用下面的代码绘制出该曲线。

由图可以看出,前200次数据均值波动较大, 600次后,均值趋于稳定,曲线波动较小。

为了更好的观察曲线,将其放大,只显示前500次重复试验结果。

同时将1000次试验结果的均值线叠加上,以便找到两者之间的偏差关系。

图中橙色直线就是1000重复试验结果的均值线。

同时也能看到重复100次时,结果与均值较近,重复次数达到400时,结果更理想,但是提升不明显。

是不是很棒?不过会不会还有更好的办法呢?

4. 计算标准误差

标准误差用来计算样本均值偏离总体均值的多少。它和标准差不同,标准差描述了样本观察值的平均变化量。标准误差能够根据样本均值的误差量或者误差散布来估计总体均值。

标准误差可以通过下式计算:

standard_error = sample_standard_deviation / sqrt(number of repeats)

即标准误差等于样本的标准差除以重复次数的均方根。

我们希望标准误差会随着试验次数的增加而减小。通过下面的代码,计算每个重复试验次数对应的样本均值的标准误差,并绘制标准误差图。

运行代码后,会绘制出标准误差与重复次数的关系曲线。

和预期的一样,随着重复试验次数的增加,标准误差快速减小。标准误差下降到一定程度后,趋于稳定,通常把1~2个单位内的值,称为可接受误差。

标准误差的单位和样本数据的单位一致。

在上图中添加纵坐标为0.5和1的辅助线,帮助我们找到可接受的标准误差值。代码如下:

友情提醒,图中出现的两条红色辅助线,分别代表标准误差等于0.5和1。

由图可知,如果试验重复次数等于100次左右,标误差开始小于1,如果试验重复次数等于300~350次左右,标准误差小于0.5。随着重复试验次数的增加,标准误差趋于稳定,变化较小。再次提醒大家记住,标准误差可以衡量样本均值偏离总体均值的多少。

我们也可以使用标准误差来作为均值的置信区间。比如,用总体均值的95%作为置信区间的上下界。这种方法只适合试验重复次数大于20的情况。

置信区间定义如下:

样本均值 +/- (标准误差*1.96)

下面计算置信区间,并将其作为误差线添加到重复试验次数对应的样本均值上。这是计算代码。

下图创建了带置信区间的样本均值曲线。

其中红色直线表示总体的均值(在教程开始根据给定的均值和标准差生成了总体,所以总体的均值已知),重复1000次或更多后,可以用样本均值代替总体均值。

图中误差线包裹着均值线。而且样本均值夸大或高估了总体均值,不过还是落在总体均值的95%置信区间内。

95%置信区间的含义是做100次重复试验,有95次包含了总体均值的真值,另外5次没有包括。

图中可以看出,随着重复次数的增加,由于标准误差的减小,95%置信区间也逐渐变窄。

放大上图后,这种趋势在20到200之间时尤其明显。

这是由上述代码生成的样本均值和误差线随试验次数变化的曲线。此图能更好的反映样本均值与总体均值的偏差。

扩展阅读

实际上,既涉及使用随机算法的计算试验方法又涉及统计学的参考资料非常少。

我个人认为1995年科恩的书是两者结合最好的:

Empirical Methods for Artificial Intelligence(人工智能实证方法),Cohen(科恩),1995

如果你对这篇教程感兴趣,我强烈推荐此书。

另外,维基百科上还有几篇文章可能对你有帮助:

Standard Error

Confidence Interval

68–95–99.7 rule

如果你还有其他好的相关资料,可以在评论区与大家交流。谢谢。

小结

在这篇教程里,我们提供了一种合理选择试验重复次数的方法,这有助于我们评价随机机器学习算法的正确性。

下面是几种重复次数选择的方法:

  • 简单粗暴的直接用30、100或者1000次。
  • 绘制样本均值和重复次数的关系曲线,并根据拐点进行选择。
  • 绘制标准误差和重复次数的关系曲线,并根据误差阈值进行选择。
  • 绘制样本置信区间和重复次数的关系曲线,并根据误差散布进行选择。

延伸阅读:DeepMind 弹性权重巩固算法让 AI 拥有“记忆” ,将成机器高效学习的敲门砖

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2017-05-27,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI研习社 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 教程概述
  • 1. 数据生成
  • 2. 基本分析
  • 4. 计算标准误差
  • 扩展阅读
  • 小结
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档