我喜欢这个来自PerformanceAnalytics
R包的chart.Correlation
function的相关矩阵
我如何在Python中创建它?我看到的相关矩阵图主要是热图,比如this seaborn
example。
发布于 2018-01-08 14:50:40
下面的cor_matrix
函数实现了这一点,并添加了一个二元核密度图。感谢@karl-anka的评论让我开始。
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
sns.set(style='white')
iris = sns.load_dataset('iris')
def corrfunc(x, y, **kws):
r, p = stats.pearsonr(x, y)
p_stars = ''
if p <= 0.05:
p_stars = '*'
if p <= 0.01:
p_stars = '**'
if p <= 0.001:
p_stars = '***'
ax = plt.gca()
ax.annotate('r = {:.2f} '.format(r) + p_stars,
xy=(0.05, 0.9), xycoords=ax.transAxes)
def annotate_colname(x, **kws):
ax = plt.gca()
ax.annotate(x.name, xy=(0.05, 0.9), xycoords=ax.transAxes,
fontweight='bold')
def cor_matrix(df):
g = sns.PairGrid(df, palette=['red'])
# Use normal regplot as `lowess=True` doesn't provide CIs.
g.map_upper(sns.regplot, scatter_kws={'s':10})
g.map_diag(sns.distplot)
g.map_diag(annotate_colname)
g.map_lower(sns.kdeplot, cmap='Blues_d')
g.map_lower(corrfunc)
# Remove axis labels, as they're in the diagonals.
for ax in g.axes.flatten():
ax.set_ylabel('')
ax.set_xlabel('')
return g
cor_matrix(iris)
发布于 2020-12-22 08:43:39
为了解决“'numpy.ndarray‘对象没有属性'name'”行的错误"ax.annotate(x.name,xy=(0.05,0.9),xycoords=ax.transAxes,fontweight='bold')“并保持通用性,在cor_matrix函数内部构建一个迭代函数,并将annnotate_col函数移动到cor_matrix函数中,如下所示。
def corrfunc(x, y, **kws):
r, p = stats.pearsonr(x, y)
p_stars = ''
if p <= 0.05:
p_stars = '*'
if p <= 0.01:
p_stars = '**'
if p <= 0.001:
p_stars = '***'
ax = plt.gca()
ax.annotate('r = {:.2f} '.format(r) + p_stars, xy=(0.05, 0.9), ycoords=ax.transAxes)
def cor_matrix(df, save=False):
# ======= NEW ITERATION FUNCTION ====
label_iter = iter(df).__next__
# ====================================
def annotate_colname(x, **kws):
ax = plt.gca()
# ===== GHANGE below x.name by label_iter() ======
ax.annotate(label_iter(), xy=(0.05, 0.9), xycoords=ax.transAxes, fontweight='bold')
g = sns.PairGrid(df, palette=['red'])
# Use normal regplot as `lowess=True` doesn't provide CIs.
g.map_upper(sns.regplot, scatter_kws={'s':10}, line_kws={"color": "red"})
g.map_diag(sns.histplot, kde=True) # fix deprecated message
g.map_diag(annotate_colname)
g.map_lower(sns.kdeplot, cmap='Blues_d')
g.map_lower(corrfunc)
# Remove axis labels, as they're in the diagonals.
for ax in g.axes.flatten():
ax.set_ylabel('')
ax.set_xlabel('')
if save:
plt.savefig('corr_mat.png')
return g
https://stackoverflow.com/questions/48139899
复制相似问题