前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >数据科学31 |机器学习-模型评价

数据科学31 |机器学习-模型评价

作者头像
王诗翔呀
发布2020-07-09 15:24:29
1K0
发布2020-07-09 15:24:29
举报
文章被收录于专栏:优雅R优雅R

错误类型

二元预测:

决策类型:真阳性、假阳性、真阴性、假阴性。

关键指标:

・灵敏度:真阳性/(真阳性+假阴性)

・特异性:真阴性/(假阳性+真阴性)

・阳性预测值:真阳性/(真阳性+假阳性)

・阴性预测值:真阴性/(假阴性+真阴性)

・准确性:(真阳性+真阴性)/(真阳性+假阳性+真阴性+假阴性)

连续数据:

均方误差(Mean squared error,MSE):

均方根误差(Root mean squared error,RMSE):

常见错误指标:
  1. MSE/RMSE 用于连续型数据,对离群点敏感
  2. 中值绝对偏差 取观测值和预测值之间的距离的绝对值的中位数,用于连续型数据
  3. 灵敏度 减少假阴性
  4. 特异性 减少假阳性
  5. 准确性 对假阳性、假阴性平均加权
  6. 一致性

ROC曲线

在二元预测中,通常会估计样本出现其中一种结局(如阳性)的概率,需要找到一个常数,即阈值(threshold)或门槛值(cutoff value),若概率值大于阈值,则预测为阳性。通过变动这一阈值,可以改变预测的特异性和灵敏度。

变动阈值可能带来的影响可以通过来进一步观察,ROC曲线可对一个区间内的门槛值画出特异性和敏感度之间的关系。

应用:利用ROC曲线可以找出合适的阈值,通过比较不同算法的ROC曲线可以选择最有效的算法。

ROC 曲线是以灵敏度(真阳性)为y轴、以1-特异性(假阴性)为x 轴,曲线上的点对应特定的阈值。

图1.ROC曲线

ROC曲线下的面积AUC:

ROC曲线下的面积AUC(Area under the curve)可以用来比较不同算法的优劣。

・AUC=0.5,预测算法表示为图中45º斜线,相当于随机对样本进行分类。

・AUC=1,预测算法表示为图中左上角顶点,在这个阈值下,可以得到100%的灵敏度和特异性,是个完美的分类器。

・通常AUC>0.8时可以认为是良好的预测算法。

图2.AUC评价算法优劣

交叉验证(cross validation)

使用训练集建立模型,然后将模型回代到训练集验证模型的有效性,通常会得到较好的验证效果,但由于可能存在过度拟合,而模型未必真的有效,因此需要用独立的新的数据集验证模型是否有效,来获得更好的模型参数估计、更高的测试集准确性。但是实际上不能用测试集进行验证,否则某种意义上测试集变成训练集的一部分,特别是新的样本数据难以收集时。

交叉验证法可以评价模型的泛化能力,而且可以用于某些参数的确定、变量的筛选等。

交叉验证将已有的样本训练集再分为训练集和测试集两部分,根据新的训练集建立模型,使用另一部分测试集进行验证,重复过程可以计算平均估计误差。常见的误差衡量标准是均方误差和均方根误差, 分别为交叉验证的方差和标准差。

1. 随机再抽样验证(random subsampling cross-validation):

图3.随机再抽样验证

重复随机抽取测试集样本,计算平均估计误差。

2. K重交叉验证(K-fold cross-validation):

图4.K重交叉验证

将样本分为k个子样本,轮流将k–1个子样本组合作为训练集建立模型,另外1个子样本作为测试集,计算平均估计误差。

3. 留一交叉验证(leave-one-out cross-validation, LOOCV)

图5.留一交叉验证

只使用原本样本中的一项来当做测试集,而其余的作为训练集,重复步骤直到每个样本都被当作一次测试集,相当于k为原本样本个数的K重交叉验证。

所有这些模型的建立和评估都在训练集中进行,我们将其分为子训练集和子测试集以评估模型。

注意:

  1. 对于时间序列数据,一个时间点可能取决于先前的时间点,如果仅对数据进行随机再抽样,可能会忽略许多重要信息,因此必须使用连续的时间段数据。
  2. 对于K重交叉验证,一般而言,随着k的增加,偏差会变小(模型回代效果好),但方差会增大(验证效果差)。
  3. 随机抽样必须是无放回抽样,有放回抽样(bootstrap,自举法)会低估误差。
  4. 交叉验证得到的模型必须应用到新的独立的训练数据集以得到实际的训练集误差。

数据要求

预测有关X的某些信息,请尽可能使用与X密切相关的数据,数据相关性越低,预测越难。了解数据实际上如何与实际尝试预测的事物相关联非常重要,这是机器学习中最常犯的错误,机器学习通常被认为是一种黑箱预测程序,在一端输入数据,在另一端得到预测结果。

caret 包

内置函数:

・预处理:preProcess()函数

・数据分割:createDataPartition()函数、createTimeSlices()函数、createResample()函数

・训练和测试:train()函数、predict()函数

・模型比较:confusionMatrix()函数

R中内置的机器学习算法:

・线性判别分析(Linear discriminant analysis)

・回归分析(Regression)

・朴素(Naive Bayes)

・支持向量机(Support vector machines)

・分类回归树(Classification and regression trees)

・随机森林(Random forests)

・提升(boosting),等等

R中有来自不同的开发者开发的不同的机器学习算法,每种算法都略有不同。

表1 不同R包中的机器学习算法的预测函数

算法类型

R包

predict()函数语法

lda

MASS

predict(obj)(不需设置选项)

glm

stats

predict(obj, type = "response")

gbm

gbm

predict(obj, type = "response", n.trees)

mda

mda

predict(obj, type = "posterior")

rpart

rpart

predict(obj, type = "prob")

Weka

RWeka

predict(obj, type = "probability")

LogitBoost

caTools

predict(obj, type = "raw", nIter)

使用以上算法应用predict()函数预测时必须传递不同的type选项参数。caret包提供了一个统一的框架,允许只使用一种函数且不需指定选项来进行预测。

例:spam数据集

将数据分为训练集和测试集:

代码语言:javascript
复制
library(caret)
library(kernlab)
data(spam)
inTrain <- createDataPartition(y=spam$type,
                               p=0.75, list = FALSE) #75%的数据作为训练集
training <- spam[inTrain, ]
testing <- spam[-inTrain, ]
dim(training)
[1] 3451   58

拟合模型:

代码语言:javascript
复制
set.seed(32343)
modelFit <- train(type ~., data = training, method="glm")
modelFit
Generalized Linear Model

3451 samples
  57 predictor
   2 classes: 'nonspam', 'spam'

No pre-processing
Resampling: Bootstrapped (25 reps) 
Summary of sample sizes: 3451, 3451, 3451, 3451, 3451, 3451, ... 
Resampling results:

  Accuracy   Kappa    
  0.9207464  0.8333755

采用重抽样的方式测试模型,选择最佳模型。进行25次有放回重抽样,并校正了自举抽样可能带来的潜在偏差。

查看拟合的最佳模型:

代码语言:javascript
复制
modelFit <- train(type ~., data = training, method="glm")
modelFit$finalModel
Call:  NULL

Coefficients:
      (Intercept)               make            address                all  
       -1.541e+00         -3.248e-01         -1.448e-01         -4.829e-02  
            num3d                our               over             remove  
        1.968e+00          4.750e-01          5.455e-01          2.218e+00  
         internet              order               mail            receive  
        6.016e-01          5.926e-01          1.638e-01         -8.654e-01  
             will             people             report          addresses  
       -1.416e-01          8.584e-02          1.268e-01          2.017e+00  
             free           business              email                you  
        9.468e-01          1.084e+00          1.746e-01          9.129e-02  
           credit               your               font             num000  
        2.903e+00          3.356e-01          1.740e-01          1.933e+00  
            money                 hp                hpl             george  
        2.605e-01         -2.556e+00         -7.866e-01         -9.418e+00  
           num650                lab               labs             telnet  
        6.231e-01         -2.372e+00         -3.618e-01         -7.320e-02  
           num857               data             num415              num85  
        1.305e-01         -4.910e-01         -1.918e+01         -5.939e+00  
       technology            num1999              parts                 pm  
        1.180e+00          1.585e-01         -5.811e-01         -9.438e-01  
           direct                 cs            meeting           original  
       -2.426e-01         -4.281e+01         -3.005e+00         -1.240e+00  
          project                 re                edu              table  
       -1.790e+00         -8.871e-01         -1.212e+00         -1.981e+00  
       conference      charSemicolon   charRoundbracket  charSquarebracket  
       -4.245e+00         -1.269e+00         -1.424e-01         -4.982e-01  
  charExclamation         charDollar           charHash         capitalAve  
        3.084e-01          5.776e+00          3.337e+00         -1.680e-02  
      capitalLong       capitalTotal  
        1.199e-02          6.359e-04

Degrees of Freedom: 3450 Total (i.e. Null);  3393 Residual
Null Deviance:	    4628 
Residual Deviance: 1367 	AIC: 1483

进行预测:

代码语言:javascript
复制
predictions <- predict(modelFit, newdata = testing)
predictions
 [1] spam    spam    spam    nonspam spam    spam    spam    spam    spam    spam   
[11] spam    spam    spam    spam    spam    spam    spam    spam    spam    nonspam
[21] spam    spam    spam    nonspam spam    spam    spam    nonspam spam    nonspam
[31] spam    spam    spam    spam    spam    spam    spam    spam    spam    spam   
[41] spam    spam    spam    spam    spam    nonspam spam    spam    spam    spam   
[51] nonspam spam    spam    nonspam spam    spam    spam    spam    spam    nonspam
  ……
Levels: nonspam spam

比较预测结果与实际结果,得到汇总统计信息:

代码语言:javascript
复制
confusionMatrix(predictions, testing$type)
Confusion Matrix and Statistics

          Reference
Prediction nonspam spam  #预测效果
   nonspam     660   59
   spam         37  394
                                         
               Accuracy : 0.9165         
                 95% CI : (0.899, 0.9319) #准确性的置信区间
    No Information Rate : 0.6061         
    P-Value [Acc > NIR] : < 2e-16        
                                         
                  Kappa : 0.8237         
                                         
 Mcnemar's Test P-Value : 0.03209        
                                         
            Sensitivity : 0.9469  #灵敏度       
            Specificity : 0.8698  #特异性     
         Pos Pred Value : 0.9179         
         Neg Pred Value : 0.9142         
             Prevalence : 0.6061         
         Detection Rate : 0.5739         
   Detection Prevalence : 0.6252         
      Balanced Accuracy : 0.9083         
                                         
       'Positive' Class : nonspam
本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-07-07,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 优雅R 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 错误类型
    • 常见错误指标:
    • ROC曲线
    • 交叉验证(cross validation)
    • 数据要求
    • caret 包
    领券
    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档