机器学习:你需要多少训练数据?

从谷歌的机器学习代码中得知,目前需要一万亿个训练样本

训练数据的特性和数量是决定一个模型性能好坏的最主要因素。一旦你对一个模型输入比较全面的训练数据,通常针对这些训练数据,模型也会产生相应的结果。但是,问题是你需要多少训练数据合适呢?这恰恰取决于你正在执行的任务、最终想通过模型实现的性能、现有的输入特征、训练数据中含有的噪声、已经提取的特征中含有的噪声以及模型的复杂性等等诸多因素。所以,发现所有这些变量相互之间有何联系,如何工作的方法即是通过在数量不一的训练样本上训练模型,并且绘制出模型关于各个训练样本集的学习曲线图。你必须已经具有特性比较明显、数量适合的训练数据,才能通过模型的训练学习出感兴趣、性能比较突出的学习曲线图。要实现上述的目的,你不禁会问,当你刚刚着手训练一个模型的时候,你应该怎样做,或者是在你训练模型的过程中,你什么时候能察觉到模型的训练数据过少,并且想要估量出在整个模型训练过程中存在什么样的问题。

所以,针对上述这些问题,代替绝对精准的回答,我们给出一个推测出的,比较实用的拇指规则。其大致过程是:自动生成大量的关于逻辑回归的问题。对于每个生成的逻辑回归问题,学习出训练样本的数量与训练模型性能之间的存在的某种关系。基于一系列的问题观察训练样本的数量与训练模型性能之间的联系,从而得到一个简单的规则——拇指规则。

我不能确定我的模型需要多少训练样本,我将建立一个模型来推测出所需训练样本的数量

这里是生成一系列关于逻辑回归问题和研究基于数量渐变的训练样本在模型上训练效果的代码。通过调用谷歌的开源工具箱Tensorflow执行代码。代码的运行过程中不需要应用到任何软件和硬件,并且我能够在我的笔记本上运行整个实验。随着代码的运行,会得到下面的学习曲线图,如图(1)所示

图(1)中,x轴表示训练样本数量与模型参数数量的比值。y轴是模型的f-score值。图中不同颜色的曲线对应于带有不同参数数量的训练模型。例如,红色曲线表示一个具有128个参数的模型随着训练样本数量128 X 1,128 X 2 等等这样变化时,f-score值的变化情况。

得到的第一个观察结果即是:f-score值不随着参数尺度的变化而变化。通过这一观察结果,我们可以认为给定的模型是线性的,并且令人高兴的是模型中的一些隐含层没有混入非线性。当然,更大的模型需要更多的训练样本,但是若训练样本数量与模型参数数量的比值是给定的,你会获得相同的模型性能。第二个观察结果即是:当训练样本数量与模型参数数量之比为10:1时,f-score值在0.85上下浮动,我们可以称此时的训练模型是一个具有良好性能的模型。通过以上的观察结果可以得出一个10倍规则法——即是要训练出一个性能良好的模型,所需训练样本数量应是模型参数数量的10倍。

因而,借由10倍规则法,将估量训练样本数量的问题转换为只要知道模型中参数数量就可以训练出一个性能良好的模型问题。基于这一点这引发了一些争论:

(1)对于线性模型 ,例如逻辑回归模型。基于每个特征,模型都分配了相应的参数,从而使得参数的数量和输入的特征数量相等,然而这里可能存在一些问题:你的特征可能是稀疏的,所以,计数的特征数量并不是直接的。

译者注:我觉得这句话的意思是,稀疏特征,例如稀疏特征的编码是01001001对于模型的训练能够起到作用的特征是少数的,而不起作用的特征占大多数。依照上述线性规则,若模型对于每个特征分配相应的参数,也就是说对于无用的特征也分配了相应的参数,再根据10倍规则法,获取是模型参数数量10倍的训练样本集,此时的训练样本数量对于最佳的训练模型来说可能是超量的,所以,此时用10倍规则法得到的训练样本集未必能够真实地得出好的训练模型。

(2)由于规范化和特征选择技术,训练模型中真实输入的特征的数量少于原始特征数量。

译者注:我觉得这两点即是在解释上述利用10倍规则法来得到性能良好模型的理论是有一定的局限性,这个理论是相对于输入特征是完备且每个特征对于模型的训练都有一定的贡献程度的。但是对于(1)、(2)这两点所说的稀疏特征和特征降维技术的使用导致特征数量减少的情况,利用10倍规则法能够得到性能良好的模型的理论还有待进一步讨论。

解决上述(1)、(2)问题的一个办法即是:在提取特征时,你不仅要用到有类别标签的数据还要用到不带类别标签的数据来估计特征的数量。例如给定一个文本语料库,在标记数据进行训练之前,你可以通过统计每个单词出现的次数,来生成一个关于单词频率直方图,以此来理解你的特征空间。根据单词频率直方图,你可以去掉长尾词,来获得真实的、主要的特征数量,之后你可以运用10倍规则法来估测在得到性能良好的模型时,你所需要的训练样本数量。

与像逻辑回归这样的线性模型相比,神经网络模型提出了一组不同的问题。为了得到神经网络中参数的数量你需要:

(1)如果你的输入特征是稀疏的,计算嵌入层中(我觉得就是隐含层)参数数量。

(2)计算神经网络模型中的边数。

根本问题是在神经网络中参数之间的关系不再是线性的。所以基于逻辑回归模型的学习经验总结不再适用于神经网络模型。在像诸如神经网络这样的模型中,你可以将基于10倍规则法获取的训练样本数量作为在模型训练中输入的训练样本量的一个下界。

译者注:是在神经网络这样非线性模型中,要想获得一个性能良好的训练模型,所需训练数据最少为模型参数的10倍,实际上所需的训练数据应该比这个还多。

尽管会存在以上的争论,但是我的10倍规则法在大多数问题中都起到了作用。然而,带着对10倍规则法的疑问,你可以在开源工具箱Tensorflow中插入你自己的模型以及对训练样本的数量和模型的训练效果之间的关系作出一个假设,并通过模拟实验来研究模型的训练效果,如果你在运行过程中有了其它见解,欢迎随时分享。

它是一个简单的规则,但是有时候它是一个模型

译者总结

这篇文献主要是探讨了如何通过设置合理的训练样本量来得到一个性能良好的模型。作者在这里向我们介绍了一种可以合理设置训练样本量的10倍规则法——即是训练样本数量是模型参数数量的10倍。以此为基础,引出了两个特例:线性模型如逻辑回归模型、神经网络模型,来得到利用这种方法进行模型训练的过程中可能产生的困惑或者不适用的情况,并且针对逻辑回归线性模型和神经网络模型如何进行改进以及怎么结合10倍规则法获得一个性能良好的训练模型给出了相应的建议。

在我平常所做的模型训练的实验中,我曾经也经常遇到不知如何选取训练样本数量的问题,根据读过的论文的经验,来设置训练数据的数量,不断进行尝试,之前并不知道有这种方法的存在,看了这篇论文获得了一定的启发,训练数据的多少以及特征的贡献程度对一个模型进行分类或者回归至关重要。

额外补充

关于F-score值的介绍

准确率与召回率(Precision&Recall)

准确率和召回率是广泛用于信息检索和统计学分类领域的两个度量值,用来评价结果的质量。

其中精度是检索出相关文档数与检索出文档总数的比率,衡量的是检索系统的查准率;

召回率是指检索出的相关文档数和文库中所有的相关文档数的比率,衡量的是检索系统的查全率。

一般来说,Precision就是检索出来的条目(比如:文档、网页等)有多少是准确的,Recall就是所有准确的条目有多少被检索出来来。

正确率、召回率和F值是在众多训练模型中选出目标的重要指标。

1. 正确率=提取出的正确信息条数/提取出的信息条数

2. 召回率=提取出的正确信息条数/样本中的信息条数

两者取值在0和1之间,越接近数值1,查准率或查全率就越高。

3. F值=正确率*召回率*2/(正确率+召回率)

即F值即是正确率与召回率的平均值,且F值越好,说明模型的性能越好。

关于google开源工具箱Tensorflow

Tensorflow是一个基于流行数据进行数值计算的开源库,类似于我们在进行SVM训练时用的libSVM工具箱一样。

原文发布于微信公众号 - 新智元(AI_era)

原文发表时间:2015-12-04

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏机器之心

资源 | 源自斯坦福CS229,机器学习备忘录在集结

项目地址:https://github.com/afshinea/stanford-cs-229-machine-learning

631
来自专栏云时之间

对交叉验证的一些补充(转)

交叉验证是一种用来评价一个统计分析的结果是否可以推广到一个独立的数据集上的技术。主要用于预测,即,想要估计一个预测模型的实际应用中的准确度。它是一种统计学上将数...

4089
来自专栏AI科技大本营的专栏

阿里团队最新实践:如何解决大规模分类问题?

【AI科技大本营导读】近年来,深度学习已成为机器学习社区的一个主要研究领域。其中一个主要挑战是这种深层网络模型的结构通常很复杂。对于一般的多类别分类任务,所需的...

691
来自专栏机器之心

无需深度学习框架,如何从零开始用Python构建神经网络

1495
来自专栏机器学习算法与理论

浅谈神经网络

一、神经网络介绍 神经网络是由具有适应性的简单单元组成的广泛并行互联的网络,它的组织能够模拟生物神经系统对真实世界物体作出的交互反应。 神经网络中最基本的成分...

3409
来自专栏决胜机器学习

深层神经网络参数调优(三) ——mini-batch梯度下降与指数加权平均

深层神经网络参数调优(三)——mini-batch梯度下降与指数加权平均 (原创内容,转载请注明来源,谢谢) 一、mini-batch梯度下降 1、概述 之前...

3533

为什么不提倡在训练集上检验模型?

在你开始接触机器学习时,通常你会从读取一个数据集并尝试使用不同的模型开始。你可能会疑惑,为什么不用数据集中的所有数据来训练及评估模型呢?

3327
来自专栏CVer

CS229 机器学习速查表

本文经机器之心(微信公众号:almosthuman2014)授权转载,禁止二次转载

781
来自专栏Spark学习技巧

机器学习之学习率 Learning Rate

1142
来自专栏PPV课数据科学社区

机器学习测试题(上)

人工智能一直助力着科技发展,新兴的机器学习正推动着各领域的进步。如今,机器学习的方法已经无处不在—从手机上的语音助手到商业网站的推荐系统,机器学习正以不容忽视...

29212

扫码关注云+社区