机器学习:你需要多少训练数据?

从谷歌的机器学习代码中得知,目前需要一万亿个训练样本

训练数据的特性和数量是决定一个模型性能好坏的最主要因素。一旦你对一个模型输入比较全面的训练数据,通常针对这些训练数据,模型也会产生相应的结果。但是,问题是你需要多少训练数据合适呢?这恰恰取决于你正在执行的任务、最终想通过模型实现的性能、现有的输入特征、训练数据中含有的噪声、已经提取的特征中含有的噪声以及模型的复杂性等等诸多因素。所以,发现所有这些变量相互之间有何联系,如何工作的方法即是通过在数量不一的训练样本上训练模型,并且绘制出模型关于各个训练样本集的学习曲线图。你必须已经具有特性比较明显、数量适合的训练数据,才能通过模型的训练学习出感兴趣、性能比较突出的学习曲线图。要实现上述的目的,你不禁会问,当你刚刚着手训练一个模型的时候,你应该怎样做,或者是在你训练模型的过程中,你什么时候能察觉到模型的训练数据过少,并且想要估量出在整个模型训练过程中存在什么样的问题。

所以,针对上述这些问题,代替绝对精准的回答,我们给出一个推测出的,比较实用的拇指规则。其大致过程是:自动生成大量的关于逻辑回归的问题。对于每个生成的逻辑回归问题,学习出训练样本的数量与训练模型性能之间的存在的某种关系。基于一系列的问题观察训练样本的数量与训练模型性能之间的联系,从而得到一个简单的规则——拇指规则。

我不能确定我的模型需要多少训练样本,我将建立一个模型来推测出所需训练样本的数量

这里是生成一系列关于逻辑回归问题和研究基于数量渐变的训练样本在模型上训练效果的代码。通过调用谷歌的开源工具箱Tensorflow执行代码。代码的运行过程中不需要应用到任何软件和硬件,并且我能够在我的笔记本上运行整个实验。随着代码的运行,会得到下面的学习曲线图,如图(1)所示

图(1)中,x轴表示训练样本数量与模型参数数量的比值。y轴是模型的f-score值。图中不同颜色的曲线对应于带有不同参数数量的训练模型。例如,红色曲线表示一个具有128个参数的模型随着训练样本数量128 X 1,128 X 2 等等这样变化时,f-score值的变化情况。

得到的第一个观察结果即是:f-score值不随着参数尺度的变化而变化。通过这一观察结果,我们可以认为给定的模型是线性的,并且令人高兴的是模型中的一些隐含层没有混入非线性。当然,更大的模型需要更多的训练样本,但是若训练样本数量与模型参数数量的比值是给定的,你会获得相同的模型性能。第二个观察结果即是:当训练样本数量与模型参数数量之比为10:1时,f-score值在0.85上下浮动,我们可以称此时的训练模型是一个具有良好性能的模型。通过以上的观察结果可以得出一个10倍规则法——即是要训练出一个性能良好的模型,所需训练样本数量应是模型参数数量的10倍。

因而,借由10倍规则法,将估量训练样本数量的问题转换为只要知道模型中参数数量就可以训练出一个性能良好的模型问题。基于这一点这引发了一些争论:

(1)对于线性模型 ,例如逻辑回归模型。基于每个特征,模型都分配了相应的参数,从而使得参数的数量和输入的特征数量相等,然而这里可能存在一些问题:你的特征可能是稀疏的,所以,计数的特征数量并不是直接的。

译者注:我觉得这句话的意思是,稀疏特征,例如稀疏特征的编码是01001001对于模型的训练能够起到作用的特征是少数的,而不起作用的特征占大多数。依照上述线性规则,若模型对于每个特征分配相应的参数,也就是说对于无用的特征也分配了相应的参数,再根据10倍规则法,获取是模型参数数量10倍的训练样本集,此时的训练样本数量对于最佳的训练模型来说可能是超量的,所以,此时用10倍规则法得到的训练样本集未必能够真实地得出好的训练模型。

(2)由于规范化和特征选择技术,训练模型中真实输入的特征的数量少于原始特征数量。

译者注:我觉得这两点即是在解释上述利用10倍规则法来得到性能良好模型的理论是有一定的局限性,这个理论是相对于输入特征是完备且每个特征对于模型的训练都有一定的贡献程度的。但是对于(1)、(2)这两点所说的稀疏特征和特征降维技术的使用导致特征数量减少的情况,利用10倍规则法能够得到性能良好的模型的理论还有待进一步讨论。

解决上述(1)、(2)问题的一个办法即是:在提取特征时,你不仅要用到有类别标签的数据还要用到不带类别标签的数据来估计特征的数量。例如给定一个文本语料库,在标记数据进行训练之前,你可以通过统计每个单词出现的次数,来生成一个关于单词频率直方图,以此来理解你的特征空间。根据单词频率直方图,你可以去掉长尾词,来获得真实的、主要的特征数量,之后你可以运用10倍规则法来估测在得到性能良好的模型时,你所需要的训练样本数量。

与像逻辑回归这样的线性模型相比,神经网络模型提出了一组不同的问题。为了得到神经网络中参数的数量你需要:

(1)如果你的输入特征是稀疏的,计算嵌入层中(我觉得就是隐含层)参数数量。

(2)计算神经网络模型中的边数。

根本问题是在神经网络中参数之间的关系不再是线性的。所以基于逻辑回归模型的学习经验总结不再适用于神经网络模型。在像诸如神经网络这样的模型中,你可以将基于10倍规则法获取的训练样本数量作为在模型训练中输入的训练样本量的一个下界。

译者注:是在神经网络这样非线性模型中,要想获得一个性能良好的训练模型,所需训练数据最少为模型参数的10倍,实际上所需的训练数据应该比这个还多。

尽管会存在以上的争论,但是我的10倍规则法在大多数问题中都起到了作用。然而,带着对10倍规则法的疑问,你可以在开源工具箱Tensorflow中插入你自己的模型以及对训练样本的数量和模型的训练效果之间的关系作出一个假设,并通过模拟实验来研究模型的训练效果,如果你在运行过程中有了其它见解,欢迎随时分享。

它是一个简单的规则,但是有时候它是一个模型

译者总结

这篇文献主要是探讨了如何通过设置合理的训练样本量来得到一个性能良好的模型。作者在这里向我们介绍了一种可以合理设置训练样本量的10倍规则法——即是训练样本数量是模型参数数量的10倍。以此为基础,引出了两个特例:线性模型如逻辑回归模型、神经网络模型,来得到利用这种方法进行模型训练的过程中可能产生的困惑或者不适用的情况,并且针对逻辑回归线性模型和神经网络模型如何进行改进以及怎么结合10倍规则法获得一个性能良好的训练模型给出了相应的建议。

在我平常所做的模型训练的实验中,我曾经也经常遇到不知如何选取训练样本数量的问题,根据读过的论文的经验,来设置训练数据的数量,不断进行尝试,之前并不知道有这种方法的存在,看了这篇论文获得了一定的启发,训练数据的多少以及特征的贡献程度对一个模型进行分类或者回归至关重要。

额外补充

关于F-score值的介绍

准确率与召回率(Precision&Recall)

准确率和召回率是广泛用于信息检索和统计学分类领域的两个度量值,用来评价结果的质量。

其中精度是检索出相关文档数与检索出文档总数的比率,衡量的是检索系统的查准率;

召回率是指检索出的相关文档数和文库中所有的相关文档数的比率,衡量的是检索系统的查全率。

一般来说,Precision就是检索出来的条目(比如:文档、网页等)有多少是准确的,Recall就是所有准确的条目有多少被检索出来来。

正确率、召回率和F值是在众多训练模型中选出目标的重要指标。

1. 正确率=提取出的正确信息条数/提取出的信息条数

2. 召回率=提取出的正确信息条数/样本中的信息条数

两者取值在0和1之间,越接近数值1,查准率或查全率就越高。

3. F值=正确率*召回率*2/(正确率+召回率)

即F值即是正确率与召回率的平均值,且F值越好,说明模型的性能越好。

关于google开源工具箱Tensorflow

Tensorflow是一个基于流行数据进行数值计算的开源库,类似于我们在进行SVM训练时用的libSVM工具箱一样。

原文发布于微信公众号 - 新智元(AI_era)

原文发表时间:2015-12-04

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏AI研习社

前Twitter资深工程师详解YOLO 2与YOLO 9000目标检测系统

AI研习社按:YOLO是Joseph Redmon和Ali Farhadi等人于2015年提出的第一个基于单个神经网络的目标检测系统。在今年CVPR上,Jose...

3766
来自专栏人工智能头条

从CNN视角看在自然语言处理上的应用

1473
来自专栏专知

【干货】监督学习与无监督学习简介

【导读】本文是一篇入门级的概念介绍文章,主要带大家了解一下监督学习和无监督学习,理解这两类机器学习算法的不同,以及偏差和方差详细阐述。这两类方法是机器学习领域中...

3668
来自专栏CDA数据分析师

机器学习新手必看十大算法

编译 机器之心 原文链接:https://towardsdatascience.com/a-tour-of-the-top-10-algorithms-for...

4136
来自专栏机器之心

教程 | 单级式目标检测方法概述:YOLO与SSD

在这篇文章中,我将概述用于基于卷积神经网络(CNN)的目标检测的深度学习技术。目标检测是很有价值的,可用于理解图像内容、描述图像中的事物以及确定目标在图像中的位...

591
来自专栏机器之心

教程 | 拟合目标函数后验分布的调参利器:贝叶斯优化

3465
来自专栏机器学习算法与Python学习

机器学习(23)之GBDT详解

关键字全网搜索最新排名 【机器学习算法】:排名第一 【机器学习】:排名第一 【Python】:排名第三 【算法】:排名第四 前言 在(机器学习(20)之Adab...

3257
来自专栏GAN&CV

详解机器学习中的梯度消失、爆炸原因及其解决方法

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_25737169/article/d...

1064
来自专栏深度学习自然语言处理

【深度学习】你不了解的细节问题(四)

方法:我们生成两个 12 维高斯混合。高斯具有相同的协方差矩阵,但在每个维度都有一个由 1 隔开的均值。该数据集由 500 个高斯组成,其中 400 个用于训练...

935
来自专栏机器人网

卷积神经网络概念与原理

受Hubel和Wiesel对猫视觉皮层电生理研究启发,有人提出卷积神经网络(CNN),Yann Lecun 最早将CNN用于手写数字识别并一直保持...

782

扫描关注云+社区