前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >推荐:一个画出高大上的机器学习图的神器(论文必备)

推荐:一个画出高大上的机器学习图的神器(论文必备)

作者头像
石晓文
发布2019-08-01 14:50:09
1.8K0
发布2019-08-01 14:50:09
举报
文章被收录于专栏:小小挖掘机小小挖掘机

本人在写论文的时候,很多图片是用matplotlib和seaborn画的,但是,我还有一个神器,Scikit-plot,通过这个神器,画出了更加高大上的机器学习图,本文对Scikit-plot做下简单介绍。(作者:黄海广)

安装说明

安装Scikit-plot非常简单,直接用命令:

代码语言:javascript
复制
pip install scikit-plot

即可完成安装。

仓库地址:

https://github.com/reiinakano/scikit-plot

里面有使用说明和样例(py和ipynb格式)。

使用说明

简单举几个例子

  • 比如画出分类评级指标的ROC曲线的完整代码:
代码语言:javascript
复制
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB
X, y = load_digits(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)
nb = GaussianNB()
nb.fit(X_train, y_train)
predicted_probas = nb.predict_proba(X_test)
# The magic happens here
import matplotlib.pyplot as plt
import scikitplot as skplt
skplt.metrics.plot_roc(y_test, predicted_probas)
plt.show()

效果如图(相当高大上!)

图:ROC曲线

  • P-R曲线就是精确率precision vs 召回率recall 曲线,以recall作为横坐标轴,precision作为纵坐标轴。首先解释一下精确率和召回率。
代码语言:javascript
复制
import matplotlib.pyplot as plt
from sklearn.naive_bayes import GaussianNB
from sklearn.datasets import load_digits as load_data
import scikitplot as skplt
# Load dataset
X, y = load_data(return_X_y=True)
# Create classifier instance then fit
nb = GaussianNB()
nb.fit(X,y)
# Get predicted probabilities
y_probas = nb.predict_proba(X)
skplt.metrics.plot_precision_recall_curve(y, y_probas, cmap='nipy_spectral')
plt.show()

图:P-R曲线

  • 混淆矩阵是分类的重要评价标准,下面代码是用随机森林对鸢尾花数据集进行分类,分类结果画一个归一化的混淆矩阵。
代码语言:javascript
复制
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_digits as load_data
from sklearn.model_selection import cross_val_predict
import matplotlib.pyplot as plt
import scikitplot as skplt
X, y = load_data(return_X_y=True)
# Create an instance of the RandomForestClassifier
classifier = RandomForestClassifier()
# Perform predictions
predictions = cross_val_predict(classifier, X, y)
plot = skplt.metrics.plot_confusion_matrix(y, predictions, normalize=True)
plt.show()

图:归一化混淆矩阵

  • 其他图如学习曲线、特征重要性、聚类的肘点等等,都可以用几行代码搞定。

图:学习曲线、特征重要性

图:K-means肘点图

总结

本文对Scikit-plot做下简单介绍,这是一个机器学习的画图神器,几行代码就能画出高大上的机器学习图,作者当年的博士论文也是靠这个画图的。

仓库地址:

https://github.com/reiinakano/scikit-plot

里面有使用说明和样例。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-07-30,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 小小挖掘机 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档