首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >如何在matplotlib中为条形图添加组标签

如何在matplotlib中为条形图添加组标签
EN

Stack Overflow用户
提问于 2013-10-04 22:59:35
回答 2查看 14.5K关注 0票数 43

我想使用matplotlib条形图绘制以下形式的数据:

代码语言:javascript
复制
data = {'Room A':
           {'Shelf 1':
               {'Milk': 10,
                'Water': 20},
            'Shelf 2':
               {'Sugar': 5,
                'Honey': 6}
           },
        'Room B':
           {'Shelf 1':
               {'Wheat': 4,
                'Corn': 7},
            'Shelf 2':
               {'Chicken': 2,
                'Cow': 1}
           }
       }

条形图应该看起来像是

从x轴上的标签应该可以看到条形组。有没有办法用matplotlib做到这一点?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2013-10-08 15:57:48

由于我在matplotlib中找不到内置的解决方案,所以我编写了自己的解决方案:

代码语言:javascript
复制
#!/usr/bin/env python

from matplotlib import pyplot as plt

def mk_groups(data):
    try:
        newdata = data.items()
    except:
        return

    thisgroup = []
    groups = []
    for key, value in newdata:
        newgroups = mk_groups(value)
        if newgroups is None:
            thisgroup.append((key, value))
        else:
            thisgroup.append((key, len(newgroups[-1])))
            if groups:
                groups = [g + n for n, g in zip(newgroups, groups)]
            else:
                groups = newgroups
    return [thisgroup] + groups

def add_line(ax, xpos, ypos):
    line = plt.Line2D([xpos, xpos], [ypos + .1, ypos],
                      transform=ax.transAxes, color='black')
    line.set_clip_on(False)
    ax.add_line(line)

def label_group_bar(ax, data):
    groups = mk_groups(data)
    xy = groups.pop()
    x, y = zip(*xy)
    ly = len(y)
    xticks = range(1, ly + 1)

    ax.bar(xticks, y, align='center')
    ax.set_xticks(xticks)
    ax.set_xticklabels(x)
    ax.set_xlim(.5, ly + .5)
    ax.yaxis.grid(True)

    scale = 1. / ly
    for pos in xrange(ly + 1):  # change xrange to range for python3
        add_line(ax, pos * scale, -.1)
    ypos = -.2
    while groups:
        group = groups.pop()
        pos = 0
        for label, rpos in group:
            lxpos = (pos + .5 * rpos) * scale
            ax.text(lxpos, ypos, label, ha='center', transform=ax.transAxes)
            add_line(ax, pos * scale, ypos)
            pos += rpos
        add_line(ax, pos * scale, ypos)
        ypos -= .1

if __name__ == '__main__':
    data = {'Room A':
               {'Shelf 1':
                   {'Milk': 10,
                    'Water': 20},
                'Shelf 2':
                   {'Sugar': 5,
                    'Honey': 6}
               },
            'Room B':
               {'Shelf 1':
                   {'Wheat': 4,
                    'Corn': 7},
                'Shelf 2':
                   {'Chicken': 2,
                    'Cow': 1}
               }
           }
    fig = plt.figure()
    ax = fig.add_subplot(1,1,1)
    label_group_bar(ax, data)
    fig.subplots_adjust(bottom=0.3)
    fig.savefig('label_group_bar_example.png')

mk_groups函数接受字典(或任何带有items()方法的内容,如collections.OrderedDict),并将其转换为数据格式,然后用于创建图表。它基本上是一个表单列表:

代码语言:javascript
复制
[ [(label, bars_to_span), ...], ..., [(tick_label, bar_value), ...] ]

add_line函数在子图中的指定位置(在轴坐标中)创建一条垂直线。

label_group_bar函数获取一个字典,并在下面带有标签的子图中创建条形图。然后,该示例的结果看起来是like this

更容易或更好的解决方案和建议仍然非常受欢迎。

票数 63
EN

Stack Overflow用户

发布于 2016-09-15 09:56:54

我寻找这个解决方案已经有一段时间了。我对它进行了一些修改,以便与pandas数据表一起使用。只有公平地分享。

代码语言:javascript
复制
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from itertools import groupby

def test_table():
    data_table = pd.DataFrame({'Room':['Room A']*4 + ['Room B']*4,
                               'Shelf':(['Shelf 1']*2 + ['Shelf 2']*2)*2,
                               'Staple':['Milk','Water','Sugar','Honey','Wheat','Corn','Chicken','Cow'],
                               'Quantity':[10,20,5,6,4,7,2,1],
                               'Ordered':np.random.randint(0,10,8)
                               })
    return data_table

def add_line(ax, xpos, ypos):
    line = plt.Line2D([xpos, xpos], [ypos + .1, ypos],
                      transform=ax.transAxes, color='black')
    line.set_clip_on(False)
    ax.add_line(line)

def label_len(my_index,level):
    labels = my_index.get_level_values(level)
    return [(k, sum(1 for i in g)) for k,g in groupby(labels)]
    
def label_group_bar_table(ax, df):
    ypos = -.1
    scale = 1./df.index.size
    for level in range(df.index.nlevels)[::-1]:
        pos = 0
        for label, rpos in label_len(df.index,level):
            lxpos = (pos + .5 * rpos)*scale
            ax.text(lxpos, ypos, label, ha='center', transform=ax.transAxes)
            add_line(ax, pos*scale, ypos)
            pos += rpos
        add_line(ax, pos*scale , ypos)
        ypos -= .1

df = test_table().groupby(['Room','Shelf','Staple']).sum()
fig = plt.figure()
ax = fig.add_subplot(111)
df.plot(kind='bar',stacked=True,ax=fig.gca())
#Below 3 lines remove default labels
labels = ['' for item in ax.get_xticklabels()]
ax.set_xticklabels(labels)
ax.set_xlabel('')
label_group_bar_table(ax, df)
fig.subplots_adjust(bottom=.1*df.index.nlevels)
plt.show()

票数 29
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/19184484

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档