前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >机器学习入门 8-5 学习曲线

机器学习入门 8-5 学习曲线

作者头像
触摸壹缕阳光
发布2019-12-26 14:10:49
1.2K0
发布2019-12-26 14:10:49
举报
文章被收录于专栏:AI机器学习与深度学习算法

本系列是《玩转机器学习教程》一个整理的视频笔记。上一小节介绍了模型复杂度曲线,通过这种直观的曲线,可以比较容易的看到模型欠拟合和过拟合的地方,进而选出最合适的模型复杂度。本小节介绍另外一个观察模型欠拟合和过拟合的曲线~"学习曲线"。

01

回顾模型复杂度曲线

先来回顾一下上一小节介绍的模型复杂度的曲线:

将数据集划分为训练数据集和测试数据集,其中训练数据集用于训练模型,而测试数据集用于评估模型的泛化能力,训练学习模型的目的是选出泛化能力最强的模型,而这一系列不同的模型是通过模型复杂度体现的,因此简单来说就是选择在测试集上准确率最高时候的模型复杂度。为了能够选择在测试集上准确率最高时候的模型,模型复杂度曲线就应运而生。

模型复杂度曲线是随着模型复杂度的上升,模型在训练数据集以及测试数据集相应的模型准确率就会有一定的变化,通过这种直观的模型复杂度曲线,可以比较容易的看到模型欠拟合以及过拟合的地方,进而找到对于我们的任务来说,最合适的模型复杂度所在的地方。

在上一小节中,提到模型复杂度曲线是一个理论的趋势,当处理不同的数据运用不同的模型时,有可能绘制不出这么清晰的模型复杂度曲线,而现在学的kNN算法和多项式回归算法就不太适合绘制这样的模型复杂度曲线,当然这些机器学习算法内在都是符合这样的曲线趋势,在后续介绍决策树的时候会绘制模型复杂度曲线,不过对于过拟合和欠拟合,如果还想使用可视化的方式看到过拟合和欠拟合,还有另外一个曲线,这就是本小节重点介绍的"学习曲线"。

2

学习曲线

学习曲线其实非常简单,可以想象一下,我们在学习知识的时候是不断的将新的内容放入我们的大脑中去消化理解,而对于模型来说,所谓的这些知识就是已知的样本信息,学习曲线描述的就是随着训练样本的逐渐增多,算法训练出的模型的表现能力。

接下来,通过具体的编程实践来绘制学习曲线。

  • Step1:创建数据集。
  • Step2:Train_Test_split默认将数据集划分为75%的训练数据集以及25%的测试数据集。
  • Step3:使用线性模型绘制学习曲线。

学习曲线其实就是对75个训练数据,从1开始每一次都多一个训练样本来训练一个全新的模型,据此来观察得到这个模型在训练数据集和测试数据集表现。在这里使用for循环,循环的范围从最极端的情况只指定一个样本进行训练到最多指定75个样本进行训练,在这75次循环中每一次都创建一个新的模型,使用当前循环的样本数对模型进行训练。

首先尝试使用线性模型来绘制学习曲线,从1到75的每一次循环中都会执行如下步骤,假设此时循环次数为i:

  1. 创建LinearRegression线性回归模型对象;
  2. 选择75个训练样本的前i个样本作为此次循环的训练样本;
  3. 对75个训练样本的前i个样本以及所有的测试样本进行预测;
  4. 计算75个训练样本的前i个样本以及所有的测试样本的均方误差,并将每一次的均方误差值分别保存到train_score以及test_score两个列表中,最终train_score和test_score列表长度都为75,表示的是线性模型随着进行训练的数据越来越多,相应得到的模型在训练数据集和测试数据集上性能的变化;

最后就可以把性能的变化绘制出来:

  1. 对于x轴来说就是每次循环进行训练的样本个数,从1到75;
  2. 对于y值就可以传入train_score,此时的train_score是均方误差,值相对来说比较大,需要将结果缩小一点,因此取均方根误差(RMSE),转换过程其实很简单,只需要使用np.sqrt对train_score列表开根即可;
  3. 由于绘制两根曲线,所以在绘制曲线的时候传入label参数对曲线进行标识,调用plt.legend()来显示图例;

上面绘制的曲线图就是对于创建的样本数据来说,使用线性回归模型得到的学习曲线。接下来就来分析线性回归模型的学习曲线:

  1. 先来看一看大体趋势:
    1. 从趋势上很明显在训练数据集上的误差是逐渐升高的,这也非常好理解,因为我们的训练数据越来越多(每一次循环都增加一个样本),训练样本点越多,越难拟合住所有的数据,因此相应的误差会逐渐的累计,不过整体而言,在刚开始的时候,误差的累计比较快,到了一定程度误差的累计其实是非常小的,此时是比较稳定的;
    2. 而对于测试数据集来说,呈现一种下滑的曲线趋势,也就是当我们使用非常少的样本进行训练的时候,刚开始测试误差非常的大,当训练样本多到一定程度的时候,测试误差就会逐渐的减小,减小到一定程度也不会小太多了,达到一种相对稳定的情况。
    3. 在最终的时候,训练误差和测试误差大体是在一个级别上的,不过测试误差还是要比训练误差高一些,这是因为训练数据拟合的过程,可以把训练数据集拟合的比较好,相应的误差会小一些,但是泛化到测试数据上的时候,误差还是可能会大一些,整体学习曲线呈现这样的趋势,为了方便后续对比其他的算法,将前面绘制学习曲线的过程提炼成一个函数。

为了验证封装函数的功能,创建线性回归模型调用绘制算法学习曲线的封装函数:

代码语言:javascript
复制
plot_learning_curve(LinearRegression(), X_train, X_test, y_train, y_test)

两次在相同数据集上绘制的线性回归学习曲线有所不同,主要是因为后续在比较的时候,会在意两根曲线之间的差距,为此在封装绘制学习曲线的函数中对坐标轴显示的范围进行了一定的限定。在本例中,只要将y轴限制在0-4这个范围就好了,知道在刚开始的时候测试数据集上的误差比较大,因此我们可以忽略刚开始测试误差值。

  • Step4:接下来绘制多项式回归的曲线,为了使用多项式回归,需要通过Pipeline管道创建多项式回归对象,使用前面小节封装创建多项式回归的函数。
  • 首先将多项式回归的degree值设置为2。

阶数为2的多项式回归学习曲线如下图所示。

上面就是使用二阶的多项式回归得到的学习曲线,仔细观察一下就会发现,这个学习曲线从整体的趋势来看和使用线性回归得到的学习曲线是一致的,

  1. train这根曲线逐渐上升,上升到一定程度后变得相对比较稳定;
  2. test这根曲线逐渐下降,下降到一定程度也变得比较稳定;

不过仔细观察就会发现,使用二阶多项式回归和线性回归绘制出的学习曲线最大的区别就在于,线性回归稳定的误差大约在1.6、1.7这个位置左右,而对于我们二阶的多项式回归学习曲线,误差稳定在1、0.9左右,二阶多项式回归的学习曲线稳定的位置比较低,这说明使用二阶多项式回归进行数据的拟合,结果比线性回归的拟合结果要好。

  • 将多项式回归的degree的值设置为20。

阶数为20的多项式回归学习曲线如下图所示。

和之前一样,整体的趋势是一致的:

  1. train逐渐上升,当上升到一定程度后相对稳定;
  2. test曲线逐渐下降,当下降到一定程度后相对稳定;

但是从整体上看,上面degree为20的学习曲线和之前的两个学习曲线也有巨大的区别,这个区别在于,train和test这两根曲线在相对比较稳定的时候,他们之间的间距依然是比较大的,这就说明了我们的模型在训练数据集上已经拟合的非常好了,但是在测试数据集上,相应的他的误差依然是很大的,离train的这根曲线比较远,这种情况通常就是过拟合的情况,也就是在训练数据集上表现的很好,但是有了新的数据在测试数据集上表现却不好,模型的泛化能力是不够的。

我们绘制了三种学习曲线,这三种学习曲线分别对应了欠拟合、正合适以及过拟合的情况。

接下来具体的总结比较一下这三张图:

欠拟合和最佳的情况相比较:

相应的train,test这两个曲线趋于稳定的位置,比最佳情况趋于稳定的位置要高一些,说明无论是对测试数据集来说还是训练数据集来说相应的误差都比较大,这是因为本身模型选择的就是不对的,所以即使在训练数据集上误差也是大的。

过拟合和最佳的情况比较:

在训练数据集上,相应的误差不大,和最佳情况下的误差是差不多的,甚至如果更极端一些,degree取值更高的话,训练数据集的误差会更低,但是问题在于,测试数据集的误差相对来说比较大,并且测试数据集的误差离训练数据集的误差比较远,它们之间稳定时候的误差差距比较大,这就说明了此时我们的模型的泛化能力不够好,对于新的数据来说,误差比较大。

在这一小节,通过另外一种学习曲线的方式进一步深刻的认识了什么叫做过拟合什么叫做欠拟合。这几个小节,一直使用的是train_test_split分离的方式来评估模型的泛化能力,但是使用这种方式其实还有一个小问题,在下一小节会说明这个小问题,进而采用一种更加严谨更加标准且相对而言更加费时的方式来评价模型的学习效果。

祝大家圣诞快乐

长按识别二维码关注:为您提供更多更好的知识

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-12-24,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI机器学习与深度学习算法 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档