比较R语言机器学习算法的性能

原文:Compare The Performance of Machine Learning Algorithms in R 译文:http://geek.csdn.net/news/detail/58172 作者: Jason Brownlee 译者:刘翔宇 审校:赵屹华 责编:周建丁 你如何有效地计算出不同机器学习算法的估计准确性?在这篇文章中,你将会学到8种技术,用来比较R语言机器学习算法。你可以使用这些技术来选择最精准的模型,并能够给出统计意义方面的评价,以及相比其它算法的绝对优势。

选择最好的机器学习模型

你如何根据需求选择最好的模型?

在你进行机器学习项目的时候,往往会有许多良好模型可供选择。每个模型都有不同的性能特点。

使用重采样方法,如交叉验证,就可以得到每个模型在未知数据上精准度的估计。你需要利用这些估计从你创建的一系列模型中选择一到两个最好的模型。

仔细比较机器学习模型

当你有了新数据集,使用多种不同的图形技术可视化数据是个好主意,你可以从不同角度来观察数据。

这种想法也可以用于模型选择。你应该使用不同的方法来进行估计机器学习算法的准确率,依此来选择一到两个模型。

你可以使用不同的可视化方法来显示平均准确率、方差和模型精度分布的其他性质。

比较并选择R语言的机器学习模型

在本节中,你将会学到如何客观地比较R语言机器学习模型。

通过本节中的案例研究,你将为皮马印第安人糖尿病数据集创建一些机器学习模型。然后你将会使用一系列不同的可视化技术来比较这些模型的估计准确率。

本案例研究分为三个部分:

  1. 准备数据集:加载库文件和数据集,准备训练模型。
  2. 训练模型:在数据集上训练标准机器学习模型,准备进行评估。
  3. 比较模型:使用8种不同的技术比较训练得到的模型。

准备数据集

本研究案例中使用的数据集是皮马印第安人糖尿病数据集,可在UCI机器学习库中获取。也可在R中的mlbench包中获取。

这是一个二元分类问题,预测患者在五年之内糖尿病是否会发作。入参是数值型,描述了女性患者的医疗信息。

现在来加载库文件和数据集。

# load librarieslibrary(mlbench)library(caret)# load the datasetdata(PimaIndiansDiabetes)

训练模型

在本节中,我们将会训练在下一节中将要比较的5个机器学习模型。

我们将使用重复交叉验证,folds为10,repeats为3,这是比较模型的常用标准配置。评估指标是精度和kappa,因为它们很容易解释。

根据算法的代表性和学习风格方式进行半随机选择。它们有:

  • 分类和回归树
  • 线性判别分析
  • 使用径向基函数的支持向量机
  • K-近邻
  • 随机森林

训练完模型之后,将它们添加到一个list中,然后调用resamples()函数。此函数可以检查模型是可比较的,并且模型都使用同样的训练方案(训练控制配置)。这个对象包含每个待评估算法每次折叠和重复的评估指标。

下一节中我们使用到的函数都需要包含这种数据的对象。

# prepare training schemecontrol <- trainControl(method="repeatedcv", number=10, repeats=3)# CARTset.seed(7)
fit.cart <- train(diabetes~., data=PimaIndiansDiabetes, method="rpart", trControl=control)# LDAset.seed(7)
fit.lda <- train(diabetes~., data=PimaIndiansDiabetes, method="lda", trControl=control)# SVMset.seed(7)
fit.svm <- train(diabetes~., data=PimaIndiansDiabetes, method="svmRadial", trControl=control)# kNNset.seed(7)
fit.knn <- train(diabetes~., data=PimaIndiansDiabetes, method="knn", trControl=control)# Random Forestset.seed(7)
fit.rf <- train(diabetes~., data=PimaIndiansDiabetes, method="rf", trControl=control)# collect resamplesresults <- resamples(list(CART=fit.cart, LDA=fit.lda, SVM=fit.svm, KNN=fit.knn, RF=fit.rf))

比较模型

在本节中,我们将看到8种不同的技术用来比较构建模型的估计精度。

汇总表(Table Summary)

这是你可以做的最简单的比较,只需要调用summary()函数,并传入resamples()函数值。它会创建一个表格,每行是一种算法,每列是评估指标。在这里我们已经整理好了结果。

Accuracy 
       Min. 1st Qu. Median   Mean 3rd Qu.   Max. NA's
CART 0.6234  0.7115 0.7403 0.7382  0.7760 0.8442    0
LDA  0.6711  0.7532 0.7662 0.7759  0.8052 0.8701    0
SVM  0.6711  0.7403 0.7582 0.7651  0.7890 0.8961    0
KNN  0.6184  0.6984 0.7321 0.7299  0.7532 0.8182    0
RF   0.6711  0.7273 0.7516 0.7617  0.7890 0.8571    0

Kappa 
       Min. 1st Qu. Median   Mean 3rd Qu.   Max. NA's
CART 0.1585  0.3296 0.3765 0.3934  0.4685 0.6393    0LDA  0.2484  0.4196 0.4516 0.4801  0.5512 0.7048    0SVM  0.2187  0.3889 0.4167 0.4520  0.5003 0.7638    0KNN  0.1113  0.3228 0.3867 0.3819  0.4382 0.5867    0RF   0.2624  0.3787 0.4516 0.4588  0.5193 0.6781    0

箱线图(Box and Whisker Plots)

这是查看不同模型评估精度伸展和联系的有效方式。

# box and whisker plots to compare models
scales <- list(x=list(relation="free"), y=list(relation="free"))
bwplot(results, scales=scales)

注意到箱线图以平均精度降序排序。我发现观察平均值(点)和箱线图的重叠(中间50%)很有用。

用箱线图比较R语言机器学习算法

密度图(Density Plots)

你可以将模型精度分布显示成密度图。这是种评估算法估计行为重叠的有效方式。

# density plots of accuracyscales <- list(x=list(relation="free"), y=list(relation="free"))
densityplot(results, scales=scales, pch = "|")

我喜欢观察波峰以及分布伸展或分布底部的差异。

比较R语言机器学习算法的密度图

点图(Dot Plots)

这些点非常有用,它显示了平均估计精度以及95%的置信区间(例如,95%观测点所落入的范围)。

# dot plots of accuracyscales <- list(x=list(relation="free"), y=list(relation="free"))
dotplot(results, scales=scales)

我发现比较均值和目测算法间的重叠伸展很有用。

比较R语言机器学习算法的点图

平行线图(Parallel Plots)

这是另一种查看数据的方式。它显示了每个被测算法每次交叉验证折叠试验的行为。它可以帮助你查看一个算法中子集相对其他算法的线性走势。

# parallel plots to compare modelsparallelplot(results)

要对此进行解释需要一些技巧。我认为这在以后对分析不同方法如何在组合预测中结合很有帮助(例如堆叠),尤其当你在相反方向看到有相关运动时。

比较R语言机器学习算法的平行线图

散点图矩阵(Scatterplot Matrix)

这创建了一个算法的所有折叠试验结果与其他算法相同折叠试验结果比较的散点图矩阵。每一对都进行了比较。

# pair-wise scatterplots of predictions to compare modelssplom(results)

这种做法对于考虑两个不同算法的预测是否相关时非常宝贵。如果弱相关,它们可以很好地用于组合预测。

比如,目测图表,好像LDA和SVM呈强相关性,SVM和RF也一样。SVM与CART似乎呈弱相关性。

比较R语言机器学习算法的散点图矩阵

成对XY图(Pairwise xyPlots)

你可以使用xy图,对两种机器学习算法的折叠试验精度进行成对比较。

# xyplot plots to compare modelsxyplot(results, models=c("LDA", "SVM"))

在这种情况下,我们可以看到LDA和SVM模型看似相关的精度。

比较R语言机器学习算法的成对散点图

统计意义检测(Statistical Significance Tests)

你可以计算不同机器学习算法间指标分布差异的意义。我们可以直接调用summary()函数汇总结果。

# difference in model predictionsdiffs <- diff(results)# summarize p-values for pair-wise comparisonssummary(diffs)

我们可以得到一个表格,记录了每对算法的统计意义分数。表格对角线下方显示的是零假设的p值(分布是相同的),值越小越好。我们可以看到CART和kNN之间没有区别,同样能看出LDA和SVM分布相差不大。

表格对角线上方显示的是不同分布的估计差异。观察前面的图表,如果我们认为LDA是最精准的模型,我们可以得出它比其他模型要具体精准多少的估计。

这些分数可以帮助你计算具体算法之间任何精度。

p-value adjustment: bonferroni 
Upper diagonal: estimates of the difference
Lower diagonal: p-value for H0: difference = 0Accuracy 
     CART      LDA       SVM       KNN       RF       
CART           -0.037759 -0.026908  0.008248 -0.023473LDA  0.0050068            0.010851  0.046007  0.014286SVM  0.0919580 0.3390336            0.035156  0.003435KNN  1.0000000 1.218e-05 0.0007092           -0.031721RF   0.1722106 0.1349151 1.0000000  0.0034441

一个好技巧是增加试验次数,来增加种群,获取可以得到更精准p值。你也可以画出它们之间的差异,但是我发现与上面的汇总表相比并没多大用处。

总结

在这篇文章中你学会了8种不同的技术,可以用来比较R语言机器学习算法模型的估计精度。

这8种技术是:

  • 表汇总
  • 箱线图
  • 密度图
  • 点图
  • 平行线图
  • 散点图矩阵
  • 成对XY图
  • 统计意义检测

原文发布于微信公众号 - 大数据挖掘DT数据分析(datadw)

原文发表时间:2016-03-04

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏后端之路

LinkedList源码解读

List中除了ArrayList我们最常用的就是LinkedList了。 LInkedList与ArrayList的最大区别在于元素的插入效率和随机访问效率 ...

19710
来自专栏alexqdjay

HashMap 多线程下死循环分析及JDK8修复

1K4
来自专栏MelonTeam专栏

ArrayList源码完全分析

导语: 这里分析的ArrayList是使用的JDK1.8里面的类,AndroidSDK里面的ArrayList基本和这个一样。 分析的方式是逐个API进行解析 ...

4519
来自专栏Java Edge

AbstractList源码解析1 实现的方法2 两种内部迭代器3 两种内部类3 SubList 源码分析4 RandomAccessSubList 源码:AbstractList 作为 Lis

它实现了 List 的一些位置相关操作(比如 get,set,add,remove),是第一个实现随机访问方法的集合类,但不支持添加和替换

462
来自专栏xingoo, 一个梦想做发明家的程序员

Spark踩坑——java.lang.AbstractMethodError

百度了一下说是版本不一致导致的。于是重新检查各个jar包,发现spark-sql-kafka的版本是2.2,而spark的版本是2.3,修改spark-sql-...

1210
来自专栏拭心的安卓进阶之路

Java 集合深入理解(12):古老的 Vector

今天刮台风,躲屋里看看 Vector ! 都说 Vector 是线程安全的 ArrayList,今天来根据源码看看是不是这么相...

2447
来自专栏赵俊的Java专栏

从源码上分析 ArrayList

1181
来自专栏xingoo, 一个梦想做发明家的程序员

20120918-向量实现《数据结构与算法分析》

#include <iostream> #include <list> #include <string> #include <vector> #include...

1736
来自专栏学海无涯

Android开发之奇怪的Fragment

说起Android中的Fragment,在使用的时候稍加注意,就会发现存在以下两种: v4包中的兼容Fragment,android.support.v4.ap...

3165
来自专栏Phoenix的Android之旅

Java 集合 Vector

List有三种实现,ArrayList, LinkedList, Vector, 它们的区别在于, ArrayList是非线程安全的, Vector则是线程安全...

672

扫码关注云+社区