前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >mlr3_重抽样

mlr3_重抽样

作者头像
火星娃统计
发布2021-02-05 16:39:39
8580
发布2021-02-05 16:39:39
举报
文章被收录于专栏:火星娃统计火星娃统计

mlr3_重抽样

概述

mlr3中包含的重抽样方法

  • cross validation ("cv"):交叉验证
  • leave-one-out cross validation ("loo"):留一验证
  • repeated cross validation ("repeated_cv") :重复交叉验证
  • bootstrapping ("bootstrap"):bootstrap
  • subsampling ("subsampling"):下采样
  • holdout ("holdout"):相当于3:7的分割方式
  • in-sample resampling ("insample")
  • custom resampling ("custom"):自定义重抽样

设置任务

代码语言:javascript
复制
task = tsk("iris")
learner = lrn("classif.rpart")
# 查看mlr的重抽样方法有哪些
as.data.table(mlr_resamplings)
##            key        params iters
## 1:   bootstrap repeats,ratio    30
## 2:      custom                   0
## 3:          cv         folds    10
## 4:     holdout         ratio     1
## 5:    insample                   1
## 6:         loo                  NA
## 7: repeated_cv repeats,folds   100
## 8: subsampling repeats,ratio    30

# 通过rsmp函数提取采样方法
resampling = rsmp("holdout")
print(resampling)
## <ResamplingHoldout> with 1 iterations
## * Instantiated: FALSE
## * Parameters: ratio=0.6667

这里$is_instantiated是false,这表示,我们没有将采样方法设置再数据集中。同时这里默认的采样比例是0.6667,可以通过下面两种方式更改

代码语言:javascript
复制
resampling$param_set$values = list(ratio = 0.8)
rsmp("holdout", ratio = 0.8)

实例化

通过instantiate函数对任务进行分组

代码语言:javascript
复制
resampling = rsmp("cv", folds = 3L)
resampling$instantiate(task)
resampling$iters
## [1] 3
# 查看训练和测试集的id号
str(resampling$train_set(1))
##  int [1:100] 2 3 4 5 10 12 14 15 18 19 ...
str(resampling$test_set(1))
##  int [1:50] 7 9 13 16 17 21 22 25 35 37 ...

执行重抽样

将task、learner和resample组合起来形成一个新的对象,

代码语言:javascript
复制
task = tsk("pima")
learner = lrn("classif.rpart", maxdepth = 3, predict_type = "prob")
resampling = rsmp("cv", folds = 3L)
# 将三者组合起来
rr = resample(task, learner, resampling, store_models = TRUE)
print(rr)

## <ResampleResult> of 3 iterations
## * Task: pima
## * Learner: classif.rpart
## * Warnings: 0 in 0 iterations
## * Errors: 0 in 0 iterations

#通过aggregate函数将多个结果平均
rr$aggregate(msr("classif.ce"))
## classif.ce 
##     0.2721


# 查看每个模型的性能
rr$score(msr("classif.ce"))
##                 task task_id                   learner    learner_id
## 1: <TaskClassif[45]>    pima <LearnerClassifRpart[34]> classif.rpart
## 2: <TaskClassif[45]>    pima <LearnerClassifRpart[34]> classif.rpart
## 3: <TaskClassif[45]>    pima <LearnerClassifRpart[34]> classif.rpart
##            resampling resampling_id iteration              prediction
## 1: <ResamplingCV[19]>            cv         1 <PredictionClassif[19]>
## 2: <ResamplingCV[19]>            cv         2 <PredictionClassif[19]>
## 3: <ResamplingCV[19]>            cv         3 <PredictionClassif[19]>
##    classif.ce
## 1:     0.3164
## 2:     0.2617
## 3:     0.2383

查看迭代结果

代码语言:javascript
复制
# 查看错误和警告
rr$warnings
rr$errors
# 查看抽样策略
rr$resampling
# 产看迭代次数
rr$resampling$iters
# 查看第一测试集和训练集
str(rr$resampling$test_set(1))
str(rr$resampling$train_set(1))

# 查看指定的学习器
lrn = rr$learners[[1]]
lrn$model

# 提取预测结果;这里将所有预测整合再一个表中
rr$prediction() 
# 提取第一次迭代结果
rr$predictions()[[1]]

自定义抽样

自己选择样本的编号,进行抽样,傻子才这样做

代码语言:javascript
复制
resampling = rsmp("custom")
resampling$instantiate(task,
  train = list(c(1:10, 51:60, 101:110)),
  test = list(c(11:20, 61:70, 111:120))
)
resampling$iters

绘制结果

代码语言:javascript
复制
library("mlr3viz")
autoplot(rr)

hu

绘制roc曲线

autoplot(rr, type = "roc")

结束语

对于重抽样的操作,建议在高性能的服务器上进行,或者测试数据较少或者特征较少的数据集。

love&peace

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2021-01-25,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 火星娃统计 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • mlr3_重抽样
    • 概述
      • 设置任务
        • 实例化
          • 执行重抽样
            • 自定义抽样
              • 绘制结果
                • 结束语
                领券
                问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档