前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >PyTorch入门笔记-增删张量的维度

PyTorch入门笔记-增删张量的维度

作者头像
触摸壹缕阳光
发布2021-01-03 15:55:49
4.5K0
发布2021-01-03 15:55:49
举报

增加维度

增加一个长度为 1 的维度相当于给原有的张量添加一个新维度的概念。由于增加的新维度长度为 1,因此张量中的元素并没有发生改变,仅仅改变了张量的理解方式。比如一张 大小的灰度图片保存为形状为 的张量,在张量的头部增加一个长度为 1 的新维度,定义为通道数维度,此时张量的形状为 。

“图片张量的形状有两种约定:

  • 通道在后的约定。TensorFlow 将通道维度放在最后: ;
  • 通道在前的约定。PyTorch 将通道维度放在前面:

使用 torch.unsqueeze(input, dim) 可以在指定的 dim 维度前插入一个长度为 1 的新维度。

代码语言:javascript
复制
>>> import torch
>>> # 使用随机生成的正态分布模拟没有通道维度的图片张量
>>> input = torch.randn(28, 28)
>>> print(input.size())

torch.Size([28, 28])

>>> # 指定第0个维度前面插入新的维度
>>> image = torch.unsqueeze(input, dim = 0)
>>> print(image.size())

torch.Size([1, 28, 28])

「需要注意的是,torch.unsqueeze(input, dim) 的 dim 参数既可以为正整数也可以为负整数:」

  • 当 dim 为正整数时,表示在当前维度之前插入一个长度为 1 的新维度;
  • 当 dim 为负整数时,表示在当前维度之后插入一个长度为 1 的新维度;

以 张量为例 (为了方便叙述将其简写成 ),不同 dim 参数的实际插入位置如下所示。

通过上图可以看出,无论 dim 参数值是正整数还是负整数,其具体范围都和输入张量的维度有关。对于输入张量为 的图片张量而言,张量的维度为 4,其 dim 参数的取值范围为 ,对比不同维度的输入张量:

  • 输入张量的维度 input.dim() = 2 时,dim 参数的取值范围为
  • 输入张量的维度 input.dim() = 3 时,dim 参数的取值范围为

得到 dim 参数的取值范围为 ,其中 input.dim() 为输入张量的维度。

如果指定 dim 参数超过其取值范围,会抛出 IndexError。

代码语言:javascript
复制
>>> import torch
>>> # 使用随机生成的正态分布模拟[b,c,h,w]
>>> input = torch.randn(1, 1, 28, 28)
>>> print(input.size())

torch.Size([1, 1, 28, 28])

>>> print(input.dim())

4

>>> # input.dim() = 4
>>> # [-4-1, 4+1) = [-5, 5)
>>> # 将dim设置为5,超出dim参数的取值范围
>>> # x = torch.unsqueeze(input, dim = 5) error
>>> # print(x.size())

Traceback (most recent call last):
  File "/home/chenkc/code/pytorch/test_02.py", line 19, in <module>
    x = torch.unsqueeze(input, dim = 5)
IndexError: Dimension out of range (expected to be in range of [-5, 4], but got 5)

删除维度

删除维度是增加维度的逆操作,与增加维度一样,「删除维度只能删除长度为 1 的维度,同时也不会改变张量的存储」。对于形状为 的张量来说,如果希望将批量维度删除 (batch_size 通常称为批量维度),可以通过 torch.squeeze(input, dim) 函数,「dim 参数为待删除维度的索引号。」

例如,删除形状为 图片张量的批量维度。

代码语言:javascript
复制
>>> import torch
>>> # 使用随机生成的正态分布模拟[b,c,h,w]
>>> input = torch.randn(1, 1, 28, 28)
>>> print(input.size())

torch.Size([1, 1, 28, 28])

>>> # squeeze函数中dim参数为待删除维度的索引号
>>> # [b,c,h,w]中批量维度的索引为0
>>> x = torch.squeeze(input, dim = 0)
>>> print(x.size())

torch.Size([1, 28, 28])

与增加维度的 torch.unsqueeze(input, dim) 中 dim 参数不同,在 torch.squeeze(input, dim) 中 dim 参数表示待删除维度的索引号。同样以 张量为例 (为了方便叙述将其简写成 ),不同 dim 参数的实际删除的维度如下所示。

如果不指定维度参数 dim,即 torch.squeeze(input),它会默认的删除所有长度为 1 的维度。

代码语言:javascript
复制
>>> import torch
>>> # 使用随机生成的正态分布模拟[b,c,h,w]
>>> input = torch.randn(1, 1, 28, 28)
>>> print(input.size())

torch.Size([1, 1, 28, 28])

>>> # 不指定dim参数默认删除所有长度为1的唯独
>>> x = torch.squeeze(input)
>>> print(x.size())

torch.Size([28, 28])

小结

Tips: 在 torch.squeeze(input, dim) 函数中,如果不指定维度参数 dim,即 dim = None 时,它默认会删除输入张量中所有长度为 1 的维度。

References:

  1. 《TensorFlow深度学习》
本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-12-21,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI机器学习与深度学习算法 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 增加维度
  • 删除维度
  • 小结
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档