设a
是一个(n, d, l)
张量。设indices
是一个包含索引的(n, 1)
张量。我想从indices
给出的指数的中维张量中的a
中收集。因此,所得到的张量的形状为(n, l)
。
n = 3
d = 2
l = 3
a = tensor([[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]],
[[12, 13, 14],
[15, 16, 17]]])
indices = tensor([[0],
[1],
[0]])
# Shape of result is (n, l)
result = tensor([[ 0, 1, 2], # a[0, 0, :] since indices[0] == 0
[ 9, 10, 11], # a[1, 1, :] since indices[1] == 1
[12, 13, 14]]) # a[2, 0, :] since indices[2] == 0
这确实类似于a.gather(1, indices)
,但gather
不会工作,因为indices
与a
的形状不同。如何在此设置中使用gather
?或者我应该使用什么?
发布于 2020-12-20 19:48:23
您可以手动创建索引。如果indices
张量具有示例数据的形状,则必须将其展平。
a[torch.arange(len(a)),indices.view(-1)]
# equal to a[[0,1,2],[0,1,0]]
输出:
tensor([[ 0, 1, 2],
[ 9, 10, 11],
[12, 13, 14]])
发布于 2021-05-09 18:29:06
我将我的答案添加到Michael的答案之上,以获得索引维度两边的更多维度,但我希望有人给我一个不使用arange
的更好的答案!
def squeeze_index(x, dim, index):
# flatten to rows
y = x.view((-1,) + x.shape[dim:])
# generate row indices
rows = torch.arange(y.shape[0])
# index and reshape
result_shape = x.shape[:dim] + (x.shape[dim+1:] if dim != -1 else ())
return y[rows, index.view(-1), ...].view(result_shape)
a = torch.arange(2*3*2*3).reshape((2,3,2,3))
indices = torch.tensor([0,0,1,0,0,1]).reshape((2,3))
result = squeeze_index(a, 2, i)
print("a", a.shape, a)
print("indices", indices.shape, indices)
print("result", result.shape, result)
提供:
a torch.Size([2, 3, 2, 3]) tensor([[[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]],
[[12, 13, 14],
[15, 16, 17]]],
[[[18, 19, 20],
[21, 22, 23]],
[[24, 25, 26],
[27, 28, 29]],
[[30, 31, 32],
[33, 34, 35]]]])
indices torch.Size([2, 3]) tensor([[0, 0, 1],
[0, 0, 1]])
result torch.Size([2, 3, 3]) tensor([[[ 0, 1, 2],
[ 6, 7, 8],
[15, 16, 17]],
[[18, 19, 20],
[24, 25, 26],
[33, 34, 35]]])
发布于 2021-06-30 23:55:25
在使用聚集函数之前,重塑索引,这是一个示例
def gather_righthand(src, index, check=True):
index = index.long()
i_dim = index.dim(); s_dim = src.dim(); t_dim = i_dim-1
if check:
assert s_dim > i_dim
for d in range(0, t_dim):
assert src.shape[d] == index.shape[d]
index_new_shape = list(src.shape)
index_new_shape[t_dim] = index.shape[t_dim]
for _ in range(i_dim, s_dim): index = index.unsqueeze(-1)
index_expand = index.expand(index_new_shape) # only this two line matters
return torch.gather(src, dim=t_dim, index=index_expand) # only this two line matters
gather_righthand(a, indices)
tensor([[[ 0., 1., 2.]],
[[ 9., 10., 11.]],
[[12., 13., 14.]]])
https://stackoverflow.com/questions/65378968
复制相似问题