比较R语言机器学习算法的性能

原文:Compare The Performance of Machine Learning Algorithms in R 译文:http://geek.csdn.net/news/detail/58172 作者: Jason Brownlee 译者:刘翔宇 审校:赵屹华 责编:周建丁 你如何有效地计算出不同机器学习算法的估计准确性?在这篇文章中,你将会学到8种技术,用来比较R语言机器学习算法。你可以使用这些技术来选择最精准的模型,并能够给出统计意义方面的评价,以及相比其它算法的绝对优势。

选择最好的机器学习模型

你如何根据需求选择最好的模型?

在你进行机器学习项目的时候,往往会有许多良好模型可供选择。每个模型都有不同的性能特点。

使用重采样方法,如交叉验证,就可以得到每个模型在未知数据上精准度的估计。你需要利用这些估计从你创建的一系列模型中选择一到两个最好的模型。

仔细比较机器学习模型

当你有了新数据集,使用多种不同的图形技术可视化数据是个好主意,你可以从不同角度来观察数据。

这种想法也可以用于模型选择。你应该使用不同的方法来进行估计机器学习算法的准确率,依此来选择一到两个模型。

你可以使用不同的可视化方法来显示平均准确率、方差和模型精度分布的其他性质。

比较并选择R语言的机器学习模型

在本节中,你将会学到如何客观地比较R语言机器学习模型。

通过本节中的案例研究,你将为皮马印第安人糖尿病数据集创建一些机器学习模型。然后你将会使用一系列不同的可视化技术来比较这些模型的估计准确率。

本案例研究分为三个部分:

  1. 准备数据集:加载库文件和数据集,准备训练模型。
  2. 训练模型:在数据集上训练标准机器学习模型,准备进行评估。
  3. 比较模型:使用8种不同的技术比较训练得到的模型。

准备数据集

本研究案例中使用的数据集是皮马印第安人糖尿病数据集,可在UCI机器学习库中获取。也可在R中的mlbench包中获取。

这是一个二元分类问题,预测患者在五年之内糖尿病是否会发作。入参是数值型,描述了女性患者的医疗信息。

现在来加载库文件和数据集。

# load librarieslibrary(mlbench)library(caret)# load the datasetdata(PimaIndiansDiabetes)

训练模型

在本节中,我们将会训练在下一节中将要比较的5个机器学习模型。

我们将使用重复交叉验证,folds为10,repeats为3,这是比较模型的常用标准配置。评估指标是精度和kappa,因为它们很容易解释。

根据算法的代表性和学习风格方式进行半随机选择。它们有:

  • 分类和回归树
  • 线性判别分析
  • 使用径向基函数的支持向量机
  • K-近邻
  • 随机森林

训练完模型之后,将它们添加到一个list中,然后调用resamples()函数。此函数可以检查模型是可比较的,并且模型都使用同样的训练方案(训练控制配置)。这个对象包含每个待评估算法每次折叠和重复的评估指标。

下一节中我们使用到的函数都需要包含这种数据的对象。

# prepare training schemecontrol <- trainControl(method="repeatedcv", number=10, repeats=3)# CARTset.seed(7)
fit.cart <- train(diabetes~., data=PimaIndiansDiabetes, method="rpart", trControl=control)# LDAset.seed(7)
fit.lda <- train(diabetes~., data=PimaIndiansDiabetes, method="lda", trControl=control)# SVMset.seed(7)
fit.svm <- train(diabetes~., data=PimaIndiansDiabetes, method="svmRadial", trControl=control)# kNNset.seed(7)
fit.knn <- train(diabetes~., data=PimaIndiansDiabetes, method="knn", trControl=control)# Random Forestset.seed(7)
fit.rf <- train(diabetes~., data=PimaIndiansDiabetes, method="rf", trControl=control)# collect resamplesresults <- resamples(list(CART=fit.cart, LDA=fit.lda, SVM=fit.svm, KNN=fit.knn, RF=fit.rf))

比较模型

在本节中,我们将看到8种不同的技术用来比较构建模型的估计精度。

汇总表(Table Summary)

这是你可以做的最简单的比较,只需要调用summary()函数,并传入resamples()函数值。它会创建一个表格,每行是一种算法,每列是评估指标。在这里我们已经整理好了结果。

Accuracy 
       Min. 1st Qu. Median   Mean 3rd Qu.   Max. NA's
CART 0.6234  0.7115 0.7403 0.7382  0.7760 0.8442    0
LDA  0.6711  0.7532 0.7662 0.7759  0.8052 0.8701    0
SVM  0.6711  0.7403 0.7582 0.7651  0.7890 0.8961    0
KNN  0.6184  0.6984 0.7321 0.7299  0.7532 0.8182    0
RF   0.6711  0.7273 0.7516 0.7617  0.7890 0.8571    0

Kappa 
       Min. 1st Qu. Median   Mean 3rd Qu.   Max. NA's
CART 0.1585  0.3296 0.3765 0.3934  0.4685 0.6393    0LDA  0.2484  0.4196 0.4516 0.4801  0.5512 0.7048    0SVM  0.2187  0.3889 0.4167 0.4520  0.5003 0.7638    0KNN  0.1113  0.3228 0.3867 0.3819  0.4382 0.5867    0RF   0.2624  0.3787 0.4516 0.4588  0.5193 0.6781    0

箱线图(Box and Whisker Plots)

这是查看不同模型评估精度伸展和联系的有效方式。

# box and whisker plots to compare models
scales <- list(x=list(relation="free"), y=list(relation="free"))
bwplot(results, scales=scales)

注意到箱线图以平均精度降序排序。我发现观察平均值(点)和箱线图的重叠(中间50%)很有用。

用箱线图比较R语言机器学习算法

密度图(Density Plots)

你可以将模型精度分布显示成密度图。这是种评估算法估计行为重叠的有效方式。

# density plots of accuracyscales <- list(x=list(relation="free"), y=list(relation="free"))
densityplot(results, scales=scales, pch = "|")

我喜欢观察波峰以及分布伸展或分布底部的差异。

比较R语言机器学习算法的密度图

点图(Dot Plots)

这些点非常有用,它显示了平均估计精度以及95%的置信区间(例如,95%观测点所落入的范围)。

# dot plots of accuracyscales <- list(x=list(relation="free"), y=list(relation="free"))
dotplot(results, scales=scales)

我发现比较均值和目测算法间的重叠伸展很有用。

比较R语言机器学习算法的点图

平行线图(Parallel Plots)

这是另一种查看数据的方式。它显示了每个被测算法每次交叉验证折叠试验的行为。它可以帮助你查看一个算法中子集相对其他算法的线性走势。

# parallel plots to compare modelsparallelplot(results)

要对此进行解释需要一些技巧。我认为这在以后对分析不同方法如何在组合预测中结合很有帮助(例如堆叠),尤其当你在相反方向看到有相关运动时。

比较R语言机器学习算法的平行线图

散点图矩阵(Scatterplot Matrix)

这创建了一个算法的所有折叠试验结果与其他算法相同折叠试验结果比较的散点图矩阵。每一对都进行了比较。

# pair-wise scatterplots of predictions to compare modelssplom(results)

这种做法对于考虑两个不同算法的预测是否相关时非常宝贵。如果弱相关,它们可以很好地用于组合预测。

比如,目测图表,好像LDA和SVM呈强相关性,SVM和RF也一样。SVM与CART似乎呈弱相关性。

比较R语言机器学习算法的散点图矩阵

成对XY图(Pairwise xyPlots)

你可以使用xy图,对两种机器学习算法的折叠试验精度进行成对比较。

# xyplot plots to compare modelsxyplot(results, models=c("LDA", "SVM"))

在这种情况下,我们可以看到LDA和SVM模型看似相关的精度。

比较R语言机器学习算法的成对散点图

统计意义检测(Statistical Significance Tests)

你可以计算不同机器学习算法间指标分布差异的意义。我们可以直接调用summary()函数汇总结果。

# difference in model predictionsdiffs <- diff(results)# summarize p-values for pair-wise comparisonssummary(diffs)

我们可以得到一个表格,记录了每对算法的统计意义分数。表格对角线下方显示的是零假设的p值(分布是相同的),值越小越好。我们可以看到CART和kNN之间没有区别,同样能看出LDA和SVM分布相差不大。

表格对角线上方显示的是不同分布的估计差异。观察前面的图表,如果我们认为LDA是最精准的模型,我们可以得出它比其他模型要具体精准多少的估计。

这些分数可以帮助你计算具体算法之间任何精度。

p-value adjustment: bonferroni 
Upper diagonal: estimates of the difference
Lower diagonal: p-value for H0: difference = 0Accuracy 
     CART      LDA       SVM       KNN       RF       
CART           -0.037759 -0.026908  0.008248 -0.023473LDA  0.0050068            0.010851  0.046007  0.014286SVM  0.0919580 0.3390336            0.035156  0.003435KNN  1.0000000 1.218e-05 0.0007092           -0.031721RF   0.1722106 0.1349151 1.0000000  0.0034441

一个好技巧是增加试验次数,来增加种群,获取可以得到更精准p值。你也可以画出它们之间的差异,但是我发现与上面的汇总表相比并没多大用处。

总结

在这篇文章中你学会了8种不同的技术,可以用来比较R语言机器学习算法模型的估计精度。

这8种技术是:

  • 表汇总
  • 箱线图
  • 密度图
  • 点图
  • 平行线图
  • 散点图矩阵
  • 成对XY图
  • 统计意义检测

原文发布于微信公众号 - 大数据挖掘DT数据分析(datadw)

原文发表时间:2016-03-04

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏新智元

CVPR清华大学研究,高效视觉目标检测框架RON

【新智元导读】当前最好的基于深度网络的目标检测框架可以分为两个主要方法流派:基于区域的方法(region-based)和不基于区域(region-free)的方...

3587
来自专栏AlgorithmDog的专栏

PCA模型加先验

大清牛人曰:ML派坐落美利坚合众山中,百年来武学奇才辈出,隐然成江湖第一大名门正派,门内有三套入门武功,曰:图模型加圈,神经网加层,优化目标加正则。...

17510
来自专栏Brian

数据挖掘

---- 概述 最近一直在学习数据挖掘和机器学习,无论是是服务端开发人员还是web开发人员,个人觉得最起码都要都一些最基本的数据挖掘和机器学习知识。废话少说,我...

2695
来自专栏AI科技评论

视频 | 从图像集合中学习特定类别的网格重建

AI 科技评论按:本文为雷锋字幕组编译的论文解读短视频,原标题 Learning Category-Specific Mesh Reconstruction ...

954
来自专栏PPV课数据科学社区

【R系列】概率基础和R语言

R语言是统计语言,概率又是统计的基础,所以可以想到,R语言必然要从底层API上提供完整、方便、易用的概率计算的函数。让R语言帮我们学好概率的基础课。 1. 随机...

2698
来自专栏机器之心

入门 | 10个例子带你了解机器学习中的线性代数

1366
来自专栏AI科技评论

开发 | AI股市预测实战:用LSTM神经网络预测沪深300未来五日收益率

LSTM Networks(长短期记忆神经网络)简介 LSTM Networks 是递归神经网络(RNNs)的一种,该算法由 Sepp Hochreiter...

2675
来自专栏灯塔大数据

塔秘 | 详解用深度学习方法处理结构化数据

导读 鉴于使用深度学习方法按照本文所介绍的步骤处理结构化数据有以下的好处:快;无需领域知识;表现优良,本文主要详细讲述如何用深度学习方法处理结构化数据。 在机器...

3498
来自专栏专知

【干货】Python机器学习项目实战2——模型选择,超参数调整和评估(附代码)

1312
来自专栏AI科技评论

开发 | Google 软件工程师解读:深度学习的activation function哪家强?

AI科技评论按:本文作者夏飞,清华大学计算机软件学士,卡内基梅隆大学人工智能硕士。现为谷歌软件工程师。本文首发于知乎,AI科技评论获授权转载。 ? TLDR (...

3784

扫描关注云+社区