gather
torch.gather(*input,dim,index,sparse_grad=False, out=None*)
函数沿着指定的轴 dim 上的索引 index 采集输入张量 input 中的元素值,函数的参数有:
除了 sparse_grad 和 out 两个可选参数,其余三个参数都是必选参数。为了方便这里只考虑必选参数,即 torch.gather(input, dim, index)。
简单介绍完 gather 函数之后,来看一个简单的小例子:一次将下面 2D 张量中所有红色的元素采集出来。
2D 张量可以看成矩阵,2D 张量的第一个维度为矩阵的行 (dim = 0),2D 张量的第二个维度为矩阵的列 (dim = 1),从左向右依次看三个红色元素在矩阵中的具体位置:
通过红色元素的具体位置可以看出,三个红色元素的列索引号是有规律的:从 0 到 2 逐渐递增。假设此时列索引的规律是已知并且固定的,我们只需要给出这些红色元素在行上的索引号就可以将这些红色元素全部采集出来。
至此,对于这个 2D 张量的小例子,已知了输入张量和指定行上的索引号。回顾 torch.gather(input, dim, index) 函数沿着指定轴上的索引采集输入张量的元素值,貌似现在已知的条件和 gather 函数中所需要的参数有些谋和。下面我们来尝试一下使用 gather 函数来采集红色元素。
>>> import torch
>>> x = torch.arange(9).view(3, 3)
>>> print(x)
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>> index = torch.tensor([[2, 0, 1]])
>>> # dim=0: 行上的索引
>>> out = torch.gather(x, dim = 0, index = index)
>>> print(out)
tensor([[6, 1, 5]])
gather 函数的输出结果和我们在小例子中分析的结果一致。
如果按照从上到下来看三个红色元素,采集元素的顺序和从前面从左向右看的时候不同,此时采集元素的顺序为 1, 5, 6,现在看看此时这三个红色元素在矩阵中的具体位置:
现在行索引号是有规律的:从 0 到 2 逐渐递增。现在假设此时行索引的规律是已知并且固定的,我们只需要给出这些红色元素在列上的索引号就可以将这些红色元素全部采集出来了。
>>> import torch
>>> x = torch.arange(9).view(3, 3)
>>> print(x)
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>> index = torch.tensor([[1, 2, 0]]).t()
>>> # dim=1: 在列方向上索引
>>> out = torch.gather(x, dim = 1, index = index)
>>> print(out)
tensor([[1],
[5],
[6]])
在不同轴上 (行或列) 进行索引传入的 index 参数的张量形状不同,在 gather 函数中规定:
接下来使用一个形状为 (3 x 5) 2D 张量来详细的分析 gather 函数的原理。
2D 张量有两个轴,假定现在只采集一个元素:
dim = 0 表示在行上索引,此时假定已知且固定了在列上的索引,即 (其中 ? 为待采集元素在行上的索引号):
如果想要使用 gather 函数采集元素,需要在 index 中指定 5 个行索引号,而每列只索引一个元素且在行上索引 (dim = 0),因此最终我们需要传入 index 张量的形状为 (1, 5),其中的元素值为待采集元素的行索引号。
dim = 1 表示在列上索引,此时假定已知且固定了在行上的索引,即 (其中 ? 为待采集元素在列上的索引号):
如果想要使用 gather 函数采集元素,需要在 index 中指定 3 个列索引号,而每行只索引一个元素且在列上索引 (dim = 1),因此最终我们需要传入 index 张量的形状为 (1, 3),其中的元素值为待采集元素的列索引号。
最后来看看如何使用 gather 函数每行采集两个元素:
>>> import torch
>>> x = torch.arange(15).view(3, 5)
>>> index = torch.LongTensor([[0, 1], [2, 3], [1, 2]])
>>> out = torch.gather(x, dim = 1, index = index)
>>> print(out)
tensor([[ 0, 1],
[ 7, 8],
[11, 12]])
传入 index 的张量形状为 (3 x 2),因此最终输出张量的形状也为 (3 x 2)。dim = 1 表示在列上索引,此时假定已知且固定了在行上的索引:
本文分享自 AI机器学习与深度学习算法 微信公众号,前往查看
如有侵权,请联系 cloudcommunity@tencent.com 删除。
本文参与 腾讯云自媒体同步曝光计划 ,欢迎热爱写作的你一起参与!