前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >PyTorch入门笔记-index_select选择函数

PyTorch入门笔记-index_select选择函数

作者头像
触摸壹缕阳光
修改2022-04-26 19:03:27
5K1
修改2022-04-26 19:03:27
举报

1. index_select 选择函数

torch.index_select(input,dim,index,out=None) 函数返回的是沿着输入张量的指定维度的指定索引号进行索引的张量子集,其中输入张量、指定维度和指定索引号就是 torch.index_select(input,dim,index,out=None) 函数的三个关键参数,函数参数有:

  • input(Tensor) - 需要进行索引操作的输入张量;
  • dim(int) - 需要对输入张量进行索引的维度;
  • index(LongTensor) - 包含索引号的 1D 张量;
  • out(Tensor, optional) - 指定输出的张量。比如执行 torch.zeros(2, 2, out = tensor_a),相当于执行 tensor_a = torch.zeros(2, 2);

接下来使用 torch.index_select(input,dim,index,out=None) 函数分别对 1D 张量、2D 张量和 3D 张量进行索引。

代码语言:python
复制
>>> import torch
>>> # 创建1D张量
>>> a = torch.arange(0, 9)
>>> print(a)

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])

>>> # 获取1D张量的第1个维度且索引号为2和3的张量子集
>>> print(torch.index_select(a, dim = 0, index = torch.tensor([2, 3])))

tensor([2, 3])

>>> # 创建2D张量
>>> b = torch.arange(0, 9).view([3, 3])
>>> print(b)

tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])

>>> # 获取2D张量的第2个维度且索引号为0和1的张量子集(第一列和第二列)
>>> print(torch.index_select(b, dim = 1, index = torch.tensor([0, 1])))

tensor([[0, 1],
        [3, 4],
        [6, 7]])

>>> # 创建3D张量
>>> c = torch.arange(0, 9).view([1, 3, 3])
>>> print(c)

tensor([[[0, 1, 2],
         [3, 4, 5],
         [6, 7, 8]]])

>>> # 获取3D张量的第1个维度且索引号为0的张量子集
>>> print(torch.index_select(c, dim = 0, index = torch.tensor([0])))

tensor([[[0, 1, 2],
         [3, 4, 5],
         [6, 7, 8]]])

「由于 index_select 函数只能针对输入张量的其中一个维度的一个或者多个索引号进行索引,因此可以通过 PyTorch 中的高级索引来实现。」

  • 获取 1D 张量 a 的第 1 个维度且索引号为 2 和 3 的张量子集: torch.index_select(a, dim = 0, index = torch.tensor([2, 3])) \iffa[[2, 3]]
  • 获取 2D 张量 b 的第 2 个维度且索引号为 0 和 1 的张量子集(第一列和第二列): torch.index_select(b, dim = 1, index = torch.tensor([0, 1])) \iff b[:, [0, 1]]
  • 创建 3D 张量 c 的第 1 个维度且索引号为 0 的张量子集: torch.index_select(c, dim = 0, index = torch.tensor([0])) \iff c[[0]]

index_select 函数虽然简单,但是有几点需要注意:

  • index 参数必须是 1D 长整型张量 (1D-LongTensor);
代码语言:python
复制
>>> import torch
>>> index1 = torch.tensor([1, 2])
>>> print(index.type())

torch.LongTensor

>>> index2 = torch.tensor([1., 2.])
>>> print(index2.type())

torch.FloatTensor

>>> index3 = torch.tensor([[1, 2]])
>>> # 创建1D张量
>>> a = torch.arange(0, 9)
>>> print(torch.index_select(a, dim = 0, index = index1))

tensor([1, 2])

>>> # print(torch.index_select(a, dim = 0, index = index2))

RuntimeError: index_select(): Expected dtype int64 for index

>>> # print(torch.index_select(a, dim = 0, index = index3))

IndexError: index_select(): Index is supposed to be a vector
  • 使用 index_select 函数输出的张量维度和原始的输入张量维度相同。这也是为什么即使在对输入张量的其中一个维度的一个索引号进行索引 (此时可以使用基本索引和切片索引) 时也需要使用 PyTorch 中的高级索引方式才能与 index_select 函数等价的原因所在;
代码语言:python
复制
>>> import torch
>>> # 创建2D张量
>>> d = torch.arange(0, 4).view([2, 2])
>>> # 使用index_select函数索引
>>> d1 = torch.index_select(d, dim = 0, index = torch.tensor([0]))
>>> print(d1)

tensor([[0, 1]])

>>> print(d1.size())

torch.Size([1, 2])

>>> # 使用PyTorch中的高级索引
>>> d2 = d[[0]]
>>> print(d2)

tensor([[0, 1]])

>>> print(d2.size())

torch.Size([1, 2])

>>> # 使用基本索引和切片索引
>>> d3 = d[0]
>>> print(d3)

tensor([0, 1])

>>> print(d3.size())

torch.Size([2])

通过上面的代码可以看出,三种方式索引出来的张量子集中的元素都是一样的,不同的是索引出来张量子集的形状,index_select 函数对输入张量进行索引可以使用高级索引实现。

References:

1. 龙良曲深度学习与PyTorch入门实战:https://study.163.com/course/introduction/1208894818.htm

原文地址:https://mp.weixin.qq.com/s?

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-11-26,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI机器学习与深度学习算法 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. index_select 选择函数
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档