前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >调整模型以减少错误预测

调整模型以减少错误预测

作者头像
磐创AI
发布2024-03-12 17:53:14
950
发布2024-03-12 17:53:14
举报

介绍

在创建分类模型时,许多算法提供了predict_proba()函数,用于给出观察结果被分类到每个类别的概率。因此,通常会看到如下输出:

代码语言:javascript
复制
[0.925, 0.075]

在上述情况下,模型有92.5%的确信度认为观察结果属于类别0,而只有7.5%的机会属于类别1。

因此,如果我们请求这个同样的模型使用predict()函数来进行二元预测,我们将只会得到结果[0],对吗?

在这个例子中,很可能我们不希望模型将观察结果预测为类别1,因为它只有很小的机会。但是,让我们假设我们对另一个观察结果进行了预测,结果如下:

代码语言:javascript
复制
[0.480, 0.520]

现在怎么办?

很多模型的粗糙切割预测肯定会给我们[1]的结果。但这是最佳决策吗?有时是,有时不是。

在本文中,我们将学习如何使用Python中的catboost包,根据我们对于可接受的假阳性率[FPR]或假阴性率[FNR]的理解,为分类提供最佳的阈值值。

数据科学背景

为了将这篇文章置于上下文中,让我们了解为什么要将阈值从默认的50%更改为其他数字。

我们有一个最好的例子来自医疗保健行业。我们知道许多实验室检查和药物测试依赖于机器学习,以帮助专家得出最精确的答案。毕竟,在这个行业,每个百分点都关系到一个人的生命。

所以让我们说我们正在使用数据来诊断乳腺癌。与利益相关者讨论后,我们达成了一项协议,即我们希望我们的模型最多产生1%的假阴性。我们想要确保一个人是健康的,以便说它对乳腺癌是阴性的。如果有疑虑,我们将将其分类为阳性,并建议进行第二次检查或不同的确认测试。

正如你可能已经得出的结论,这样做将降低我们模型的准确性,因为我们将增加假阳性的数量,但这是可以接受的,因为人们始终可以再次检查并进行其他检查以确认是否是真正的阳性。另一方面,我们不会漏掉任何一个患有疾病并得到阴性结果的人。

编码

你可以在我的GitHub存储库中找到这个练习的全部代码,链接在这里。

https://github.com/gurezende/Studying/tree/master/Python/CatBoost

要安装catboost,使用pip install catboost。下面列出了一些需要的导入。

代码语言:javascript
复制
# Basics
import pandas as pd
import numpy as np
# Visualizations
import plotly.express as px
# CatBoost
from catboost import CatBoostClassifier
from catboost import Pool
# Train test
from sklearn.model_selection import train_test_split
# Metrics
from sklearn.metrics import confusion_matrix, f1_score
数据集

要使用的数据是著名的toy数据集Breast Cancer,原生来自sklearn。

代码语言:javascript
复制
# Dataset
from sklearn.datasets import load_breast_cancer

# Load data
data = load_breast_cancer()

# X
X = pd.DataFrame(data.data, columns=data.feature_names)
# y
y = data.target

正如你可能已经知道或不知道的那样,这个数据集已经准备好了。在建模之前没有太多可以探索或转换的内容。这也不是我们在这里的目的,所以我只会继续进行代码。

训练测试分割

让我们将数据拆分为训练集和测试集。

代码语言:javascript
复制
# Train test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

print(f'Train shapes: {X_train.shape} | {y_train.shape}')
print(f'Test shapes: {X_test.shape} | {y_test.shape}')

Train shapes: (455, 30) | (455,)
Test shapes: (114, 30) | (114,)
第一个模型

接下来,我们将使用CatBoostClassifier训练第一个模型。

代码语言:javascript
复制
# Creating a Pool for training and validation sets
train_pool = Pool( data=X_train, label=y_train)
test_pool = Pool( data=X_test, label=y_test)

# Fit
model = CatBoostClassifier(iterations=500)
model.fit(train_pool, eval_set=test_pool, verbose=100)

在此序列中,F1分数为:97%。

代码语言:javascript
复制
# Predict
preds = model.predict(X_test)
f1_score(y_test, preds)

0.971830985915493

非常好。但是我们的模型有点复杂,因为它有超过30个特征。让我们尝试减少特征数量,而不会失去太多性能。Catboost具有feature_importances_属性,可以帮助我们确定要选择的最佳特征。

代码语言:javascript
复制
# Feature importances to dataframe
feature_importances = (
    pd.DataFrame({'feature': data.feature_names, 
                  'importance': model.feature_importances_})
    .sort_values(by='importance', ascending=False)
)
# Plot
px.bar(feature_importances,
       x= data.feature_names, y=model.feature_importances_,
       height=600, width=1000).update_layout(xaxis={'categoryorder':'total descending'})

在不使用任何复杂技巧的情况下,我只是随意选择了保留任何具有3+重要性的特征。这保留了其中的10个特征,位于红线的左侧。

简化模型

让我们训练更简化的模型并评估分数。

代码语言:javascript
复制
# Simpler model
features = feature_importances.feature[:10]
# Creating a Pool for training and validation sets
train_pool2 = Pool( data=X_train[features], label=y_train)
test_pool2 = Pool( data=X_test[features], label=y_test)

# Model 
model2 = CatBoostClassifier(iterations=600)
model2.fit(train_pool2, eval_set=test_pool2, verbose=100)

# Score
preds2 = model2.predict(test_pool2)
f1_score(y_test, preds2)

0.979020979020979

不错。相同的F1分数:97%。

由于我们正在处理医学诊断,我们不应该对假阴性很宽容。我们希望我们的模型只有在非常确定患者实际上是健康的情况下才会说患者是健康的。

但我们知道,CatBoost算法使用标准的50%阈值来预测结果。这意味着,如果正面概率低于50%,患者将被诊断为乳腺癌阴性。但我们可以调整该数字,以使其仅在更高程度的确定性下给出负面预测。

让我们看看如何做到这一点。以下是我们模型的一些预测。

代码语言:javascript
复制
# Regular predictions
default_preds = pd.DataFrame(model2.predict_proba(test_pool2).round(3))
default_preds['classification'] = model2.predict(test_pool2)
default_preds.sample(10)

请注意,观察结果82有63.4%的机会是阴性,但也有36%的机会是阳性,从医学标准来看可能被认为是高的。我们希望将此案例分类为阳性,即使知道它可能是错误的。因此,我们可以将这个人送去进行以后的测试。所以让我们将我们的假阴性率[FNR]容忍度设置为1%。

代码语言:javascript
复制
from catboost.utils import select_threshold
# Finding the right threshold
print(select_threshold(model2, test_pool2, FNR=0.01))

0.1420309044590601

太好了。现在,CatBoost计算出了新的阈值,被分类为负的阈值为1-0.142 = 0.858。简单来说,类别0的概率必须超过85.8%才能被标记为0,否则将被分类为1。

好的,所以我创建了一个自定义函数predict_threshold(df,threshold,rate_type)(请访问我的GitHub查看代码),该函数以解释变量的数据框、所需的阈值和速率类型(FNR或FPR)作为输入,并返回使用新切割的分类。

代码语言:javascript
复制
# Predict
new_predictions = predict_threshold(df= test_pool2, 
                                    threshold= 0.01, 
                                    rate_type= "FNR")

# Standard predictions
normal_predictions = model2.predict(test_pool2)

同样的观察结果,在索引82处,先前以63%的概率被分类为阴性(0)的现在被分类为阳性(1)。

这是使用标准50%阈值的混淆矩阵。

代码语言:javascript
复制
# Confusion Matrix 50% standard threshold
pd.DataFrame( confusion_matrix(y_true=y_test, y_pred=normal_predictions) )

这是使用更新后的阈值的新分类。

代码语言:javascript
复制
# Confusion Matrix 1% of false negatives allowed threshold
pd.DataFrame( confusion_matrix(y_true=y_test, y_pred=new_predictions) )

观察底部左侧的单元格[true=1,pred=0,FN],来自两个混淆矩阵。顶部的一个显示了一个假阴性。这个人实际上患有癌症,但模型将其分类为阴性。在新模型中解决了这个问题,没有假阴性。另一方面,我们也增加了一个假阳性。因此,这一切都是关于权衡,就像数据科学中的许多其他事情一样。

FPR(I型错误)和FNR(II型错误)是互补的。当你降低一个时,必然会增加另一个。

如果项目需要非常低数量的假阳性,同样的方法也可以用来降低FPR。

总结

总之,在这篇文章中,我们学到了以下内容:

  • 分类的默认切割阈值是概率的50%。
  • 可以调整此数字以减少假阳性或假阴性的数量。
  • FPR(I型错误)和FNR(II型错误)是互补的。降低一个将增加另一个。
  • 使用catboost包计算概率切割的阈值值。
  • 例如:predict_threshold(test_pool2,threshold=0.01,rate_type="FNR")

参考资料

https://github.com/catboost/tutorials/blob/master/events/2019_pydata_london/pydata_london_2019.ipynb

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2024-03-11,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 磐创AI 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 介绍
  • 数据科学背景
  • 编码
    • 数据集
      • 训练测试分割
        • 第一个模型
          • 简化模型
          • 总结
          相关产品与服务
          腾讯云服务器利旧
          云服务器(Cloud Virtual Machine,CVM)提供安全可靠的弹性计算服务。 您可以实时扩展或缩减计算资源,适应变化的业务需求,并只需按实际使用的资源计费。使用 CVM 可以极大降低您的软硬件采购成本,简化 IT 运维工作。
          领券
          问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档