前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >R语言机器学习caret-09:决策树的小例子

R语言机器学习caret-09:决策树的小例子

作者头像
医学和生信笔记
发布2023-08-30 11:55:25
2630
发布2023-08-30 11:55:25
举报
文章被收录于专栏:医学和生信笔记

前面已经铺垫了超多caret的基础知识,所以下面就是具体的实战演示了。

今天给大家演示下caret做决策树的例子,但其实并不是很好用,还不如之前介绍的直接使用rpart,或者tidymodelsmlr3

加载数据和R包

代码语言:javascript
复制
library(caret)
代码语言:javascript
复制
library(modeldata)
str(penguins)
代码语言:javascript
复制
## tibble [344 × 7] (S3: tbl_df/tbl/data.frame)
##  $ species          : Factor w/ 3 levels "Adelie","Chinstrap",..: 1 1 1 1 1 1 1 1 1 1 ...
##  $ island           : Factor w/ 3 levels "Biscoe","Dream",..: 3 3 3 3 3 3 3 3 3 3 ...
##  $ bill_length_mm   : num [1:344] 39.1 39.5 40.3 NA 36.7 39.3 38.9 39.2 34.1 42 ...
##  $ bill_depth_mm    : num [1:344] 18.7 17.4 18 NA 19.3 20.6 17.8 19.6 18.1 20.2 ...
##  $ flipper_length_mm: int [1:344] 181 186 195 NA 193 190 181 195 193 190 ...
##  $ body_mass_g      : int [1:344] 3750 3800 3250 NA 3450 3650 3625 4675 3475 4250 ...
##  $ sex              : Factor w/ 2 levels "female","male": 2 1 1 NA 1 2 1 2 NA NA ...

用这个企鹅数据集做演示。一共有377行,7列,其中species是结果变量,三分类,因子型,其余列是预测变量。

首先还是简单探索下数据:

代码语言:javascript
复制
library(GGally)
代码语言:javascript
复制
ggbivariate(penguins, "species")
代码语言:javascript
复制
## Warning: Removed 2 rows containing non-finite values (`stat_boxplot()`).
## Removed 2 rows containing non-finite values (`stat_boxplot()`).
## Removed 2 rows containing non-finite values (`stat_boxplot()`).
## Removed 2 rows containing non-finite values (`stat_boxplot()`).
代码语言:javascript
复制
## Warning: Removed 11 rows containing non-finite values (`stat_prop()`).

plot of chunk unnamed-chunk-2

这个数据还可以,sex有一些缺失值,其他看着还行。

预处理

做个简单的预处理,连续性变量中心化,分类变量设置哑变量。预处理这部分不如tidymodels好用。

代码语言:javascript
复制
cent <- preProcess(penguins, method = c("center","scale"))
pen_df <- predict(cent, newdata = penguins)

class <- pen_df$species
dummy <- dummyVars(species ~. , data=pen_df)
pen_df <- predict(dummy, newdata = pen_df)
代码语言:javascript
复制
pen_df <- as.data.frame(pen_df)
pen_df$species <- class

str(pen_df)
代码语言:javascript
复制
## 'data.frame': 344 obs. of  10 variables:
##  $ island.Biscoe    : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ island.Dream     : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ island.Torgersen : num  1 1 1 1 1 1 1 1 1 1 ...
##  $ bill_length_mm   : num  -0.883 -0.81 -0.663 NA -1.323 ...
##  $ bill_depth_mm    : num  0.784 0.126 0.43 NA 1.088 ...
##  $ flipper_length_mm: num  -1.416 -1.061 -0.421 NA -0.563 ...
##  $ body_mass_g      : num  -0.563 -0.501 -1.187 NA -0.937 ...
##  $ sex.female       : num  0 1 1 NA 1 0 1 0 NA NA ...
##  $ sex.male         : num  1 0 0 NA 0 1 0 1 NA NA ...
##  $ species          : Factor w/ 3 levels "Adelie","Chinstrap",..: 1 1 1 1 1 1 1 1 1 1 ...

建立模型

caret是可以调用rpart包实现决策树的,但是只支持一个超参数cp,感觉不如之前介绍的好用:

代码语言:javascript
复制
# 设定种子数
set.seed(3456)

# 根据结果变量的类别多少划分
trainIndex <- createDataPartition(pen_df$species, p = 0.7, 
                                  list = FALSE)
head(trainIndex)
代码语言:javascript
复制
##      Resample1
## [1,]         2
## [2,]         7
## [3,]         8
## [4,]         9
## [5,]        10
## [6,]        12
代码语言:javascript
复制
penTrain <- pen_df[ trainIndex,]
penTest  <- pen_df[-trainIndex,]

dim(penTrain)
代码语言:javascript
复制
## [1] 242  10
代码语言:javascript
复制
dim(penTest)
代码语言:javascript
复制
## [1] 102  10
代码语言:javascript
复制
# 选择重抽样方法,10折交叉验证
trControl <- trainControl(method = "cv", number = 10,
                          classProbs = T
                          )

set.seed(8)
tree_fit <- train(x = pen_df[,-1],
               y = pen_df$species,
               method = "rpart",
               trControl = trControl,
               tuneLength = 20
               )

tree_fit
代码语言:javascript
复制
## CART 
## 
## 344 samples
##   9 predictor
##   3 classes: 'Adelie', 'Chinstrap', 'Gentoo' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 310, 309, 310, 309, 310, 310, ... 
## Resampling results across tuning parameters:
## 
##   cp          Accuracy   Kappa    
##   0.00000000  1.0000000  1.0000000
##   0.03399123  1.0000000  1.0000000
##   0.06798246  1.0000000  1.0000000
##   0.10197368  1.0000000  1.0000000
##   0.13596491  1.0000000  1.0000000
##   0.16995614  1.0000000  1.0000000
##   0.20394737  1.0000000  1.0000000
##   0.23793860  1.0000000  1.0000000
##   0.27192982  1.0000000  1.0000000
##   0.30592105  1.0000000  1.0000000
##   0.33991228  1.0000000  1.0000000
##   0.37390351  0.8023203  0.6725971
##   0.40789474  0.8023203  0.6725971
##   0.44188596  0.8023203  0.6725971
##   0.47587719  0.8023203  0.6725971
##   0.50986842  0.8023203  0.6725971
##   0.54385965  0.8023203  0.6725971
##   0.57785088  0.8023203  0.6725971
##   0.61184211  0.8023203  0.6725971
##   0.64583333  0.6525957  0.3976620
## 
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was cp = 0.3399123.
代码语言:javascript
复制
plot(tree_fit)

plot of chunk unnamed-chunk-7

代码语言:javascript
复制
library(rpart.plot)
代码语言:javascript
复制
## Loading required package: rpart
代码语言:javascript
复制
rpart.plot(tree_fit$finalModel)

plot of chunk unnamed-chunk-8

代码语言:javascript
复制
library(treeheatr)
代码语言:javascript
复制
heat_tree(partykit::as.party(tree_fit$finalModel))

plot of chunk unnamed-chunk-9

其他图形就不演示了,大家可以参考我们之前的推文。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2023-05-16,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 医学和生信笔记 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 加载数据和R包
  • 预处理
  • 建立模型
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档