我有两个张量,两个都是相同的形状。我想用GeomLoss
计算成对的辛克霍恩距离。
我尝试过的:
import torch
import geomloss # pip install git+https://github.com/jeanfeydy/geomloss
a = torch.rand((8,4))
b = torch.rand((8,4))
geomloss.SamplesLoss('sinkhorn')(a,b)
# ^ input shape [batch, feature_dim]
# will return a scalar value
geomloss.SamplesLoss('sinkhorn')(a.unsqueeze(1),b.unsqueeze(1))
# ^ input shape [batch, n_points, feature_dim]
# will return a tensor of size [batch] of distances between a[i] and b[i] for each i
然而,我想计算成对的距离,其中的结果张量应该是大小[batch, batch]
。为了达到这个目的,我尝试使用以下广播方式:
geomloss.SamplesLoss('sinkhorn')(a.unsqueeze(0), b.unsqueeze(1))
但我收到了一条错误消息:
ValueError:示例
x
和y
应该具有相同的批次大小.
发布于 2020-12-04 21:49:50
由于文档没有给出如何使用距离的前向函数的示例。这里有一种方法,这将要求您调用距离函数batch
倍。
我们将逐行构造距离矩阵。行i
对应于从a[i]<->b[0]
,a[i]<->b[1]
到a[i]<->b[batch]
的距离。为此,我们需要为每一行i
构造一个(8x4)
重复版本的张量a[i]
。
这样做可以:
a_i = torch.stack(8*[a[i]], dim=0)
然后计算出a[i]
与b
中的每一批之间的距离。
dist(a_i.unsqueeze(1), b.unsqueeze(1))
有了总共的batch
线,我们就可以构造出最后的张量stack
。
以下是完整的代码:
batch = a.shape[0]
dist = geomloss.SamplesLoss('sinkhorn')
distances = [dist(torch.stack(batch*[a[i]]).unsqueeze(1), b.unsqueeze(1)) for i in range(batch)]
D = torch.stack(distances)
https://stackoverflow.com/questions/65150672
复制相似问题