在XY分布的二维直方图中,我如何知道每个点对应的仓号和仓位高度?
如何正确地可视化结果(最好是使用seaborn)?
发布于 2021-02-17 01:08:35
因此,我想创建一个曲线图,其中我的x,y数据点将与使用numpy.histogram2d计算的直方图叠加。
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
np.random.seed(9)
x = np.round(10*np.random.rand(12), 1)
y = np.round(10*np.random.rand(12), 1)
binrange=([x.min(), x.max()+1], [y.min(), y.max()+1]
h, ex, ey = np.histogram2d(x, y, bins=5, range=binrange), density=False)
nx = np.digitize(x, bins=ex)
ny = np.digitize(y, bins=ey)
print('Why do my points fall into empty bins??')
print('Values:', '\n', x, '\n', y, '\n')
print('Bins', '\n', ex, '\n', ey, '\n')
print('Bin numbers:\n', nx, '\n', ny, '\n')
sns.histplot(x=x, y=y, bins=5, binrange=binrange), cbar=True)
sns.scatterplot(x=x, y=y, s=15, color='k')
plt.suptitle('What I expect to see')输出:
Values:
[0.1 5. 5. 1.3 1.4 2.2 4.2 2.5 0.8 3.5 1.7 8.8]
[9.5 0.4 7. 5.7 9. 6.7 5.5 7. 3.9 6.9 8.2 4.7]
Bins
[0.1 2.04 3.98 5.92 7.86 9.8 ]
[ 0.4 2.42 4.44 6.46 8.48 10.5 ]
Bin numbers:
[1 3 3 1 1 2 3 2 1 2 1 5]
[5 1 4 3 5 4 3 4 2 4 4 3]

这里的一个小技巧是使用np.rot90正确地旋转计算出的直方图
plt.imshow(np.rot90(h, 1),
extent=[x.min(), x.max()+1, y.min(), y.max()+1], origin='upper', cmap='Blues')
plt.colorbar()
plt.scatter(x=x, y=y, s=10, color='k')

这样,问题就差不多解决了。但是,使用sns.heatmap绘制最后一幅图需要更多的内容。主要的问题是以某种方式将范围设置为轴。或者,我们可以将原始数据缩放到限制(0,number_of_bins)。
例如:
def transform(distrA, limitsA, limitsB):
'''Transforms distribution of unevenly distributed points in a space A to space B"
Input:
distrA - numpy 2D array [[arrdim1 ...], [arrdim2 ...], [arrdim3 ...], [arrdim4 ...]] -
Distribution to be transformed.
limitsA and limitsB - (array of pairs) -
Limits of space A and B, correspondingly, in the form (lower, higher)
Output:
distrB - transformed distribution'''
shape=distrA.shape
distrB = np.empty(shape=distrA.shape)
for i in range(shape[0]):
spanA = limitsA[i][1] - limitsA[i][0]
spanB = limitsB[i][1] - limitsB[i][0]
for j in range(shape[1]):
distrB[i, j] = spanB * (distrA[i, j]-limitsA[i][0]) / spanA + limitsB[i][0]
return distrB
hm=sns.heatmap(np.rot90(h, 1), cmap='Blues', annot=True)
h_trans=transform(np.asarray([x, y]),
[[x.min(), x.max()+1], [y.min(), y.max()+1]],
((0,5), (5,0))
)
sns.scatterplot(x=h_trans[0], y=h_trans[1], s=20, color='k')
plt.title('Desired seaborn heatmap')

https://stackoverflow.com/questions/66228946
复制相似问题