cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
c = torch.FloatTensor([1, 2, 4])
b = torch.FloatTensor([1, 2, 3])
simi = cos(b,c)
tensor(0.9915)
我在这个函数中使用了dim=-1,这是否意味着它是一个一维浮点列表?这是正确的吗?
发布于 2022-01-20 22:15:30
与大多数python中的索引一样,-1指的是最后一个维度(-2将是第二到最后的,等等.)。当初始化余弦相似度时,使用dim=-1
表示将沿着输入的最后一维计算余弦相似度。
例如,如果b
和c
是尺寸为[X,Y,Z]
的三维张量,则结果是尺寸为[X,Y]
的二维张量。在这种情况下,由于输入张量只有一维(大小为[3]
),因此最终得到了大小为[]
的结果张量,即标量。
https://stackoverflow.com/questions/70793278
复制相似问题