首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >将seaborn regplot参数注记到绘图中

将seaborn regplot参数注记到绘图中
EN

Stack Overflow用户
提问于 2021-05-19 14:25:57
回答 1查看 190关注 0票数 0

我正在尝试使用seaborn.regplotr2prmse值绘制散点图。但是下面的代码返回了一个AttributeError: 'AxesSubplot' object has no attribute 'map_dataframe'错误

代码语言:javascript
运行
复制
fig, axes = plt.subplots(1, 2, figsize=(15, 5), sharey=True)


g = sns.regplot(x='est_fmc', y='1h_surface', data=new_df, ax=axes[0])
def annotate(data, **kws):
    slope, intercept, rvalue, pvalue, stderr = scipy.stats.linregress(x = data['est_fmc'], y= data['1h_surface'] )
    rmse = mean_squared_error(data['est_fmc'], data['1h_surface'], squared=False)
    print(slope, intercept, rvalue, pvalue, rmse)
    ax = plt.gca()
    ax.text(.02, .9, 'r2={:.2f}, p={:.2g}, rmse = {:.2f}'.format(rvalue**2, pvalue, rmse),
            transform=ax.transAxes)
g.map_dataframe(annotate)


g = sns.regplot(x='est_fmc', y='1h_profile', data=new_df, ax = axes[1] )
def annotate(data, **kws):
    slope, intercept, rvalue, pvalue, stderr = scipy.stats.linregress(x = data['est_fmc'], y= data['1h_profile'] )
    rmse = mean_squared_error(data['est_fmc'], data['1h_profile'], squared=False)
    print(slope, intercept, rvalue, pvalue, rmse)
    ax = plt.gca()
    ax.text(.02, .9, 'r2={:.2f}, p={:.2g}, rmse = {:.2f}'.format(rvalue**2, pvalue, rmse),
            transform=ax.transAxes)
g.map_dataframe(annotate)

有没有办法解决这个问题?我真的很感谢任何人的帮助。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-05-19 15:53:55

海运的一个重要方面是difference between figure-level and axes-level functionssns.regplot是一个轴级函数。它获取一个ax (指示子绘图)作为可选参数,并始终返回在其上创建绘图的ax

map_dataframe旨在使用图形级函数(创建一个子图网格)。它可以与诸如relplot之类的函数一起工作。请注意,图形级函数不接受ax作为参数,它们总是创建自己的新图形。

在您的示例中,可以使用ax参数修改annotate函数,也可以使用xy的参数修改该函数,使其适用于这两个子图。( Python中的一个重要概念是"DRY - Don't Repeat Yourself"。)

以下是修改后的代码,从一些测试数据开始。(进一步的改进是将对regplot的调用放到annotate函数中,将该函数重命名为类似于“regplot_with_annotation”的名称)。

代码语言:javascript
运行
复制
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import scipy
from sklearn.metrics import mean_squared_error

def annotate(ax, data, x, y):
    slope, intercept, rvalue, pvalue, stderr = scipy.stats.linregress(x=data[x], y=data[y])
    rmse = mean_squared_error(data[x], data[y], squared=False)
    ax.text(.02, .9, f'r2={rvalue ** 2:.2f}, p={pvalue:.2g}, rmse={rmse:.2f}', transform=ax.transAxes)

est_fmc = np.random.uniform(0, 10, 100)
oneh_surface = 2 * est_fmc + np.random.normal(0, 5, 100) + 10
oneh_profile = 3 * est_fmc + np.random.normal(0, 3, 100) + 5

new_df = pd.DataFrame({'est_fmc': est_fmc, '1h_surface': oneh_surface, '1h_profile': oneh_profile})

fig, axes = plt.subplots(1, 2, figsize=(15, 5), sharey=True)

ax = sns.regplot(x='est_fmc', y='1h_surface', data=new_df, ax=axes[0])
annotate(ax, data=new_df, x='est_fmc', y='1h_surface')

ax = sns.regplot(x='est_fmc', y='1h_profile', data=new_df, ax=axes[1])
annotate(ax, data=new_df, x='est_fmc', y='1h_profile')

plt.tight_layout()
plt.show()

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/67597972

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档