前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >mlr3基础(二)

mlr3基础(二)

作者头像
王诗翔呀
发布2021-09-24 16:02:58
2.7K0
发布2021-09-24 16:02:58
举报
文章被收录于专栏:优雅R优雅R

重采样

重采样策略通常用来评估学习算法的性能。mlr3包含了以下预定义的重采样策略:

  • 交叉验证 - cv[29]
  • 留一交叉验证 - loo[30]
  • 重复交叉验证 - repeated_cv[31]
  • bootstrapping - bootstrap[32]
  • 二次抽样 - subsampling[33]
  • holdout - holdout[34]
  • 样本重采样 - insample[35]
  • 自定义重采样 - custom[36]

以下部分提供了如何设置和选择重采样策略以及如何随后实例化重采样过程的指导。

以下是重采样过程的图示:

机器学习流程 source: https://mlr3book.mlr-org.com/images/ml_abstraction.svg

Figure 3: 机器学习流程 source: https://mlr3book.mlr-org.com/images/ml_abstraction.svg

设置

在本例中,我们再次使用了penguins任务和rpart包中的一个简单分类树。

代码语言:javascript
复制
library("mlr3verse")

task = tsk("penguins")
learner = lrn("classif.rpart")

在对数据集执行重采样时,我们首先需要定义应该使用哪种方法。mlr3重采样策略及其参数可以通过查看数据进行查询。mlr_resamplings字典的表输出:

代码语言:javascript
复制
as.data.table(mlr_resamplings)
out            key        params iters
out 1:   bootstrap ratio,repeats    30
out 2:      custom                  NA
out 3:   custom_cv                  NA
out 4:          cv         folds    10
out 5:     holdout         ratio     1
out 6:    insample                   1
out 7:         loo                  NA
out 8: repeated_cv folds,repeats   100
out 9: subsampling ratio,repeats    30

用于特殊用例的额外重采样方法将通过扩展包提供,例如用于空间数据的mlr3spatiotemporal[37]

在前面进行的模型拟合相当于“holdout 重采样”,所以让我们首先考虑这个。同样,我们可以通过$get()或方便的function rsmp()从字典mlr_resamplings中检索元素:

代码语言:javascript
复制
resampling = rsmp("holdout")
print(resampling)
out <ResamplingHoldout> with 1 iterations
out * Instantiated: FALSE
out * Parameters: ratio=0.6667

注意$is_instantiated字段被设置为FALSE。这意味着我们还没有在数据集上实际应用该策略。在下一节实例化中对数据集应用该策略。

默认情况下,我们得到.66/.33数据的分割。有两种方法可以改变比例:

  1. 使用命名列表覆盖param_setvalues中的槽:
代码语言:javascript
复制
resampling$param_set$values = list(ratio = 0.8)
  1. 使用时直接指定重采样参数:
代码语言:javascript
复制
rsmp("holdout", ratio = 0.8)
out <ResamplingHoldout> with 1 iterations
out * Instantiated: FALSE
out * Parameters: ratio=0.8

实例化

到目前为止,我们只是设置和选择了重采样策略。

为了实际执行分割并获得训练和测试分割的指标,重采样需要一个Task。通过调用instantiate()方法,我们将数据的索引分解为用于训练集和测试集的索引。这些结果索引存储在Resampling对象中。为了更好地说明以下操作,我们切换到一个3折交叉验证:

代码语言:javascript
复制
resampling = rsmp("cv", folds = 3)
resampling$instantiate(task)
resampling$iters
out [1] 3
str(resampling$train_set(1))
out  int [1:229] 3 4 5 8 14 17 20 28 30 35 ...
str(resampling$test_set(1))
out  int [1:115] 1 6 9 10 11 19 21 24 25 31 ...

请注意,如果你想以公平的方式比较多个学习器,则必须对每个学习器使用相同的实例化重采样。下一节基准测试将讨论一种大大简化多个学习器之间比较的方法。

执行

对于一个任务,一个学习者和一个重采样对象,我们可以调用resample(),它根据给定的重采样策略重复地将学习器应用于手头的任务。这又创建了一个ResampleResult对象。我们告诉resample()通过将store_models选项设置为true来保留拟合的模型,然后开始计算:

代码语言:javascript
复制
task = tsk("penguins")
learner = lrn("classif.rpart", maxdepth = 3, predict_type = "prob")
resampling = rsmp("cv", folds = 3)

rr = resample(task, learner, resampling, store_models = TRUE)
out INFO  [21:44:36.748] [mlr3]  Applying learner 'classif.rpart' on task 'penguins' (iter 3/3) 
out INFO  [21:44:36.795] [mlr3]  Applying learner 'classif.rpart' on task 'penguins' (iter 1/3) 
out INFO  [21:44:36.829] [mlr3]  Applying learner 'classif.rpart' on task 'penguins' (iter 2/3)
print(rr)
out <ResampleResult> of 3 iterations
out * Task: penguins
out * Learner: classif.rpart
out * Warnings: 0 in 0 iterations
out * Errors: 0 in 0 iterations

rr存储返回的ResampleResult提供了各种getter方法来访问存储的信息:

  • 计算所有重采样迭代的平均性能:
代码语言:javascript
复制
rr$aggregate(msr("classif.ce"))
out classif.ce 
out 0.06969235
  • 提取单个重采样迭代的性能:
代码语言:javascript
复制
rr$score(msr("classif.ce"))
out                 task  task_id                   learner    learner_id
out 1: <TaskClassif[47]> penguins <LearnerClassifRpart[36]> classif.rpart
out 2: <TaskClassif[47]> penguins <LearnerClassifRpart[36]> classif.rpart
out 3: <TaskClassif[47]> penguins <LearnerClassifRpart[36]> classif.rpart
out            resampling resampling_id iteration              prediction
out 1: <ResamplingCV[19]>            cv         1 <PredictionClassif[19]>
out 2: <ResamplingCV[19]>            cv         2 <PredictionClassif[19]>
out 3: <ResamplingCV[19]>            cv         3 <PredictionClassif[19]>
out    classif.ce
out 1: 0.05217391
out 2: 0.11304348
out 3: 0.04385965
  • 检查警告或错误:
代码语言:javascript
复制
rr$warnings
out Empty data.table (0 rows and 2 cols): iteration,msg
rr$errors
out Empty data.table (0 rows and 2 cols): iteration,msg

提取并检查重采样分割:

代码语言:javascript
复制
rr$resampling
out <ResamplingCV> with 3 iterations
out * Instantiated: TRUE
out * Parameters: folds=3
rr$resampling$iters
out [1] 3
str(rr$resampling$test_set(1))
out  int [1:115] 1 3 15 19 20 24 29 31 34 37 ...
str(rr$resampling$train_set(1))
out  int [1:229] 5 7 8 9 10 11 13 14 17 27 ...
  • 检索特定迭代的学习器并检查它:
代码语言:javascript
复制
lrn = rr$learners[[1]]
lrn$model
out n= 229 
out 
out node), split, n, loss, yval, (yprob)
out       * denotes terminal node
out 
out 1) root 229 124 Adelie (0.458515284 0.196506550 0.344978166)  
out   2) flipper_length< 207.5 148  44 Adelie (0.702702703 0.290540541 0.006756757)  
out     4) bill_length< 43.35 105   4 Adelie (0.961904762 0.038095238 0.000000000) *
out     5) bill_length>=43.35 43   4 Chinstrap (0.069767442 0.906976744 0.023255814) *
out   3) flipper_length>=207.5 81   3 Gentoo (0.012345679 0.024691358 0.962962963) *
  • 提取预测:
代码语言:javascript
复制
rr$prediction() # all predictions merged into a single Prediction object
out <PredictionClassif> for 344 observations:
out     row_ids     truth  response prob.Adelie prob.Chinstrap prob.Gentoo
out           1    Adelie    Adelie  0.96190476     0.03809524  0.00000000
out           3    Adelie    Adelie  0.96190476     0.03809524  0.00000000
out          15    Adelie    Adelie  0.96190476     0.03809524  0.00000000
out ---                                                                   
out         337 Chinstrap    Gentoo  0.02127660     0.03191489  0.94680851
out         338 Chinstrap Chinstrap  0.06666667     0.91111111  0.02222222
out         340 Chinstrap    Gentoo  0.02127660     0.03191489  0.94680851
rr$predictions()[[1]] # prediction of first resampling iteration
out <PredictionClassif> for 115 observations:
out     row_ids     truth  response prob.Adelie prob.Chinstrap prob.Gentoo
out           1    Adelie    Adelie  0.96190476     0.03809524  0.00000000
out           3    Adelie    Adelie  0.96190476     0.03809524  0.00000000
out          15    Adelie    Adelie  0.96190476     0.03809524  0.00000000
out ---                                                                   
out         339 Chinstrap Chinstrap  0.06976744     0.90697674  0.02325581
out         343 Chinstrap    Gentoo  0.01234568     0.02469136  0.96296296
out         344 Chinstrap Chinstrap  0.06976744     0.90697674  0.02325581
  • 过滤器只保留指定的迭代:
代码语言:javascript
复制
rr$filter(c(1, 3))
print(rr)
out <ResampleResult> of 2 iterations
out * Task: penguins
out * Learner: classif.rpart
out * Warnings: 0 in 0 iterations
out * Errors: 0 in 0 iterations

自定义重采样

有时需要使用自定义分割进行重采样,例如重现研究报告中的结果。可以使用“custom”模板创建手动重采样实例。

代码语言:javascript
复制
resampling = rsmp("custom")
resampling$instantiate(task,
  train = list(c(1:10, 51:60, 101:110)),
  test = list(c(11:20, 61:70, 111:120))
)
resampling$iters
out [1] 1
resampling$train_set(1)
out  [1]   1   2   3   4   5   6   7   8   9  10  51  52  53  54  55  56  57  58  59
out [20]  60 101 102 103 104 105 106 107 108 109 110
resampling$test_set(1)
out  [1]  11  12  13  14  15  16  17  18  19  20  61  62  63  64  65  66  67  68  69
out [20]  70 111 112 113 114 115 116 117 118 119 120

使用预定义组进行重采样

与定义列角色“group”(表示特定的观察结果应该总是在测试集或训练集中一起出现)相反,我们还可以提供一个因子变量来预定义所有分区(还在进行中)。

这意味着该变量的每个因素级别单独组成测试集。因此,此方法不允许设置“fold”参数,因为折叠的数量是由因子级别的数量决定的。

这种预定义的方法在mlr2中称为“阻塞”。它不应该与mlr3spatiotempcv中的术语“块”混淆,后者指的是利用平方/矩形分割的一类重采样方法。

可视化重采样结果

mlr3viz提供了一个autoplot()方法。为了展示一些图,我们创建了一个具有两个特征的二元分类任务,使用10倍交叉验证执行重采样并可视化结果:

代码语言:javascript
复制
task = tsk("pima")
task$select(c("glucose", "mass"))
learner = lrn("classif.rpart", predict_type = "prob")
rr = resample(task, learner, rsmp("cv"), store_models = TRUE)
out INFO  [21:44:37.234] [mlr3]  Applying learner 'classif.rpart' on task 'pima' (iter 4/10) 
out INFO  [21:44:37.247] [mlr3]  Applying learner 'classif.rpart' on task 'pima' (iter 9/10) 
out INFO  [21:44:37.266] [mlr3]  Applying learner 'classif.rpart' on task 'pima' (iter 7/10) 
out INFO  [21:44:37.289] [mlr3]  Applying learner 'classif.rpart' on task 'pima' (iter 3/10) 
out INFO  [21:44:37.315] [mlr3]  Applying learner 'classif.rpart' on task 'pima' (iter 10/10) 
out INFO  [21:44:37.337] [mlr3]  Applying learner 'classif.rpart' on task 'pima' (iter 8/10) 
out INFO  [21:44:37.365] [mlr3]  Applying learner 'classif.rpart' on task 'pima' (iter 2/10) 
out INFO  [21:44:37.379] [mlr3]  Applying learner 'classif.rpart' on task 'pima' (iter 6/10) 
out INFO  [21:44:37.393] [mlr3]  Applying learner 'classif.rpart' on task 'pima' (iter 5/10) 
out INFO  [21:44:37.405] [mlr3]  Applying learner 'classif.rpart' on task 'pima' (iter 1/10)
# boxplot of AUC values across the 10 folds
autoplot(rr, measure = msr("classif.auc"))

img

代码语言:javascript
复制
# ROC curve, averaged over 10 folds
autoplot(rr, type = "roc")

img

代码语言:javascript
复制
# learner predictions for first fold
rr$filter(1)
autoplot(rr, type = "prediction")

img

autoplot.ResampleResult()[38]的手册页列出了所有可用的绘图类型。

可视化重采样分区

Mlr3spatiotempcv提供autoplot()方法来可视化时空数据集的重采样分区。更多信息,请参阅函数参考[39]和vignette“时空可视化”[40]

img

基准测试

比较不同学习器在多个任务和/或不同重采样方案上的表现是一个常见的任务。在机器学习领域,这种操作通常被称为“基准测试”。mlr3包提供了方便的benchmark()函数。

设计创建

在mlr3中,我们要求你提供基准实验的“设计”。这样的设计本质上是你想要执行的设置表。它由任务、学习者和重采样三方面的唯一组合组成。

我们使用benchmark_grid()函数来创建一个详尽的设计并正确地实例化重采样,这样对于每个任务,所有的学习器都在相同的训练/测试分割上执行。我们设置学习器预测概率,并告诉他们预测训练集的观察值(通过设置predict_sets为c(“train”,“test”))。此外,我们使用tsks()lrns()rsmps()来检索Task、Learner和Resampling的列表,其方式与tsk()lrn()rsmp()相同。

代码语言:javascript
复制
library("mlr3verse")

design = benchmark_grid(
  tasks = tsks(c("spam", "german_credit", "sonar")),
  learners = lrns(c("classif.ranger", "classif.rpart", "classif.featureless"),
    predict_type = "prob", predict_sets = c("train", "test")),
  resamplings = rsmps("cv", folds = 3)
)
print(design)
out                 task                         learner         resampling
out 1: <TaskClassif[47]>      <LearnerClassifRanger[36]> <ResamplingCV[19]>
out 2: <TaskClassif[47]>       <LearnerClassifRpart[36]> <ResamplingCV[19]>
out 3: <TaskClassif[47]> <LearnerClassifFeatureless[36]> <ResamplingCV[19]>
out 4: <TaskClassif[47]>      <LearnerClassifRanger[36]> <ResamplingCV[19]>
out 5: <TaskClassif[47]>       <LearnerClassifRpart[36]> <ResamplingCV[19]>
out 6: <TaskClassif[47]> <LearnerClassifFeatureless[36]> <ResamplingCV[19]>
out 7: <TaskClassif[47]>      <LearnerClassifRanger[36]> <ResamplingCV[19]>
out 8: <TaskClassif[47]>       <LearnerClassifRpart[36]> <ResamplingCV[19]>
out 9: <TaskClassif[47]> <LearnerClassifFeatureless[36]> <ResamplingCV[19]>

创建的设计可以传递给benchmark()以开始计算。也可以手动创建自定义设计。然而,如果你使用data.table()创建一个自定义任务,如果你在创建设计之前没有手动实例化重采样,那么设计的每一行的train/test分割将是不同的。查看benchmark_grid()帮助页面[41]以获得一个示例。

结果的执行和汇总

基准设计完成后,可以直接调用benchmark()

代码语言:javascript
复制
# execute the benchmark
bmr = benchmark(design)
out INFO  [21:44:39.493] [mlr3]  Running benchmark with 27 resampling iterations 
out INFO  [21:44:39.501] [mlr3]  Applying learner 'classif.rpart' on task 'sonar' (iter 1/3) 
out INFO  [21:44:39.528] [mlr3]  Applying learner 'classif.featureless' on task 'sonar' (iter 1/3) 
out INFO  [21:44:39.537] [mlr3]  Applying learner 'classif.rpart' on task 'spam' (iter 1/3) 
out INFO  [21:44:39.864] [mlr3]  Applying learner 'classif.featureless' on task 'sonar' (iter 3/3) 
out INFO  [21:44:39.876] [mlr3]  Applying learner 'classif.ranger' on task 'sonar' (iter 3/3) 
out INFO  [21:44:40.066] [mlr3]  Applying learner 'classif.rpart' on task 'sonar' (iter 3/3) 
out INFO  [21:44:40.112] [mlr3]  Applying learner 'classif.ranger' on task 'sonar' (iter 2/3) 
out INFO  [21:44:40.283] [mlr3]  Applying learner 'classif.featureless' on task 'sonar' (iter 2/3) 
out INFO  [21:44:40.299] [mlr3]  Applying learner 'classif.rpart' on task 'sonar' (iter 2/3) 
out INFO  [21:44:40.338] [mlr3]  Applying learner 'classif.rpart' on task 'german_credit' (iter 1/3) 
out INFO  [21:44:40.410] [mlr3]  Applying learner 'classif.featureless' on task 'german_credit' (iter 3/3) 
out INFO  [21:44:40.423] [mlr3]  Applying learner 'classif.ranger' on task 'german_credit' (iter 2/3) 
out INFO  [21:44:40.898] [mlr3]  Applying learner 'classif.featureless' on task 'german_credit' (iter 1/3) 
out INFO  [21:44:40.908] [mlr3]  Applying learner 'classif.ranger' on task 'spam' (iter 3/3) 
out INFO  [21:44:43.498] [mlr3]  Applying learner 'classif.ranger' on task 'sonar' (iter 1/3) 
out INFO  [21:44:43.624] [mlr3]  Applying learner 'classif.ranger' on task 'german_credit' (iter 3/3) 
out INFO  [21:44:44.128] [mlr3]  Applying learner 'classif.ranger' on task 'spam' (iter 2/3) 
out INFO  [21:44:47.058] [mlr3]  Applying learner 'classif.ranger' on task 'german_credit' (iter 1/3) 
out INFO  [21:44:47.537] [mlr3]  Applying learner 'classif.rpart' on task 'german_credit' (iter 2/3) 
out INFO  [21:44:47.568] [mlr3]  Applying learner 'classif.rpart' on task 'spam' (iter 3/3) 
out INFO  [21:44:47.648] [mlr3]  Applying learner 'classif.featureless' on task 'spam' (iter 1/3) 
out INFO  [21:44:47.669] [mlr3]  Applying learner 'classif.rpart' on task 'german_credit' (iter 3/3) 
out INFO  [21:44:47.717] [mlr3]  Applying learner 'classif.featureless' on task 'spam' (iter 2/3) 
out INFO  [21:44:47.741] [mlr3]  Applying learner 'classif.featureless' on task 'spam' (iter 3/3) 
out INFO  [21:44:47.762] [mlr3]  Applying learner 'classif.ranger' on task 'spam' (iter 1/3) 
out INFO  [21:44:50.098] [mlr3]  Applying learner 'classif.rpart' on task 'spam' (iter 2/3) 
out INFO  [21:44:50.177] [mlr3]  Applying learner 'classif.featureless' on task 'german_credit' (iter 2/3) 
out INFO  [21:44:50.189] [mlr3]  Finished benchmark

注意,我们没有手动实例化重采样实例。benchmark_grid()会为我们处理它:在构建穷举网格期间,每个重采样策略都会为每个任务实例化一次。

基准测试完成后,我们可以使用$aggregate()聚合性能结果。我们创建了两个度量来计算训练集和预测集的AUC:

代码语言:javascript
复制
measures = list(
  msr("classif.auc", predict_sets = "train", id = "auc_train"),
  msr("classif.auc", id = "auc_test")
)

tab = bmr$aggregate(measures)
print(tab)
out    nr      resample_result       task_id          learner_id resampling_id
out 1:  1 <ResampleResult[20]>          spam      classif.ranger            cv
out 2:  2 <ResampleResult[20]>          spam       classif.rpart            cv
out 3:  3 <ResampleResult[20]>          spam classif.featureless            cv
out 4:  4 <ResampleResult[20]> german_credit      classif.ranger            cv
out 5:  5 <ResampleResult[20]> german_credit       classif.rpart            cv
out 6:  6 <ResampleResult[20]> german_credit classif.featureless            cv
out 7:  7 <ResampleResult[20]>         sonar      classif.ranger            cv
out 8:  8 <ResampleResult[20]>         sonar       classif.rpart            cv
out 9:  9 <ResampleResult[20]>         sonar classif.featureless            cv
out    iters auc_train  auc_test
out 1:     3 0.9994729 0.9843972
out 2:     3 0.9109017 0.9002759
out 3:     3 0.5000000 0.5000000
out 4:     3 0.9983689 0.8050245
out 5:     3 0.8049514 0.7278807
out 6:     3 0.5000000 0.5000000
out 7:     3 1.0000000 0.9191759
out 8:     3 0.9319927 0.7546365
out 9:     3 0.5000000 0.5000000

我们可以进一步汇总这些结果。例如,我们可能有兴趣知道哪个学习器在同时完成所有任务时表现最好。简单地将性能与平均值相加通常在统计上并不合理。相反,我们按任务分组计算每个学习器的等级统计量。然后将计算得到的按学习器分组的秩用data.table进行汇总。由于需要最大化AUC,我们将这些值乘以−1,使最好的学习者的排名为1。

代码语言:javascript
复制
library("data.table")
# group by levels of task_id, return columns:
# - learner_id
# - rank of col '-auc_train' (per level of learner_id)
# - rank of col '-auc_test' (per level of learner_id)
ranks = tab[, .(learner_id, rank_train = rank(-auc_train), rank_test = rank(-auc_test)), by = task_id]
print(ranks)
out          task_id          learner_id rank_train rank_test
out 1:          spam      classif.ranger          1         1
out 2:          spam       classif.rpart          2         2
out 3:          spam classif.featureless          3         3
out 4: german_credit      classif.ranger          1         1
out 5: german_credit       classif.rpart          2         2
out 6: german_credit classif.featureless          3         3
out 7:         sonar      classif.ranger          1         1
out 8:         sonar       classif.rpart          2         2
out 9:         sonar classif.featureless          3         3
# group by levels of learner_id, return columns:
# - mean rank of col 'rank_train' (per level of learner_id)
# - mean rank of col 'rank_test' (per level of learner_id)
ranks = ranks[, .(mrank_train = mean(rank_train), mrank_test = mean(rank_test)), by = learner_id]

# print the final table, ordered by mean rank of AUC test
ranks[order(mrank_test)]
out             learner_id mrank_train mrank_test
out 1:      classif.ranger           1          1
out 2:       classif.rpart           2          2
out 3: classif.featureless           3          3

可视化基准测试结果

与绘制任务、预测或重新取样结果类似,mlr3viz还提供了用于基准测试结果的autoplot()方法。

代码语言:javascript
复制
autoplot(bmr) + ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 45, hjust = 1))

img

我们也可以绘制ROC曲线。为此,我们首先需要过滤BenchmarkResult,使其只包含一个Task:

代码语言:javascript
复制
bmr_small = bmr$clone()$filter(task_id = "german_credit")
autoplot(bmr_small, type = "roc")

img

提取结果

一个BenchmarkResult对象本质上是多个ResampleResult对象的集合。由于这些数据存储在聚合的data.table()的列中,我们可以很容易地提取它们:

代码语言:javascript
复制
tab = bmr$aggregate(measures)
rr = tab[task_id == "german_credit" & learner_id == "classif.ranger"]$resample_result[[1]]
print(rr)
out <ResampleResult> of 3 iterations
out * Task: german_credit
out * Learner: classif.ranger
out * Warnings: 0 in 0 iterations
out * Errors: 0 in 0 iterations

我们现在可以使用前一节中所示的方法之一来研究这个重采样,甚至是单个重采样迭代:

代码语言:javascript
复制
measure = msr("classif.auc")
rr$aggregate(measure)
out classif.auc 
out   0.8050245
# get the iteration with worst AUC
perf = rr$score(measure)
i = which.min(perf$classif.auc)

# get the corresponding learner and train set
print(rr$learners[[i]])
out <LearnerClassifRanger:classif.ranger>
out * Model: -
out * Parameters: num.threads=1
out * Packages: ranger
out * Predict Type: prob
out * Feature types: logical, integer, numeric, character, factor, ordered
out * Properties: importance, multiclass, oob_error, twoclass, weights
head(rr$resampling$train_set(i))
out [1]  4  7 10 11 17 23

转换和合并

可以使用转换器as_benchmark_result()将ResampleResult转换为BenchmarkResult。另外,两个BenchmarkResults可以合并到一个更大的结果对象中。

代码语言:javascript
复制
task = tsk("iris")
resampling = rsmp("holdout")$instantiate(task)

rr1 = resample(task, lrn("classif.rpart"), resampling)
out INFO  [21:44:52.627] [mlr3]  Applying learner 'classif.rpart' on task 'iris' (iter 1/1)
rr2 = resample(task, lrn("classif.featureless"), resampling)
out INFO  [21:44:52.751] [mlr3]  Applying learner 'classif.featureless' on task 'iris' (iter 1/1)
# Cast both ResampleResults to BenchmarkResults
bmr1 = as_benchmark_result(rr1)
bmr2 = as_benchmark_result(rr2)

# Merge 2nd BMR into the first BMR
bmr1$combine(bmr2)

bmr1
out <BenchmarkResult> of 2 rows with 2 resampling runs
out  nr task_id          learner_id resampling_id iters warnings errors
out   1    iris       classif.rpart       holdout     1        0      0
out   2    iris classif.featureless       holdout     1        0      0

二分类

目标变量只包含两个类的分类问题称为“二分类”。对于这样的二分类目标变量,你可以在任务创建期间在分类任务对象中指定正类。如果在构造过程中没有显式设置,则阳性类默认为目标变量的第一个水平。

代码语言:javascript
复制
# during construction
data("Sonar", package = "mlbench")
task = as_task_classif(Sonar, target = "Class", positive = "R")

# switch positive class to level 'M'
task$positive = "M"

ROC和阈值

ROC分析是机器学习的一个子领域,研究对二元预测系统的评价。我们前面已经看到,可以通过访问$confusion字段来检索Prediction的混淆矩阵:

代码语言:javascript
复制
learner = lrn("classif.rpart", predict_type = "prob")
pred = learner$train(task)$predict(task)
C = pred$confusion
print(C)
out         truth
out response  M  R
out        M 95 10
out        R 16 87

混淆矩阵包含正确和不正确的类分配的计数,按类标签分组。列显示真实的(观察到的)标签,行显示预测的标签。正数总是在混淆矩阵的第一行或第一行。因此,C11中的元素是我们的模型预测阳性类并正确的次数。类似地,C22中的元素是我们的模型预测负类的次数,并且是正确的。对角线上的元素被称为真阳性(TP)和真阴性(TN)。元素C12是我们错误预测阳性标签的次数,被称为假阳性(FP)。元素C21被称为假阴性(FN)。

我们现在可以将混乱矩阵的行和列规范化,从而得出一些有用的指标。

img

很难同时实现高TPR和低FPR,所以我们使用它们来构建ROC曲线。我们通过分类器的TPR和FPR值来描述分类器,并在坐标系中绘制它们。最好的分类器位于左上角。最差的分类器位于对角线。对角线上的分类器产生随机标签(具有不同的比例)。如果每个阳性的x将被随机分为25%的“阳性”,我们得到的TPR为0.25。如果我们将每个负x随机分配给“正”,我们得到的FPR为0.25。在实践中,我们永远不应该得到对角线以下的分类器,因为将预测的标签倒置将导致对角线上的反射。

评分分类器是产生分数或概率的模型,而不是离散标签。为了从mlr3中的学习者获得概率,你必须为ref(“LearnerClassif”)设置predict_type = "prob"。分类器是否能预测概率在其$predict_types字段中给出。阈值灵活地将测量的概率转换为标签。如果f^(x)>τ else预测0,则预测1(正类)通常情况下,可以使用τ=0.5将概率转换为标签,但对于不平衡或成本敏感的情况,另一个阈值可能更合适。阈值设置之后,可以使用标签上定义的任何度量。

代码语言:javascript
复制
library("mlr3viz")

# TPR vs FPR / Sensitivity vs (1 - Specificity)
autoplot(pred, type = "roc")

img

代码语言:javascript
复制
# Precision vs Recall
autoplot(pred, type = "prc")

img

阈值调整

能够预测出正向分类概率的学习者器常使用简单的规则来确定预测的分类标签:如果概率超过阈值t=0.5,则选择正向分类标签,否则选择负向分类标签。如果模型没有很好地校准或类标签严重不平衡,选择一个不同的阈值可以帮助提高预测性能。

在这里,我们将阈值更改为t=0.2,提高了真实阳性率(TPR)。注意,有了新的阈值,更多来自正类别的观察将被正确地归类为正的标签,但与此同时,真实正阴性率(TNR)下降。根据应用的不同,这可能是一种需要的权衡。

代码语言:javascript
复制
measures = msrs(c("classif.tpr", "classif.tnr"))
pred$confusion
out         truth
out response  M  R
out        M 95 10
out        R 16 87
pred$score(measures)
out classif.tpr classif.tnr 
out   0.8558559   0.8969072
pred$set_threshold(0.2)
pred$confusion
out         truth
out response   M   R
out        M 104  25
out        R   7  72

阈值还可以用mlr3pipelines包进行调优,例如使用PipeOpTuneThreshold[42]

参考资料

[1]

补充高级技术: https://mlr3book.mlr-org.com/extending.html#extending-learners

[2]

嵌套重采样部分和模型优化: https://mlr3book.mlr-org.com/optimization.html#optimization

[3]

R6的vignette: https://r6.r-lib.org/

[4]

介绍: https://r6.r-lib.org/articles/Introduction.html

[5]

TaskClassif: https://mlr3.mlr-org.com/reference/TaskClassif.html

[6]

TaskRegr: https://mlr3.mlr-org.com/reference/TaskRegr.html

[7]

mlr3proba::TaskSurv: https://mlr3proba.mlr-org.com/reference/TaskSurv.html

[8]

mlr3proba: https://mlr3proba.mlr-org.com/

[9]

mlr3proba::TaskDens: https://mlr3proba.mlr-org.com/reference/TaskDens.html

[10]

mlr3proba: https://mlr3proba.mlr-org.com/

[11]

mlr3cluster::TaskClust: https://mlr3cluster.mlr-org.com/reference/TaskClust.html

[12]

mlr3cluster: https://mlr3cluster.mlr-org.com/

[13]

mlr3spatiotempcv::TaskRegrST: https://www.rdocumentation.org/packages/mlr3spatiotempcv/topics/TaskRegrST

[14]

mlr3spatiotempcv::TaskClassifST: https://www.rdocumentation.org/packages/mlr3spatiotempcv/topics/TaskClassifST

[15]

mlr3spatiotempcv: https://mlr3spatiotempcv.mlr-org.com/

[16]

mlr3ordinal: https://github.com/mlr-org/mlr3ordinal

[17]

id]`下找到相应的手册页,例如[mlr_tasks_german_credit: https://mlr3.mlr-org.com/reference/mlr_tasks_german_credit.html

[18]

mlr3viz::autoplot: https://mlr3viz.mlr-org.com/reference/autoplot.TaskClassif.html

[19]

更高级的主题: https://mlr3book.mlr-org.com/extending.html#extending-learners

[20]

mlr_learners_classif.featureless: https://mlr3.mlr-org.com/reference/mlr_learners_classif.featureless.html

[21]

mlr_learners_regr.featureless: https://mlr3.mlr-org.com/reference/mlr_learners_regr.featureless.html

[22]

mlr_learners_classif.rpart: https://mlr3.mlr-org.com/reference/mlr_learners_classif.rpart.html

[23]

mlr_learners_regr.rpart: https://mlr3.mlr-org.com/reference/mlr_learners_regr.rpart.html

[24]

mlr3extralearners: https://github.com/mlr-org/mlr3extralearners/

[25]

这个交互式列表: https://mlr3extralearners.mlr-org.com/articles/learners/list_learners.html

[26]

这里: https://mlr3extralearners.mlr-org.com/articles/learners/learner_status.html

[27]

mlr3的文档: https://mlr3.mlr-org.com/reference/mlr_reflections.html#examples

[28]

mlr3文档: https://mlr3.mlr-org.com/reference/mlr_reflections.html#examples

[29]

交叉验证 - cv: https://mlr3.mlr-org.com/reference/mlr_resamplings_cv.html

[30]

留一交叉验证 - loo: https://mlr3.mlr-org.com/reference/mlr_resamplings_loo.html

[31]

重复交叉验证 - repeated_cv: https://mlr3.mlr-org.com/reference/mlr_resamplings_repeated_cv.html

[32]

bootstrapping - bootstrap: https://mlr3.mlr-org.com/reference/mlr_resamplings_bootstrap.html

[33]

二次抽样 - subsampling: https://mlr3.mlr-org.com/reference/mlr_resamplings_subsampling.html

[34]

holdout - holdout: https://mlr3.mlr-org.com/reference/mlr_resamplings_holdout.html

[35]

样本重采样 - insample: https://mlr3.mlr-org.com/reference/mlr_resamplings_insample.html

[36]

自定义重采样 - custom: https://mlr3.mlr-org.com/reference/mlr_resamplings_custom.html

[37]

mlr3spatiotemporal: https://github.com/mlr-org/mlr3spatiotemporal

[38]

autoplot.ResampleResult(): https://mlr3viz.mlr-org.com/reference/autoplot.ResampleResult.html

[39]

函数参考: https://mlr3spatiotempcv.mlr-org.com/reference

[40]

“时空可视化”: https://mlr3spatiotempcv.mlr-org.com/articles/spatiotemp-viz.html

[41]

帮助页面: https://mlr3.mlr-org.com/reference/benchmark_grid.html

[42]

PipeOpTuneThreshold: https://mlr3pipelines.mlr-org.com/reference/mlr_pipeops_tunethreshold.html

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2021-09-14,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 优雅R 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 重采样
    • 设置
      • 实例化
        • 执行
          • 自定义重采样
            • 使用预定义组进行重采样
              • 可视化重采样结果
                • 可视化重采样分区
                • 基准测试
                  • 设计创建
                    • 结果的执行和汇总
                      • 可视化基准测试结果
                        • 提取结果
                          • 转换和合并
                          • 二分类
                            • ROC和阈值
                              • 阈值调整
                                • 参考资料
                                相关产品与服务
                                数据保险箱
                                数据保险箱(Cloud Data Coffer Service,CDCS)为您提供更高安全系数的企业核心数据存储服务。您可以通过自定义过期天数的方法删除数据,避免误删带来的损害,还可以将数据跨地域存储,防止一些不可抗因素导致的数据丢失。数据保险箱支持通过控制台、API 等多样化方式快速简单接入,实现海量数据的存储管理。您可以使用数据保险箱对文件数据进行上传、下载,最终实现数据的安全存储和提取。
                                领券
                                问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档