首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何在生成的热图图的某个地方插入一个小正方形标记

如何在生成的热图图的某个地方插入一个小正方形标记
EN

Stack Overflow用户
提问于 2022-06-20 19:17:37
回答 1查看 59关注 0票数 0

我正在创建一个包含10个子图的2D matplotlib图(i和j坐标)。每个子图包含150×150个网格单元数据。我如何能够插入一个小的黑色正方形标记(3×3)某个固定的地方(中心在坐标62和62 ),在每个生成的热图子图在这10个子图上?因此,正方形标记将包含从60到64在x和y方向的10个街区,并包含以x 62和y 62为中心的书面文本"Sale 1“。下面的代码不会生成任何补丁。任何反馈都是非常感谢的。

代码语言:javascript
复制
    from matplotlib.patches import Rectangle
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.metrics import r2_score, median_absolute_error
    import os
    import matplotlib.cm as cm
    from mpl_toolkits import axes_grid1
    import matplotlib.pyplot as plt
    #import seaborn as sns
    import matplotlib.pyplot as plt
    import matplotlib.colors as mcolors
    import matplotlib.colors
    import matplotlib.colors as colors
    
    
    data = np.random.rand(10, 150, 150)
    data = data.reshape(-1, 1)
    
    
    property = "Sale"
    
       
    pmin = data.min()
    pmax = data.max()
    
    v = np.linspace(round(pmin,3), round(pmax,3),15, endpoint=True)
    v = [round(x,3) for x in v] 
    
    fig, ax = plt.subplots(2, 5, figsize=(160, 80))
    row_count = 0
    col_count = 0
    for i in range(10):
    
        sub_plot_data = data[(i)*(150*150):(i+1)*150*150]
        
       
        x = 150
        y = 150
        #--------------------------- Define the map boundary ---------------------- 
        xmin = 1258096.6
        xmax = 1291155.0
        ymin = 11251941.6
        ymax = 11285000.0
        
        pmin = min(sub_plot_data)
        pmax = max(sub_plot_data) 
    
        
        # ---------------------------  define color bar for Discrete  color 
        bounds = np.linspace(-1, 1, 10)
        Discrete_colors = plt.get_cmap('jet')(np.linspace(0,1,len(bounds)+1))
        # create colormap without the outmost colors
        cmap = mcolors.ListedColormap(Discrete_colors[1:-1]) # 


        actual_2d = np.reshape(sub_plot_data,(y,x)) 
        
        im1 = ax[row_count, col_count].imshow(actual_2d, interpolation=None, cmap=cmap, 
        extent=(xmin, xmax, ymin, ymax), vmin=pmin, vmax=pmax)      
        plt.text(actual_2d[62, 62], actual_2d[62, 62], '%s' % 'Sale_1', 
        horizontalalignment='center', verticalalignment='center', color= 'black', fontsize= 90)
    
        
        ax[row_count, col_count].set_title("Sale_Stores-%s - L: %s"%(i+1, layer), 
        fontsize=130, pad=44, x=0.5, y=0.999) # new
    
        ax[row_count, col_count].set_aspect('auto')
        ax[row_count, col_count].tick_params(left=False, labelleft=False, top=False, 
        labeltop=False, right=False, labelright=False, bottom=False, labelbottom=False) # new
        #ax[row_count, col_count] = plt.gca()
        plt.gca().add_patch(Rectangle((60, 60), 3, 3, edgecolor='black', 
        facecolor='black',fill=True,lw=2))
        ax[row_count, col_count].add_patch(plt.text(62, 62, '%s' % 'Sale_1', 
       horizontalalignment='center', verticalalignment='center', color= 'black', fontsize= 90))
    
        col_count +=1
        
        if col_count == 5:
            row_count +=1  
            col_count =0
    
           
       
    fig.tight_layout(h_pad=10) 
    plt.subplots_adjust(left=0.02,
                    bottom=0.1, 
                    right=0.91, 
                    top=0.8, 
                    wspace=0.1, 
                    hspace=0.2)
    
      
    cbaxes = fig.add_axes([0.94, 0.05, 0.02, 0.8]) 
    cbar = fig.colorbar(im1, ax=ax.ravel().tolist(), ticks=v, extend='both', cax =cbaxes)
    cbar.ax.tick_params(labelsize=70) 
    #cbar.set_ticks(v)
    cbar.ax.set_yticklabels([i for i in v], fontsize=120)
    
    
    output_dir = r"D/test"
    plot_dir = os.path.join(output_dir, reservoir_property)
    if not os.path.exists(plot_dir):
        os.makedirs(plot_dir)
    fig.savefig(r"%s/per_allmodel.png"%(plot_dir))
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-06-21 05:19:32

我尝试了您的代码并做了一些修改:第一,图的大小太大,导致错误,所以我把它变小了;第二,我简化了子图: axes有一个子图对象列表,所以我用axes.flat删除了它们;第二个是将文本修改为注释。图形大小已经缩小,字体大小和间距已经调整,所以请自己修改它。最后,由于禁用了颜色条滴答,所以没有设置tick_params

代码语言:javascript
复制
fig, axes = plt.subplots(2, 5, figsize=(16, 8))
row_count = 0
col_count = 0

for i,ax in enumerate(axes.flat):

    sub_plot_data = data[(i)*(150*150):(i+1)*150*150]

    x = 150
    y = 150
    #--------------------------- Define the map boundary ---------------------- 
    xmin = 1258096.6
    xmax = 1291155.0
    ymin = 11251941.6
    ymax = 11285000.0

    pmin = min(sub_plot_data)
    pmax = max(sub_plot_data) 
    # ---------------------------  define color bar for Discrete  color 
    bounds = np.linspace(-1, 1, 10)
    Discrete_colors = plt.get_cmap('jet')(np.linspace(0,1,len(bounds)+1))
    # create colormap without the outmost colors
    cmap = mcolors.ListedColormap(Discrete_colors[1:-1]) # 

    actual_2d = np.reshape(sub_plot_data,(y,x)) 

    #im = ax.imshow(actual_2d, interpolation=None, cmap=cmap, extent=(xmin, xmax, ymin, ymax), vmin=pmin, vmax=pmax)      
    im = ax.imshow(actual_2d, interpolation=None, cmap=cmap)      
    ax.text(actual_2d[62, 62], actual_2d[62, 62]-10, '%s' % 'Sale_1', 
        horizontalalignment='center', verticalalignment='center', color= 'black', fontsize=18)
    ax.set_title("Sale_Stores-%s - L: %s"%(i+1, 1), fontsize=14, pad=30, x=0.5, y=0.999)
    ax.set_aspect('auto')
    ax.add_patch(Rectangle((60, 60), 6, 6, edgecolor='red', facecolor='red', fill=True, lw=2))
    ax.text(62, 62, '%s' % 'Sale_1', ha='center', va='center', color='black', fontsize=14)

       
fig.tight_layout(h_pad=10) 
plt.subplots_adjust(left=0.02,
                    bottom=0.1, 
                    right=0.91, 
                    top=0.8, 
                    wspace=0.1, 
                    hspace=0.5)

cbaxes = fig.add_axes([0.94, 0.05, 0.02, 0.8])
cbar = fig.colorbar(im, ax=axes.flat, ticks=v, extend='both', cax=cbaxes)
cbar.ax.tick_params(labelsize=10) 
#cbar.set_ticks(v)
cbar.ax.set_yticklabels([str(i) for i in v], fontsize=12)

#plt.tick_params(left=False, labelleft=False, top=False, labeltop=False, right=False, labelright=False, bottom=False, labelbottom=False)

plt.show()

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72691970

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档