首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >pytorch基础知识-维度变换-(中)

pytorch基础知识-维度变换-(中)

作者头像
用户6719124
发布2019-11-17 23:15:34
发布2019-11-17 23:15:34
83000
代码可运行
举报
运行总次数:0
代码可运行

下面介绍sequeeze和 unsqueeze API

squeeze为维度挤压

unsqueeze为维度展开

用法:

API: a.unsqueeze(dim)

代码语言:javascript
代码运行次数:0
运行
复制
print(a.unsqueeze(0).shape)
# 在原[4, 1, 28, 28]的第0个维度上插入一个维度

代码语言:javascript
代码运行次数:0
运行
复制
torch.Size([1, 4, 1, 28, 28])

同理也可以进行其他类型转换

代码语言:javascript
代码运行次数:0
运行
复制
print(a.unsqueeze(-1).shape)
# 在原[4, 1, 28, 28]的第后面那个维度上插入一个维度
print(a.unsqueeze(-2).shape)
# 在原[4, 1, 28, 28]的第倒数第2个维度上插入一个维度

代码语言:javascript
代码运行次数:0
运行
复制
torch.Size([4, 1, 28, 28, 1])
torch.Size([4, 1, 28, 1, 28])

使用时也要注意设定的维度是否具有合理性,如正着数是0~4,倒着数是-1~-5。

代码语言:javascript
代码运行次数:0
运行
复制
print(a.unsqueeze(-5).shape)
# 在原[4, 1, 28, 28]的第后面那个维度上插入一个维度
print(a.unsqueeze(5).shape)
# 在原[4, 1, 28, 28]的第倒数第2个维度上插入一个维度

输出结果为

代码语言:javascript
代码运行次数:0
运行
复制
torch.Size([1, 4, 1, 28, 28])

物理意义上来说可以理解为在batch前面加了一个组

第一个可以正常输出,后面那个因维度不存在,程序报错

代码语言:javascript
代码运行次数:0
运行
复制
Traceback (most recent call last):
  File "E:/公众号/维度变换/1.1.py", line 24, in <module>
    print(a.unsqueeze(5).shape)
IndexError: Dimension out of range (expected to be in range of [-5, 4], but got 5)

这里可以总结为括号内的index的范围为[-a.dim()-1, a.dim()+1)。注意是包含前面不包含后面(括号类型)。

先将插入的idx和插入位置总结为下图

另外

为更深入了解,我们先构建一个tensor

代码语言:javascript
代码运行次数:0
运行
复制
b = torch.tensor([1.1, 2.2])
print(b)
print(b.shape)
c = b.unsqueeze(-1)
print(c)
print(c.shape)
d = b.unsqueeze(0)
print(d)
print(d.shape)

分别输出为

代码语言:javascript
代码运行次数:0
运行
复制
tensor([1.1000, 2.2000])
torch.Size([2])
tensor([[1.1000],
        [2.2000]])
torch.Size([2, 1])
tensor([[1.1000, 2.2000]])
torch.Size([1, 2])

由此可见确定好插入维度的位置,对于进行维度变换十分重要

下面以具体例子介绍维度变换操作

代码语言:javascript
代码运行次数:0
运行
复制
# 先构建一个32的通道
b = torch.rand(32)
# 再设定一个转换的最终目标
f = torch.rand(4, 32, 14, 14)
# 思考如何通过维度变换将b转化为f
# 本节考虑先将[32]变为[1, 32, 1, 1],后续通过维度扩张的讲解再进行后续步骤
b = b.unsqueeze(1).unsqueeze(2).unsqueeze(0)
# 进行了[32]=>[32, 1]=>[32, 1, 1]=>[1, 32, 1, 1]
print(b.shape)

输出为

代码语言:javascript
代码运行次数:0
运行
复制
torch.Size([1, 32, 1, 1])

这种维度增加方式在以后的图片处理中十分常见,要求务必掌握。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-10-05,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 python pytorch AI机器学习实践 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档