首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >计算多类的sklearn.roc_auc_score

计算多类的sklearn.roc_auc_score
EN

Stack Overflow用户
提问于 2016-09-25 18:17:43
回答 6查看 54.4K关注 0票数 27

我想为我的分类器计算AUC,精度,准确度。我正在做有监督的学习:

这是我的工作代码。这段代码适用于二进制类,但不适用于多类。请假设您有一个包含二进制类的数据帧:

代码语言:javascript
运行
复制
sample_features_dataframe = self._get_sample_features_dataframe()
labeled_sample_features_dataframe = retrieve_labeled_sample_dataframe(sample_features_dataframe)
labeled_sample_features_dataframe, binary_class_series, multi_class_series = self._prepare_dataframe_for_learning(labeled_sample_features_dataframe)

k = 10
k_folds = StratifiedKFold(binary_class_series, k)
for train_indexes, test_indexes in k_folds:
    train_set_dataframe = labeled_sample_features_dataframe.loc[train_indexes.tolist()]
    test_set_dataframe = labeled_sample_features_dataframe.loc[test_indexes.tolist()]

    train_class = binary_class_series[train_indexes]
    test_class = binary_class_series[test_indexes]
    selected_classifier = RandomForestClassifier(n_estimators=100)
    selected_classifier.fit(train_set_dataframe, train_class)
    predictions = selected_classifier.predict(test_set_dataframe)
    predictions_proba = selected_classifier.predict_proba(test_set_dataframe)

    roc += roc_auc_score(test_class, predictions_proba[:,1])
    accuracy += accuracy_score(test_class, predictions)
    recall += recall_score(test_class, predictions)
    precision += precision_score(test_class, predictions)

最后我将结果除以K,当然是为了得到平均AUC,精度等。这段代码工作得很好。但是,我不能对多类进行同样的计算:

代码语言:javascript
运行
复制
    train_class = multi_class_series[train_indexes]
    test_class = multi_class_series[test_indexes]

    selected_classifier = RandomForestClassifier(n_estimators=100)
    selected_classifier.fit(train_set_dataframe, train_class)

    predictions = selected_classifier.predict(test_set_dataframe)
    predictions_proba = selected_classifier.predict_proba(test_set_dataframe)

我发现对于多类,我必须添加参数“加权”来表示平均值。

代码语言:javascript
运行
复制
    roc += roc_auc_score(test_class, predictions_proba[:,1], average="weighted")

我遇到错误:提升ValueError(“不支持{0}格式”.format(Y_type))

ValueError:不支持多类格式

EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/39685740

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档