前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >预测三分类变量模型的ROC介绍

预测三分类变量模型的ROC介绍

作者头像
Jamesjin63
发布2022-11-03 14:57:32
9620
发布2022-11-03 14:57:32
举报
文章被收录于专栏:EpiHubEpiHub

我们对Logistics回归很熟悉,预测变量y为二分类变量,然后对预测结果进行评估,会用到2*2 Matrix,计算灵敏度、特异度等及ROC曲线,判断模型预测准确性。

但是如果遇到y为三分类变量,那么会得到3*3 Matrix 那该选用什么指标进行评估呢?

答案:macro-average and micro-average

接下来,我们将介绍如何建立模型预测三分类变量,及对模型准确性进行评估。

1.模型构建

我们根据 iris数据集中的 Species三分类变量,建立多元回归模型,根据花的特征预测Species种类,其中我们添加xv新变量; 首先我们对 iris数据集进行拆分成 Training与Testing两个数据集,Training用于模型构建。

代码语言:javascript
复制
# https://stackoverflow.com/questions/59205776/random-forest-svm-and-multinomial-logistic-regression-with-r

library(tidyverse)
library(randomForest)
set.seed(123)
head(iris)
df=iris %>% mutate(xv=as.factor(ifelse(rnorm(150,3,4)<3,"Yes","No"))) # new predictor
## split da
split1= sample(c(rep(0, 0.7 * nrow(df)), rep(1, 0.3 * nrow(df)))) 
train <- df[split1 == 0, ]   
test <- df[split1 == 1, ]  

## Model LM
library(nnet)
fit1 = multinom(Species~.,data=train)
summary(fit1)

fit1结果解读比二分类多一个分类。参照OR的解释。

2.观测值VS预测值-Matrix

构建完模型fit1后,需要对testing 数据进行预测,然后我们创建一个真实值与预测值的矩阵。

代码语言:javascript
复制
## Model Prediction
pre=predict(fit1,test)
dfpre=tibble(actual=test$Species,predicted=pre)
table(dfpre)

            predicted
actual       setosa versicolor virginica
  setosa         13          0         0
  versicolor      0         13         0
  virginica       0          1        18

3.Performance Measures

接下来对该矩阵进行分析,需要预先对矩阵的一些参数进行计算;为后续的 Accuracy, precision, F1等。 Source: https://www.r-bloggers.com/2016/03/computing-classification-evaluation-metrics-in-r/

代码语言:javascript
复制
## basic variables
n = sum(cm) # number of instances
nc = nrow(cm) # number of classes
diag = diag(cm) # number of correctly classified instances per class 
rowsums = apply(cm, 1, sum) # number of instances per class
colsums = apply(cm, 2, sum) # number of predictions per class
p = rowsums / n # distribution of instances over the actual classes
q = colsums / n # distribution of instances over the predicted classes

## Accuracy
accuracy = sum(diag) / n 
accuracy 
precision = diag / colsums 
recall = diag / rowsums 
f1 = 2 * precision * recall / (precision + recall) 
data.frame(precision, recall, f1) 

## Macro
macroPrecision = mean(precision)
macroRecall = mean(recall)
macroF1 = mean(f1)
data.frame(macroPrecision, macroRecall, macroF1)

上述计算过程比较繁琐,有没有一键输出的,有!接下来是一键输出

3.1 Performance Measures 一键输出

这里使用 Evaluate 函数进行输出,其中Evaluate代码见连接或后台私信。 Source:https://github.com/saidbleik/Evaluation/blob/master/eval.R

代码语言:javascript
复制
results = Evaluate(actual=df3$ya, predicted=xa)
results
## output
$ConfusionMatrix
            Predicted
Actual       setosa versicolor virginica
  setosa         13          0         0
  versicolor      0         13         0
  virginica       0          1        18

$Metrics
                                setosa versicolor virginica
Accuracy                     0.9777778  0.9777778 0.9777778
Precision                    1.0000000  0.9285714 1.0000000
Recall                       1.0000000  1.0000000 0.9473684
F1                           1.0000000  0.9629630 0.9729730
MacroAvgPrecision            0.9761905  0.9761905 0.9761905
MacroAvgRecall               0.9824561  0.9824561 0.9824561
MacroAvgF1                   0.9786453  0.9786453 0.9786453
AvgAccuracy                  0.9851852  0.9851852 0.9851852
MicroAvgPrecision            0.9777778  0.9777778 0.9777778
MicroAvgRecall               0.9777778  0.9777778 0.9777778
MicroAvgF1                   0.9777778  0.9777778 0.9777778
MajorityClassAccuracy        0.4222222  0.4222222 0.4222222
MajorityClassPrecision       0.0000000  0.0000000 0.4222222
MajorityClassRecall          0.0000000  0.0000000 1.0000000
MajorityClassF1              0.0000000  0.0000000 0.5937500
Kappa                        0.9662162  0.9662162 0.9662162
RandomGuessAccuracy          0.3333333  0.3333333 0.3333333
RandomGuessPrecision         0.2888889  0.2888889 0.4222222
RandomGuessRecall            0.3333333  0.3333333 0.3333333
RandomGuessF1                0.3095238  0.3095238 0.3725490
RandomWeightedGuessAccuracy  0.3451852  0.3451852 0.3451852
RandomWeightedGuessPrecision 0.2888889  0.2888889 0.4222222
RandomWeightedGuessRecall    0.2888889  0.2888889 0.4222222
RandomWeightedGuessF1        0.2888889  0.2888889 0.4222222

4.ROC Curves Across Multi-Class Classifications

当然我们也可以绘制 The ROC curves of micro-average and macro-average, indicating the overall distinguishing ability of the three-class classification. 但是需要分几个步骤进行:

  1. 我们原来的预测值输出是Species的分类结果,这部分我们需要输出对各种类别的概率值。
  2. 哑变量设置,将我们的 testing数据集中Species分类改成哑变量
  3. 计算 macro/micro。并绘制ROC曲线 Source:https://mran.microsoft.com/snapshot/2018-02-12/web/packages/multiROC/vignettes/my-vignette.html

当然这里我们需要提到一个概念:One-vs-all confusion matrices 即针对三个变量转换成,setosa与非setosa;这样就可以得到setosa的ROC

代码语言:javascript
复制
library(multiROC)
actual=dummies::dummy.data.frame(test %>% select(Species),               
                                 sep = "_",            
                                 dummy.classes = "factor"  )

predicted=predict(fit1,test,type = "prob")# with probability


test_data=cbind(actual,predicted)
colnames(test_data)=c("setosa_true","versicolor_true" ,"virginica_true",
                      "setosa_pred_m1","versicolor_pred_m1","virginica_pred_m1")
res <- multi_roc(test_data, force_diag=T)
res

res里面存储了我们想要的信息,接下来对res进行提取各组的Specificity 与Sensitivity,绘制ROC曲线。

代码语言:javascript
复制
#### ggplot ROC
n_method <- length(unique(res$Methods))
n_group <- length(unique(res$Groups))
res_df <- data.frame(Specificity= numeric(0), Sensitivity= numeric(0), Group = character(0), AUC = numeric(0), Method = character(0))
for (i in 1:n_method) {
  for (j in 1:n_group) {
    temp_data_1 <- data.frame(Specificity=res$Specificity[[i]][j],
                              Sensitivity=res$Sensitivity[[i]][j],
                              Group=unique(res$Groups)[j],
                              AUC=res$AUC[[i]][j],
                              Method = unique(res$Methods)[i])
    colnames(temp_data_1) <- c("Specificity", "Sensitivity", "Group", "AUC", "Method")
    res_df <- rbind(res_df, temp_data_1)
    
  }
  temp_data_2 <- data.frame(Specificity=res$Specificity[[i]][n_group+1],
                            Sensitivity=res$Sensitivity[[i]][n_group+1],
                            Group= "Macro",
                            AUC=res$AUC[[i]][n_group+1],
                            Method = unique(res$Methods)[i])
  temp_data_3 <- data.frame(Specificity=res$Specificity[[i]][n_group+2],
                            Sensitivity=res$Sensitivity[[i]][n_group+2],
                            Group= "Micro",
                            AUC=res$AUC[[i]][n_group+2],
                            Method = unique(res$Methods)[i])
  colnames(temp_data_2) <- c("Specificity", "Sensitivity", "Group", "AUC", "Method")
  colnames(temp_data_3) <- c("Specificity", "Sensitivity", "Group", "AUC", "Method")
  res_df <- rbind(res_df, temp_data_2)
  res_df <- rbind(res_df, temp_data_3)
}

ggplot(res_df, aes(x = 1-Specificity, y=Sensitivity)) + 
  geom_path(aes(color = Group, linetype=Method)) + 
  geom_segment(aes(x = 0, y = 0, xend = 1, yend = 1), colour='grey', linetype = 'dotdash') + 
  theme_bw() + 
  theme(plot.title = element_text(hjust = 0.5), 
        legend.justification=c(1, 0), 
        legend.position=c(.95, .05), 
        legend.title=element_blank(), 
        legend.background = element_rect(fill=NULL, size=0.5, linetype="solid", colour ="black"))

ggsave("ROC-SVM.pdf",width = 16,height = 12,dpi=500)

image.png

最后,附上RF,SVM的模型

代码语言:javascript
复制
#### 2.SVM
library(e1071)
fitsvm = svm(ya~ ., data = df2,probability=TRUE)
summary(fitsvm)


#### 3.RF
library(randomForest)
fitrf = randomForest(ya~ ., 
                   data = df2,
                   ntree = 300, # parameter setting
                   mtry = 8,
                   importance = TRUE,
                   proximity = TRUE)

参考: Performance Measures for Multi-Class Problems-- https://www.datascienceblog.net/post/machine-learning/performance-measures-multi-class-problems/

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2022-03-27,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1.模型构建
  • 2.观测值VS预测值-Matrix
  • 3.Performance Measures
  • 3.1 Performance Measures 一键输出
  • 4.ROC Curves Across Multi-Class Classifications
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档