我正在尝试使用seaborn.regplot
用r2
、p
和rmse
值绘制散点图。但是下面的代码返回了一个AttributeError: 'AxesSubplot' object has no attribute 'map_dataframe'
错误
fig, axes = plt.subplots(1, 2, figsize=(15, 5), sharey=True)
g = sns.regplot(x='est_fmc', y='1h_surface', data=new_df, ax=axes[0])
def annotate(data, **kws):
slope, intercept, rvalue, pvalue, stderr = scipy.stats.linregress(x = data['est_fmc'], y= data['1h_surface'] )
rmse = mean_squared_error(data['est_fmc'], data['1h_surface'], squared=False)
print(slope, intercept, rvalue, pvalue, rmse)
ax = plt.gca()
ax.text(.02, .9, 'r2={:.2f}, p={:.2g}, rmse = {:.2f}'.format(rvalue**2, pvalue, rmse),
transform=ax.transAxes)
g.map_dataframe(annotate)
g = sns.regplot(x='est_fmc', y='1h_profile', data=new_df, ax = axes[1] )
def annotate(data, **kws):
slope, intercept, rvalue, pvalue, stderr = scipy.stats.linregress(x = data['est_fmc'], y= data['1h_profile'] )
rmse = mean_squared_error(data['est_fmc'], data['1h_profile'], squared=False)
print(slope, intercept, rvalue, pvalue, rmse)
ax = plt.gca()
ax.text(.02, .9, 'r2={:.2f}, p={:.2g}, rmse = {:.2f}'.format(rvalue**2, pvalue, rmse),
transform=ax.transAxes)
g.map_dataframe(annotate)
有没有办法解决这个问题?我真的很感谢任何人的帮助。
发布于 2021-05-19 15:53:55
海运的一个重要方面是difference between figure-level and axes-level functions。sns.regplot
是一个轴级函数。它获取一个ax
(指示子绘图)作为可选参数,并始终返回在其上创建绘图的ax
。
map_dataframe
旨在使用图形级函数(创建一个子图网格)。它可以与诸如relplot
之类的函数一起工作。请注意,图形级函数不接受ax
作为参数,它们总是创建自己的新图形。
在您的示例中,可以使用ax
参数修改annotate
函数,也可以使用x
和y
的参数修改该函数,使其适用于这两个子图。( Python中的一个重要概念是"DRY - Don't Repeat Yourself"。)
以下是修改后的代码,从一些测试数据开始。(进一步的改进是将对regplot
的调用放到annotate
函数中,将该函数重命名为类似于“regplot_with_annotation”的名称)。
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import scipy
from sklearn.metrics import mean_squared_error
def annotate(ax, data, x, y):
slope, intercept, rvalue, pvalue, stderr = scipy.stats.linregress(x=data[x], y=data[y])
rmse = mean_squared_error(data[x], data[y], squared=False)
ax.text(.02, .9, f'r2={rvalue ** 2:.2f}, p={pvalue:.2g}, rmse={rmse:.2f}', transform=ax.transAxes)
est_fmc = np.random.uniform(0, 10, 100)
oneh_surface = 2 * est_fmc + np.random.normal(0, 5, 100) + 10
oneh_profile = 3 * est_fmc + np.random.normal(0, 3, 100) + 5
new_df = pd.DataFrame({'est_fmc': est_fmc, '1h_surface': oneh_surface, '1h_profile': oneh_profile})
fig, axes = plt.subplots(1, 2, figsize=(15, 5), sharey=True)
ax = sns.regplot(x='est_fmc', y='1h_surface', data=new_df, ax=axes[0])
annotate(ax, data=new_df, x='est_fmc', y='1h_surface')
ax = sns.regplot(x='est_fmc', y='1h_profile', data=new_df, ax=axes[1])
annotate(ax, data=new_df, x='est_fmc', y='1h_profile')
plt.tight_layout()
plt.show()
https://stackoverflow.com/questions/67597972
复制相似问题