首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >GridSearchCV结果热图

GridSearchCV结果热图
EN

Stack Overflow用户
提问于 2021-06-26 10:19:15
回答 1查看 152关注 0票数 4

我正在尝试为sklearn的GridSearchCV结果生成一个热图。我喜欢sklearn-evaluation的一点是它非常容易生成热图。然而,我遇到了一个问题。当我给一个参数为None时,例如

代码语言:javascript
运行
复制
max_depth = [3, 4, 5, 6, None]

在生成热图时,它显示以下错误:

代码语言:javascript
运行
复制
TypeError: '<' not supported between instances of 'NoneType' and 'int'

有什么解决方法吗?我已经找到了其他生成热图的方法,比如使用matplotlib和seaborn,但是没有什么比sklearn-evalutaion更漂亮的热图了。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-06-26 13:51:52

我摆弄了一下grid_search.py文件/lib/python3.8/site-packages/sklearn_evaluation/plot/grid_search.py。在第192/193行更改行

从…

代码语言:javascript
运行
复制
row_names = sorted(set([t[0] for t in matrix_elements.keys()]),
                   key=itemgetter(1))
col_names = sorted(set([t[1] for t in matrix_elements.keys()]),
                   key=itemgetter(1))

至:

代码语言:javascript
运行
复制
row_names = sorted(set([t[0] for t in matrix_elements.keys()]),
                   key=lambda x: (x[1] is None, x[1]))
col_names = sorted(set([t[1] for t in matrix_elements.keys()]),
                   key=lambda x: (x[1] is None, x[1]))

在排序时,将所有None移到列表的末尾是基于Andrew Clarke之前的answer

使用此调整,我的演示脚本如下所示:

代码语言:javascript
运行
复制
import numpy as np
import sklearn.datasets as datasets
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn_evaluation import plot

data = datasets.make_classification(n_samples=200, n_features=10, n_informative=4, class_sep=0.5)


X = data[0]
y = data[1]

hyperparameters = {
    "max_depth": [1, 2, 3, None],
    "criterion": ["gini", "entropy"],
    "max_features": ["sqrt", "log2"],
}

est = RandomForestClassifier(n_estimators=5)
clf = GridSearchCV(est, hyperparameters, cv=3)
clf.fit(X, y)
plot.grid_search(clf.cv_results_, change=("max_depth", "criterion"), subset={"max_features": "sqrt"})


import matplotlib.pyplot as plt

plt.show()

输出如下所示:

票数 5
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68138679

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档