首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >在PyTorch中计算成批成对辛克霍恩距离

在PyTorch中计算成批成对辛克霍恩距离
EN

Stack Overflow用户
提问于 2020-12-04 21:00:09
回答 1查看 544关注 0票数 0

我有两个张量,两个都是相同的形状。我想用GeomLoss计算成对的辛克霍恩距离。

我尝试过的:

代码语言:javascript
运行
复制
import torch
import geomloss  # pip install git+https://github.com/jeanfeydy/geomloss

a = torch.rand((8,4))
b = torch.rand((8,4))

geomloss.SamplesLoss('sinkhorn')(a,b)
# ^ input shape [batch, feature_dim]
# will return a scalar value

geomloss.SamplesLoss('sinkhorn')(a.unsqueeze(1),b.unsqueeze(1))  
# ^ input shape [batch, n_points, feature_dim]
# will return a tensor of size [batch] of distances between a[i] and b[i] for each i

然而,我想计算成对的距离,其中的结果张量应该是大小[batch, batch]。为了达到这个目的,我尝试使用以下广播方式:

代码语言:javascript
运行
复制
geomloss.SamplesLoss('sinkhorn')(a.unsqueeze(0), b.unsqueeze(1))

但我收到了一条错误消息:

ValueError:示例xy应该具有相同的批次大小.

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-12-04 21:49:50

由于文档没有给出如何使用距离的前向函数的示例。这里有一种方法,这将要求您调用距离函数batch倍。

我们将逐行构造距离矩阵。行i对应于从a[i]<->b[0]a[i]<->b[1]a[i]<->b[batch]的距离。为此,我们需要为每一行i构造一个(8x4)重复版本的张量a[i]

这样做可以:

代码语言:javascript
运行
复制
a_i = torch.stack(8*[a[i]], dim=0)

然后计算出a[i]b中的每一批之间的距离。

代码语言:javascript
运行
复制
dist(a_i.unsqueeze(1), b.unsqueeze(1))

有了总共的batch线,我们就可以构造出最后的张量stack

以下是完整的代码:

代码语言:javascript
运行
复制
batch = a.shape[0]
dist = geomloss.SamplesLoss('sinkhorn')
distances = [dist(torch.stack(batch*[a[i]]).unsqueeze(1), b.unsqueeze(1)) for i in range(batch)]
D = torch.stack(distances)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65150672

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档