首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >Seaborn配对图:使用色调时map_lower中丢失的数据

Seaborn配对图:使用色调时map_lower中丢失的数据
EN

Stack Overflow用户
提问于 2017-10-26 12:08:50
回答 2查看 984关注 0票数 0

当我定义hue来着色我的绘图时,map_lower会更频繁地调用它的函数,并且与不使用hue的等效调用相比,它会松散数据。这是个窃听器还是我犯了个错误?

请参阅下面的代码

代码语言:javascript
运行
复制
import matplotlib.pyplot as plt
import pandas as pd
from scipy import stats
import seaborn as sns


def corrfunc(x, y, **kws):
    r, _ = stats.pearsonr(x, y)
    print(x)
    print(y)
    print(r)

iris = sns.load_dataset("iris")
seax = sns.pairplot(iris, size=2, vars=["petal_width", "petal_length", "sepal_width"])
seax.map_lower(corrfunc)
plt.show()

如果你改变了

代码语言:javascript
运行
复制
sns.pairplot(iris, size=2, vars=["petal_width", "petal_length", "sepal_width"])

代码语言:javascript
运行
复制
seax = sns.pairplot(iris, hue="sepal_length", size=2, vars=["petal_width", "petal_length", "sepal_width"])

代码被破坏了,但情节看起来很好。因此,如果您运行的代码没有色调,corrfunc将被调用3次,在下面的3个地块。如果我添加hue=" class“来按字段类对情节着色,corrfunc将被调用8次左右。我不明白为什么用颜色着色对map_lower有影响。

EN

Stack Overflow用户

回答已采纳

发布于 2017-10-28 03:31:36

所以也许有一天这会帮助那些想做我想做的事的人。这是我丑陋但有效的解决方案:

代码语言:javascript
运行
复制
#!/usr/bin/env python
import matplotlib.pyplot as plt
from scipy import stats
import seaborn as sns

# Global variables to keep track of data chunks if you
# use hue to color the data points. map_lower will with
# hue group data in chunks of identical hue values

dataLength = xName = yName = xData = yData = ''


# Function to group data pairs to plot their correlation
def assemble_data_subplot(x, y, **kwargs):
    global xName, yName, xData, yData, dataLength
    if xName == '' and yName == '':
        xName = x.name
        yName = y.name
        xData = x
        yData = y
    elif xName == x.name and yName == y.name:
        xData = xData.append(x)
        yData = yData.append(y)

    if len(xData) == dataLength:
        correlate_data(xData, yData)
        xName = yName = xData = yData = ''


# Correlation function
def correlate_data(xData, yData):
    r, _ = stats.pearsonr(xData, yData)
    r = r**2
    sax = plt.gca()
    sax.annotate("$r^2$={:.2f}".format(r),
                 xy=(.02, .86),
                 xycoords=sax.transAxes)


# Main function to plot the pairwise correlation plot
def main():
    # Init global variable to set it later
    global dataLength

    # Path to CSV file and data frame builder
    df = sns.load_dataset("iris")

    # Example without hue
    g = sns.pairplot(df, size=2, hue="petal_width",
                     vars=["petal_width",
                           "petal_length",
                           "sepal_width"])

    # Get the number of data entries to check when the assembled data
    # is complete. Used in assemble_data_subplot
    dataLength = len(df)

    # Plot the r^2 value on the lower part of the pair plot
    g.map_lower(assemble_data_subplot)

    # Generate the output
    g.savefig("output.png")
    plt.show()


if __name__ == "__main__":
    main()
票数 1
EN
查看全部 2 条回答
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/46953958

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档