推荐|Kaggle机器学习之模型融合(stacking)心得

此文道出了本人学习Stacking入门级应用的心路历程。

在经过了几天漫长的查询资料和整理,脑子不好,理解顿悟花了不少时间。在学习过程中感谢@贝尔塔的模型融合方法,以及这篇文章(作者是章凌豪)。对于两位提供的信息,感激不尽。同时还有Kaggle上一些关于ensemble的文章和代码,比如这篇。

本文适用于被stacking折磨的死去活来的新手,在网上为数不多的stacking内容里,我已经假设你早已经看过了上述所提到的那几篇有用的文章了。但是,看完之后内心还是卧槽的。我希望下面的内容能成为,你在学习stacking的曲折道路上的一个小火把,给你提供一些微弱的光亮。

本文以Kaggle的Titanic(泰坦尼克预测)入门比赛来讲解stacking的应用(两层!)。

数据的行数:train.csv有890行,也就是890个人,test.csv有418行(418个人)。

而数据的列数就看你保留了多少个feature了,因人而异。我自己的train保留了 7+1(1是预测列)。

在网上为数不多的stacking内容里,相信你早看过了这张图:

这张图,如果你能一下子就能看懂,那就OK。

如果一下子看不懂,就麻烦了,在接下来的一段时间内,你就会卧槽卧槽地持续懵逼......

因为这张图极具‘误导性’。(注意!我没说这图是错的,尽管它就是错的!!!但是在网上为数不多教学里有张无码图就不错啦,感恩吧,我这个小弱鸡)。

我把图改了一下:

对于每一轮的 5-fold,Model 1都要做满5次的训练和预测。

Titanic 栗子:

Train Data有890行。(请对应图中的上层部分)

每1次的fold,都会生成 713行 小train, 178行 小test。我们用Model 1来训练 713行的小train,然后预测 178行 小test。预测的结果是长度为 178 的预测值。

这样的动作走5次! 长度为178 的预测值 X 5 = 890 预测值,刚好和Train data长度吻合。这个890预测值是Model 1产生的,我们先存着,因为,一会让它将是第二层模型的训练来源。

重点:这一步产生的预测值我们可以转成 890 X 1 (890 行,1列),记作 P1 (大写P)

接着说 Test Data 有 418 行。(请对应图中的下层部分,对对对,绿绿的那些框框)

每1次的fold,713行 小train训练出来的Model 1要去预测我们全部的Test Data(全部!因为Test Data没有加入5-fold,所以每次都是全部!)。此时,Model 1的预测结果是长度为418的预测值。

这样的动作走5次!我们可以得到一个 5 X 418 的预测值矩阵。然后我们根据行来就平均值,最后得到一个 1 X 418 的平均预测值。

重点:这一步产生的预测值我们可以转成 418 X 1 (418行,1列),记作 p1 (小写p)

走到这里,你的第一层的Model 1完成了它的使命。

第一层还会有其他Model的,比如Model 2,同样的走一遍, 我们有可以得到 890 X 1 (P2) 和 418 X 1 (p2) 列预测值。

这样吧,假设你第一层有3个模型,这样你就会得到:

来自5-fold的预测值矩阵 890 X 3,(P1,P2, P3) 和 来自Test Data预测值矩阵 418 X 3, (p1, p2, p3)。

-----------------------------------------

到第二层了..................

来自5-fold的预测值矩阵 890 X 3 作为你的Train Data,训练第二层的模型 来自Test Data预测值矩阵 418 X 3 就是你的Test Data,用训练好的模型来预测他们吧。

---------------------------------------

最后 ,放出一张Python的Code,在网上为数不多的stacking内容里, 这个几行的code你也早就看过了吧,我之前一直卡在这里,现在加上一点点注解,希望对你有帮助:

谢谢您观看本文,最后........

原文发布于微信公众号 - IT派(it_pai)

原文发表时间:2017-10-02

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏罗晖的专栏

使用 Q-Learning 实现 FlappyBird AI

Q-Learning 是一种 off-policy 的强化学习算法, 本文介绍如何使用 Q-Learning 来实现FlappyBird 游戏的 AI,附件中给...

1.6K00
来自专栏人工智能头条

机器学习在web攻击检测中的应用实践

16950
来自专栏Renderbus云渲染农场

如何基于一张图片来创建3D模型?包含哪些步骤?

首先,基于一张图片(一个角度)来制作一个3D模型是不大现实的。因为三维物体是立体的、多维的,单从一个角度去观察很难判断物体其他视角的构造,制作出来的模型也就一个...

16000
来自专栏CVer

CVPR 2018 收录论文名单全公布

本文将介绍 CVPR 2018 所有录用论文的标题, 包括每篇论文属于 oral, spotlight还是 poster的情况. 大家可以根据论文的标题去 go...

20720
来自专栏深度学习与数据挖掘实战

【AI头条&优质资源】时间序列预测模型:使用深度神经网络RNN+Attention机制

放arxiv那天看了一下,整篇paper思路读下来还是非常清晰的,实验效果也很不错。

50620
来自专栏人工智能

C+实现神经网络之三—神经网络的训练和测试

前言 在之前的博客中我们已经实现了Net类的设计和前向传播和反向传播的过程。可以说神经网络的核心的部分已经完成。接下来就是应用层面了。要想利用神经网络解决实际的...

22590
来自专栏AI研习社

论文推荐 | 很可能出现在下一代PS中的深度门卷积图像补全技术

Yu Jiahui /Lin Zhe /Yang Jimei /Shen Xiaohui /Lu Xin /Huang Thomas S.

21920
来自专栏机器学习人工学weekly

机器学习人工学weekly-2018/8/26

Safety-first AI for autonomous data centre cooling and industrial control

11830
来自专栏生信小驿站

R.python常见问题③(xgboost介绍和安装)

XGBoost :eXtreme Gradient Boosting 项目地址:https://github.com/dmlc/xgboost

17010
来自专栏CDA数据分析师

如何利用 Excel 进行高级数据分析?

高级的数据分析会涉及回归分析、方差分析和T检验等方法,不要看这些内容貌似跟日常工作毫无关系,其实往高处走,MBA的课程也是包含这些内容的,所以早学晚学都得学,干...

24280

扫码关注云+社区

领取腾讯云代金券