首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >循环中添加的动态调整大小的子图

循环中添加的动态调整大小的子图
EN

Stack Overflow用户
提问于 2019-06-13 03:45:29
回答 1查看 831关注 0票数 0

我对在python中绘制图形非常陌生,在创建子图时遇到了麻烦。

我目前有一个循环,它将产生2个数组,然后我将它们分配给Dataframe列,并使用df.plot()来绘制图形。这导致我每次通过循环都会得到一个不同的图。

我想试着把所有的东西都集中在一个图上。我在循环之外创建了figure对象,在循环中,我尝试了下面的代码。我知道我需要定义子图的大小,但问题是循环的数量是用户定义的。此外,关于形状-如果它是4个循环,2x2就可以了,但如果它是25,我希望它尽可能地近似一个正方形。我不确定这是不是可行。

代码语言:javascript
复制
        ax = plt.subplot(i)
        ax.scatter(y_df['y_pred'], y_df['y_test'])

但我一直收到以下错误:

“三位数,而不是{}".format(args))

ValueError:整型子绘图规范必须是三位数字,而不是1

这是我的完整代码。我删除了许多不相关的行,以便更容易理解:

代码语言:javascript
复制
  fig = plt.figure()


    tscv = TimeSeriesSplit(n_splits=self.no_splits)
    for train_index, test_index in tqdm(tscv.split(X)):
        X_train, X_test = X.iloc[train_index], X.iloc[test_index]
        y_train, y_test = y.iloc[train_index], y.iloc[test_index]



        self.regressor.fit(X_train, y_train.ravel())

        # predict y values
        y_pred = self.regressor.predict(X_test)


        # plot y_pred vs y_test
        y_df = pd.DataFrame()
        y_pred = y_pred.reshape(len(y_pred), )
        y_test = y_test.reshape(len(y_test), )
        y_df['y_pred'] = y_pred
        y_df['y_test'] = y_test

        ax = plt.subplot(i)
        ax.scatter(y_df['y_pred'], y_df['y_test'])
EN

回答 1

Stack Overflow用户

发布于 2019-06-13 07:26:11

add_subplot有三个参数:

代码语言:javascript
复制
fig.add_subplot(nrows, ncols, index)

如果您想要更新子绘图的指定位置,可以在各个轴上使用"change_geometry“,它采用相同的三个参数,例如:

代码语言:javascript
复制
for i,ax in enumerate(fig.axes):
    if isinstance(ax,matplotlib.axes.SubplotBase):
        ax.change_geometry(len(fig.axes),1,i)

我用“change_geometry”做了一个小例子:

代码语言:javascript
复制
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
import numpy as np
import random


def run(val):
    n_axes = int(val)
    ax_names = random.sample(range(max_size),n_axes)

    i=0
    for ax_name in range(1,max_size):
        # Delete outdated axes
        if ax_name not in ax_names and ax_name in my_axes.keys():
            fig.delaxes(my_axes[ax_name])
            del my_axes[ax_name]
        if ax_name in ax_names:
            i+=1
            #Plot new data on new axes
            if ax_name not in my_axes.keys():
                print(i,n_axes)
                y = np.random.rand(x.shape[0])
                my_axes[ax_name] = ax = fig.add_subplot(n_axes,1,i)
                ax.plot(x,y)
            # Relocate "old" ax to new position
            else:
                my_axes[ax_name].change_geometry(n_axes,1,i)



fig = plt.figure()
my_axes = {}
x = np.linspace(0,1,100)
max_size=81

ax = plt.axes([0.1, 0.03, 0.8, 0.03], facecolor='#cccc00')
sNum = Slider(ax, '#plots', 1, max_size, valinit=1, valstep=1)
sNum.on_changed(run)
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56569394

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档