前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >python中实现ROC curve

python中实现ROC curve

作者头像
生信编程日常
发布2020-09-03 15:24:18
9370
发布2020-09-03 15:24:18
举报

以下是使用scikit learn预测、做出决策边界并画出ROC曲线的一个示例,以鸢尾花数据集为例。

1. 导入鸢尾花的数据
代码语言:javascript
复制
import numpy as np
import matplotlib.pyplot as plt
import warnings
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn import metrics 
from sklearn.preprocessing import StandardScaler
from sklearn.svm import LinearSVC
warnings.filterwarnings('ignore')

iris = datasets.load_iris()

X = iris.data
y = iris.target

X = X[y<2,:2] 
y = y[y<2] # 方便可视化
2. 标准化数据并使用SVM预测
代码语言:javascript
复制
standardScaler = StandardScaler()
standardScaler.fit(X)
X_standard = standardScaler.transform(X)

X_train, X_test, y_train, y_test = train_test_split(X_standard, y, test_size=0.75, random_state=1)
svc2 = LinearSVC(C=0.001)
svc2.fit(X_train, y_train)
3. 做出决策边界
代码语言:javascript
复制
# 决策边界函数
def plot_boundary(model, X, y):
    x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
    y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
    
    h = .02  # step size in the mesh
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    Z = model.predict(np.c_[xx.ravel(), yy.ravel()])

    # Put the result into a color plot
    Z = Z.reshape(xx.shape)
    plt.figure(1, figsize=(4, 3))
    plt.pcolormesh(xx, yy, Z, cmap=plt.cm.Set3)

    # Plot also the training points
    plt.scatter(X[:, 0], X[:, 1], c=y, edgecolors='k', cmap=plt.cm.Greens)
    plt.show()

plot_boundary(svc2, X_train, y_train)
4. ROC曲线
代码语言:javascript
复制
y_pred_proba = poly_kernel_svc.predict_proba(X_test)[::,1]
fpr, tpr, _ = metrics.roc_curve(y_test,  y_pred_proba)
auc = metrics.roc_auc_score(y_test, y_pred_proba)

plt.plot(fpr,tpr,label='SVM model AUC %0.2f' % auc, color='blue', lw = 2)
plt.plot([0, 1], [0, 1], color='black', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating Curve')
plt.legend(loc="lower right")
plt.show()

示例数据集比较简单,所以效果非常好,一般的数据集画出的效果如下:

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. 导入鸢尾花的数据
  • 2. 标准化数据并使用SVM预测
  • 3. 做出决策边界
  • 4. ROC曲线
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档