前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >🤩 catboost | 分量变量的梯度提升机器学习算法怎么用!?~(附shap解析!~)

🤩 catboost | 分量变量的梯度提升机器学习算法怎么用!?~(附shap解析!~)

作者头像
生信漫卷
发布2024-04-10 14:39:07
910
发布2024-04-10 14:39:07
举报

撒西不理达纳,各位!~😭

终于从创伤转出来了,而且在创伤的半年里正好赶上了国自然、省自然的提交。🙃

实在是忙的不行,根本没有时间做自己的事情。🫠

现在转完出来了,也可以写点自己感兴趣的东西了。😜

大家有什么推荐的有趣的包吗,分享一下呀!~🤒

接着之前的机器学习吧,今天是Catboost。🙊

CatBoostXGBoostLightGBM并称为GBDT的三大主流神器,都是在GBDT算法框架下的一种改进实现。

CatBoost是一种基于对称决策树(oblivious trees)为机器学习器实现的参数较少、支持分类变量高准确性GBDT框架,主要解决的痛点是高效合理地处理分类特征。🥳

2用到的包

代码语言:javascript
复制
rm(list = ls())
library(tidyverse)
library(catboost)
library(survival)

3示例数据

代码语言:javascript
复制
dat <- lung %>% 
  dplyr::select(c(status,sex, ph.ecog), everything())

DT::datatable(dat)

4划分训练集和测试集

代码语言:javascript
复制
train_indices <- sample(x = 1:nrow(dat), size = 0.7 * nrow(dat), replace = F)

test_indices <- sample(setdiff(1:nrow(dat), train_indices), size = 0.3 * nrow(dat), replace = F)

train_data <- dat[train_indices, ]

test_data <- dat[test_indices, ]

5定义训练集和测试集

代码语言:javascript
复制
trainpool <- catboost.load_pool(data=train_data[,-1],label = as.integer(train_data[,1]),cat_features=c(2,3))

testpool <- catboost.load_pool(data=test_data[,-1],label = as.integer(test_data[,1]),cat_features=c(2,3))

6设置算法参数

代码语言:javascript
复制
params <- list(iterations = 1000,  
               loss_function = 'Logloss', 
               random_seed=123, 
               learning_rate = 0.01, 
               verbose = 0,  
               use_best_model = T, 
               od_type = 'Iter', 
               od_wait = 10  
               )

7模型拟合

代码语言:javascript
复制
cat_model <- catboost.train(trainpool,testpool,params)

cat_model

8模型预测和评估

代码语言:javascript
复制
pred <- catboost.predict(cat_model, 
                          testpool, 
                          prediction_type = "Probability")

9混淆矩阵

代码语言:javascript
复制
ModelMetrics::confusionMatrix(test_data[,1], pred, cutoff = 0.7)

10绘制ROC曲线

代码语言:javascript
复制
library(pROC)

cat_roc<- roc(test_data[,1], pred,
              aur = T,
              ci = T,
              smooth = T)

ggroc(cat_roc, legacy.axes = T)+
    geom_segment(aes(x = 0, xend = 1, y = 0, yend = 1), color="darkgrey", linetype=4)+
    theme_bw()+
    ggtitle('ROC') + 
    ggsci::scale_color_npg()+
    annotate("text",x=0.75,y=0.125,label=paste("AUC = ", round(cat_roc$auc,3)))

11基于SHAP值进行模型解释

代码语言:javascript
复制
library(shapviz)

shapviz.catboost.Model <- function(object, X_pred, X = X_pred, collapse = NULL, ...) {
  if (!requireNamespace("catboost", quietly = TRUE)) {
    stop("Package 'catboost' not installed")
  }
  stopifnot(
    "X must be a matrix or data.frame. It can't be an object of class catboost.Pool" =
      is.matrix(X) || is.data.frame(X),
    "X_pred must be a matrix, a data.frame, or a catboost.Pool" =
      is.matrix(X_pred) || is.data.frame(X_pred) || inherits(X_pred, "catboost.Pool"),
    "X_pred must have column names" = !is.null(colnames(X_pred))
  )
  
  if (!inherits(X_pred, "catboost.Pool")) {
    X_pred <- catboost.load_pool(X_pred)
  }

  S <- catboost.get_feature_importance(object, X_pred, type = "ShapValues", ...)

  pp <- ncol(X_pred) + 1L
  baseline <- S[1L, pp]
  S <- S[, -pp, drop = FALSE]
  colnames(S) <- colnames(X_pred)
  shapviz(S, X = X, baseline = baseline, collapse = collapse)
}

代码语言:javascript
复制
shp <- shapviz(cat_model, X_pred = test_data[,-1])

shp

12可视化

这里之前都介绍过有哪些可视化方法了,大家不清楚的可以翻看之前的推文。🥳

代码语言:javascript
复制
sv_waterfall(shp,row_id = 1)

代码语言:javascript
复制
sv_force(shp,row_id = 1)

代码语言:javascript
复制
sv_importance(shp,kind = "beeswarm")

代码语言:javascript
复制
sv_importance(shp,fill="#F2613F")

代码语言:javascript
复制
sv_dependence(shp, 
              "ph.ecog", 
              alpha = 0.5,
              size = 1.5,
              color_var = NULL)

代码语言:javascript
复制
sv_dependence(shp, 
              v = c("sex",
                    "age",
                    "ph.ecog",
                    "ph.karno"))

最后祝大家早日不卷!~

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2024-03-30,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 生信漫卷 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 2用到的包
  • 3示例数据
  • 4划分训练集和测试集
  • 5定义训练集和测试集
  • 6设置算法参数
  • 7模型拟合
  • 8模型预测和评估
  • 9混淆矩阵
  • 10绘制ROC曲线
  • 11基于SHAP值进行模型解释
  • 12可视化
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档