前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >PyTorch入门笔记-索引和切片

PyTorch入门笔记-索引和切片

作者头像
触摸壹缕阳光
修改2022-04-26 15:49:32
3.2K0
修改2022-04-26 15:49:32
举报

前言

切片其实也是索引操作,所以切片经常被称为切片索引,为了更方便叙述,本文将切片称为切片索引。索引和切片操作可以帮助我们快速提取张量中的部分数据。

1. 基本索引

PyTorch 支持与 Python 和 NumPy 类似的基本索引操作,PyTorch 中的基本索引可以通过整数值来索引张量。

代码语言:javascript
复制
>>> import torch
>>> # 构造形状为3x3,元素值从0到8的2D张量
>>> a = torch.arange(0, 9).view([3, 3])
>>> print(a)

tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]]) 

>>> print(a[0]) # 索引张量a的第一行

tensor([0, 1, 2])

>>> print(a[0][1]) # 索引张量a的第一行和第二列

tensor(1)

变量 a 是一个(3 x 3)的 2D 张量,即张量 a 包含两个维度:

  • 第一个维度,在 2D 张量中称为行维度;
  • 第二个维度,在 2D 张量中称为列维度;

a[0]表示在张量 a 的行维度上取索引号为 0 的元素(第一行);a[0][1]表示在张量 a 的行维度上取索引号为 0 的元素(第一行)以及在列维度上取索引号为 1 的元素(第二列),获取行维度和列维度上的元素集合的交集(位于第一行第二列上的元素集合)即为最终的索引结果。简单来说,[i][j]...[k]中的每一个[]都表示张量的一个维度,从左边开始维度依次增加,而[]中的元素值代表对应维度的索引号,「此时的索引号可以为负数,相当于从后向前索引。」

代码语言:javascript
复制
>>> import torch
>>> # 构造形状为3x3,元素值从0到8的2D张量
>>> a = torch.arange(0, 9).view([3, 3])
>>> print(a)

tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])

>>> print(a[-1]) # 索引张量a的最后一行

tensor([6, 7, 8])

「当张量的维度数较高的时候,使用[i][j]...[k]**的方式书写非常不方便,可以采用[i, j,...,k]的方式,两种方式是等价的。」

代码语言:javascript
复制
>>> import torch
>>> # 构造形状为2x2x3,元素值从0到11的3D张量
>>> a = torch.arange(12).view([2, 2, 3])
>>> print(a)

tensor([[[ 0,  1,  2],
         [ 3,  4,  5]],

        [[ 6,  7,  8],
         [ 9, 10, 11]]])

>>> # 第一个维度取索引号为0的元素
>>> # 第二个维度取索引号为1的元素
>>> # 第三个维度取索引号为2的元素
>>> # 满足这三个条件的元素即为索引结果
>>> print(a[0, 1, 2])

tensor(5)

>>> # 通过基本索引修改元素值
>>> a[0, 1, 2] = 100
>>> print(a)

tensor([[[  0,   1,   2],
         [  3,   4, 100]],

        [[  6,   7,   8],
         [  9,  10,  11]]])

通过对比原始张量 a 和通过基本索引的方式修改元素值之后的张量 a 可以发现,「通过基本索引出来的结果与原始的张量共享内存,如果修改一个,另一个也会被修改。」

2. 切片索引

通过 [start: end: steps](起始位置为start,终止位置为end,步长为steps)的方式索引连续的张量子集。以形状为 [4, 3, 28, 28] 的图片张量为例,在 PyTorch 中图片张量的格式为 [batch_size, channel, width, hight],[4, 3, 28, 28] 的图片张量表示 4 张拥有 RGB 三个通道且每个通道为 (28 x 28) 的像素矩阵。

代码语言:javascript
复制
>>> import torch
>>> # 模拟4张拥有RGB三个通道且每个通道为(28 x 28)的像素矩阵
>>> a = torch.rand(4, 3, 28, 28)
>>> # 读取前2张图片
>>> print(a[:2].size())

torch.Size([2, 3, 28, 28])

>>> # 读取前两张图片的R通道的28x28的像素矩阵
>>> print(a[:2, :1, :, :].size())

torch.Size([2, 1, 28, 28])

>>> # 读取前两张图片的GB通道的28x28的像素矩阵
>>> print(a[:2, 1:, :, :].size())

torch.Size([2, 2, 28, 28])

>>> # 读取前两张图片的B通道的28x28的像素矩阵
>>> print(a[:2, -1:, :, :].size())

torch.Size([2, 1, 28, 28])

start: end: step切片方式有很多简写方式,其中 start、end、step 3 个参数可以根据需要选择性的省略,全部省略时即为::,表示从最考试读取到最末尾,步长为 1,即不跳过任何元素。如 x[0,::] 表示读取第一张图片的的所有通道的像素矩阵,其中::表示在通道维度上读取所有RGB三个通道,它等价于 x[0] 的写法。通常为了简洁,将::简写成单个冒号。

代码语言:javascript
复制
>>> import torch
>>> # 模拟4张拥有RGB三个通道且每个通道为(28 x 28)的像素矩阵
>>> a = torch.rand(4, 3, 28, 28)
>>> # 读取第一张图片
>>> print(a[0,::].size())

torch.Size([3, 28, 28])

>>> # 为了更加简介,::可以简写为单个冒号:
>>> print(a[0,:].size())

torch.Size([3, 28, 28])

接下来总结一下start: end: step 切片的简写方式,其中从第一个元素读取时 start 可以省略,即 start = 0 是可以省略的,取到最后一个元素时 end 可以省略,步长为 1 时 step 可以省略,简写方式总结如表 4.1:

「还有点需要注意,在 PyTorch 中切片索引中的步长不能小于0,即不能为负数。」

代码语言:javascript
复制
>>> import torch
>>> # 创建元素值为0~8的1D张量
>>> a = torch.arange(9)
>>> print(a)

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])

>>> print(a[4: 0: -2])

Traceback (most recent call last):
  File "/home/chenkc/code/tensor.py", line 44, in <module>
    print(a[4: 0: -2])
ValueError: step must be greater than zero

当张量的维度数量较多时,不需要采样的维度一般用单冒号 : 表示采样所有元素,此时有可能出现大量的 : 出现。

代码语言:javascript
复制
>>> import torch
>>> # 模拟4张拥有RGB三个通道且每个通道为(28 x 28)的像素矩阵
>>> a = torch.rand(4, 3, 28, 28)
>>> # 获取4张图片的RGB三个通道的所有行和第三列像素矩阵
>>> print(a[:, :, :, 2].size())

torch.Size([4, 3, 28])

「为了避免出现像x[:, :, :, 2] 这样过多冒号的情况,可以使用...符号表示取多个维度上所有数据,其中维度的数量需要根据规则自动推断:当切片方式出现...符号时,...符号左边的维度将自动对齐到最左边,...符号右边的维度将自动对齐到最右边,此时系统再自动推断...符号代表的维度张量,」 它的切片方式总结如表 4.2 所示(「其中表中的···都为...」)。

3. 高级索引

PyTorch 支持绝大多数 NumPy 的高级索引,高级索引可以看成是基本索引的扩展。

代码语言:javascript
复制
>>> import torch
>>> a = torch.arange(9).view([3, 3])

>>> print(a)

tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])

>>> print(a[[0, 1],...])

tensor([[0, 1, 2],
        [3, 4, 5]])

>>> print(a[[0, 1], [1, 2]])

tensor([1, 5])

>>> print(a[[1, 0, 2], [0]])

tensor([3, 0, 6])

这里给出了 PyTorch 中的三种高级索引方式,通过这些高级索引的输出结果,可以看出这些高级索引的本质。

  • a[[0, 1, ...]] 等价 a[0] 和 a[1],相当于索引张量的第一行和第二行元素;
  • a[[0, 1, 1, 2]] 等价 a[0, 1] 和 a[1, 2],相当于索引张量的第一行的第二列和第二行的第三列元素;
  • a[[1, 0, 2, 0]] 等价 a[1, 0] 和 a[0, 0] 和 a[2, 0],相当于索引张量的第二行第一列的元素、张量第一行和第一列的元素以及张量第三行和第一列的元素;

References:

1. 龙良曲深度学习与PyTorch入门实战:https://study.163.com/course/introduction/1208894818.htm

2. 初探Numpy中的花式索引

原文地址:https://mp.weixin.qq.com/s?

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-11-24,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI机器学习与深度学习算法 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 前言
  • 1. 基本索引
  • 2. 切片索引
  • 3. 高级索引
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档