前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >mlr3_训练和测试

mlr3_训练和测试

作者头像
火星娃统计
发布2021-02-05 16:39:11
7720
发布2021-02-05 16:39:11
举报
文章被收录于专栏:火星娃统计火星娃统计

mlr3_训练和测试

概述

之前的章节中,我们已经建立了task和learner,接下来利用这两个R6对象,建立模型,并使用新的数据集对模型进行评估

建立task和learner

这里使用简单的tsk和lrn方法建立

代码语言:javascript
复制
task = tsk("sonar")
learner = lrn("classif.rpart")

设置训练和测试数据

这里设置的其实是task里面数据的行数目

代码语言:javascript
复制
train_set = sample(task$nrow, 0.8 * task$nrow)
test_set = setdiff(seq_len(task$nrow), train_set)

训练learner

$model是learner中用来存储训练好的模型

代码语言:javascript
复制
# 可以看到目前是没有模型训练好的
learner$model
## NULL

接下来使用任务来训练learner

代码语言:javascript
复制
# 这里使用row_ids选择训练数据
learner$train(task, row_ids = train_set)
# 训练完成后查看模型

print(learner$model)

预测

使用剩余的数据进行预测 predict

代码语言:javascript
复制
# 返回每一个个案的预测结果
prediction = learner$predict(task, row_ids = test_set)
## <PredictionClassif> for 42 observations:
##     row_id truth response
##          2     R        R
##          6     R        R
##         12     R        M
## ---                      
##        191     M        M
##        199     M        M
##        204     M        M

# 为了提取预测后的数据,最好的办法是转换为data.table
head(as.data.table(prediction))

# 同时,我们需要计算混淆矩阵

prediction$confusion
##         truth
## response  M  R
##        M 15  3
##        R  8 16

改变预测的类型

这个部分主要是计算每一种类型的概率,有时候用于roc曲线的绘制

代码语言:javascript
复制
learner$predict_type = "prob"
# 重新训练
learner$train(task, row_ids = train_set)

# 重新预测
prediction = learner$predict(task, row_ids = test_set)
# 查看结果
head(as.data.table(prediction))
##    row_id truth response prob.M  prob.R
## 1:      2     R        R 0.2222 0.77778
## 2:      6     R        R 0.2222 0.77778
## 3:     12     R        M 0.9375 0.06250
## 4:     13     R        R 0.1429 0.85714
## 5:     30     R        R 0.2222 0.77778
## 6:     31     R        M 0.9535 0.04651

可以看到,里面出现了新的两列,用于描述各自的概率大小

绘制预测图

代码语言:javascript
复制
library("mlr3viz")
task = tsk("sonar")
learner = lrn("classif.rpart", predict_type = "prob")
learner$train(task)
prediction = learner$predict(task)
# 绘制默认图
autoplot(prediction)
# 绘制roc图
autoplot(prediction, type = "roc")

对于回归任务

代码语言:javascript
复制
library("mlr3viz")
library("mlr3learners")
task = tsk("mtcars")
learner = lrn("regr.lm")
learner$train(task)
prediction = learner$predict(task)
autoplot(prediction)

模型评估

mlr3 自带一系列的评估方法,如

代码语言:javascript
复制
mlr_measures
## <DictionaryMeasure> with 54 stored values
## Keys: classif.acc, classif.auc, classif.bacc, classif.bbrier,
##   classif.ce, classif.costs, classif.dor, classif.fbeta, classif.fdr,
##   classif.fn, classif.fnr, classif.fomr, classif.fp, classif.fpr,
##   classif.logloss, classif.mbrier, classif.mcc, classif.npv,
##   classif.ppv, classif.prauc, classif.precision, classif.recall,
##   classif.sensitivity, classif.specificity, classif.tn, classif.tnr,
##   classif.tp, classif.tpr, debug, oob_error, regr.bias, regr.ktau,
##   regr.mae, regr.mape, regr.maxae, regr.medae, regr.medse, regr.mse,
##   regr.msle, regr.pbias, regr.rae, regr.rmse, regr.rmsle, regr.rrse,
##   regr.rse, regr.rsq, regr.sae, regr.smape, regr.srho, regr.sse,
##   selected_features, time_both, time_predict, time_train

# 使用msr获取评估的方法,这里是准确率
measure = msr("classif.acc")
prediction$score(measure)

## classif.acc 
##       0.875

结束语

到这里基本上mlr3的主要内容都已经更新完毕,后面涉及冲抽样,模型优化等问题 love&peace

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2021-01-20,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 火星娃统计 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • mlr3_训练和测试
    • 概述
      • 建立task和learner
        • 设置训练和测试数据
          • 训练learner
            • 预测
              • 改变预测的类型
              • 绘制预测图
              • 模型评估
            • 结束语
            领券
            问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档