首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >火炬聚集中等维度

火炬聚集中等维度
EN

Stack Overflow用户
提问于 2020-12-20 18:33:04
回答 3查看 343关注 0票数 4

a是一个(n, d, l)张量。设indices是一个包含索引的(n, 1)张量。我想从indices给出的指数的中维张量中的a中收集。因此,所得到的张量的形状为(n, l)

代码语言:javascript
运行
复制
n = 3
d = 2
l = 3

a = tensor([[[ 0,  1,  2],
             [ 3,  4,  5]],

            [[ 6,  7,  8],
             [ 9, 10, 11]],

            [[12, 13, 14],
             [15, 16, 17]]])

indices = tensor([[0],
                  [1],
                  [0]])

# Shape of result is (n, l)
result = tensor([[ 0,  1,  2],  # a[0, 0, :] since indices[0] == 0

                 [ 9, 10, 11],  # a[1, 1, :] since indices[1] == 1

                 [12, 13, 14]]) # a[2, 0, :] since indices[2] == 0

这确实类似于a.gather(1, indices),但gather不会工作,因为indicesa的形状不同。如何在此设置中使用gather?或者我应该使用什么?

EN

回答 3

Stack Overflow用户

发布于 2020-12-20 19:48:23

您可以手动创建索引。如果indices张量具有示例数据的形状,则必须将其展平。

代码语言:javascript
运行
复制
a[torch.arange(len(a)),indices.view(-1)]
# equal to a[[0,1,2],[0,1,0]]

输出:

代码语言:javascript
运行
复制
tensor([[ 0,  1,  2],
        [ 9, 10, 11],
        [12, 13, 14]])
票数 1
EN

Stack Overflow用户

发布于 2021-05-09 18:29:06

我将我的答案添加到Michael的答案之上,以获得索引维度两边的更多维度,但我希望有人给我一个不使用arange的更好的答案!

代码语言:javascript
运行
复制
def squeeze_index(x, dim, index):
  # flatten to rows
  y = x.view((-1,) + x.shape[dim:])

  # generate row indices
  rows = torch.arange(y.shape[0])

  # index and reshape
  result_shape = x.shape[:dim] + (x.shape[dim+1:] if dim != -1 else ())
  return y[rows, index.view(-1), ...].view(result_shape)

a = torch.arange(2*3*2*3).reshape((2,3,2,3))
indices = torch.tensor([0,0,1,0,0,1]).reshape((2,3))
result = squeeze_index(a, 2, i)
print("a", a.shape, a)
print("indices", indices.shape, indices)
print("result", result.shape, result)

提供:

代码语言:javascript
运行
复制
a torch.Size([2, 3, 2, 3]) tensor([[[[ 0,  1,  2],
          [ 3,  4,  5]],

         [[ 6,  7,  8],
          [ 9, 10, 11]],

         [[12, 13, 14],
          [15, 16, 17]]],


        [[[18, 19, 20],
          [21, 22, 23]],

         [[24, 25, 26],
          [27, 28, 29]],

         [[30, 31, 32],
          [33, 34, 35]]]])
indices torch.Size([2, 3]) tensor([[0, 0, 1],
        [0, 0, 1]])
result torch.Size([2, 3, 3]) tensor([[[ 0,  1,  2],
         [ 6,  7,  8],
         [15, 16, 17]],

        [[18, 19, 20],
         [24, 25, 26],
         [33, 34, 35]]])
票数 0
EN

Stack Overflow用户

发布于 2021-06-30 23:55:25

在使用聚集函数之前,重塑索引,这是一个示例

代码语言:javascript
运行
复制
def gather_righthand(src, index, check=True):
    index = index.long()
    i_dim = index.dim(); s_dim = src.dim(); t_dim = i_dim-1
    if check:
        assert s_dim > i_dim
        for d in range(0, t_dim): 
            assert src.shape[d] == index.shape[d]
    index_new_shape = list(src.shape)
    index_new_shape[t_dim] = index.shape[t_dim]
    for _ in range(i_dim, s_dim): index = index.unsqueeze(-1)

    index_expand = index.expand(index_new_shape)            # only this two line matters
    return torch.gather(src, dim=t_dim, index=index_expand) # only this two line matters


gather_righthand(a, indices)
tensor([[[ 0.,  1.,  2.]],
        [[ 9., 10., 11.]],
        [[12., 13., 14.]]])
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65378968

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档