PyTorch 中改变张量形状有 view、reshape 和 resize_ (没有原地操作的resize方法未来会被丢弃) 三种方式,「其中 resize_ 比较特殊,它能够在修改张量形状的同时改变张量的大小,而 view 和 reshape 方法不能改变张量的大小,只能够重新调整张量形状。」
resize_ 方法比较特殊,后续用到的时候再详细介绍。本文主要介绍 view 和 reshape 方法,在 PyTorch 中 view 方法存在很长时间,reshape 方法是在 PyTorch0.4 的版本中引入,两种方法功能上相似,但是一些细节上稍有不同,因此这里介绍两个方法的不同之处。
nD 张量底层实现是使用一块连续内存的一维数组,由于 PyTorch 底层实现是 C 语言 (C/C++ 使用行优先方式),所以n维张量也使用行优先方式。比如对于下面形状为 (3 x 3) 的 2D 张量:
2D 张量在内存中实际以一维数组的形式进行存储,行优先的方式指的是存储的顺序按照 2D 张量的行依次存储。
上面形状为 (3 x 3) 的 2D 张量通常称为存储的逻辑结构,而实际存储的一维数组形式称为存储的物理结构。
为了方便理解代码,先来熟悉一些方法。
tensor.is_contiguous()
来查看 tensor 是否为连续存储的张量;>>> import torch
>>> a = torch.tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>> # 判断a是否为连续存储的张量
>>> print(a.is_contiguous())
True
>>> view_a = a.view(1, 9)
>>> reshape_a = a.reshape(9, 1)
>>> # 通过转置操作将a变成不连续存储的张量
>>> a.t_()
>>> # 判断a是否为连续存储的张量
>>> print(a.is_contiguous())
False
>>> # view_t_a = a.view(1, 9) error
>>> reshape_t_a = a.reshape(1, 9)
其中 view_t_a = a.view(1, 9)
会抛出异常,再次验证了 view 只能用于数据连续存储的张量,而 reshape 则不需要考虑张量中的数据是否连续存储。
原始张量的视图简单来说就是和原始张量共享数据,因此如果改变使用 view 方法返回的新张量,原始张量也会发生相对应的改变。
>>> import torch
>>> a = torch.tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>> view_a = a.view(1, 9)
>>> print(view_a)
tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8]])
>>> # 更改张量中的元素值
>>> view_a[:, 1] = 100
>>> print(a)
tensor([[ 0, 100, 2],
[ 3, 4, 5],
[ 6, 7, 8]])
>>> print(view_a)
tensor([[ 0, 100, 2, 3, 4, 5, 6, 7, 8]])
reshape 方法可能返回的是原始张量的视图或者拷贝,当处理连续存储的张量 reshape 返回的是原始张量的视图,而当处理不连续存储的张量 reshape 返回的是原始张量的拷贝。
>>> import torch
>>> a = torch.tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>> # 判断a是否为连续存储的张量
>>> print(a.is_contiguous())
True
>>> reshape_a = a.reshape(1, 9)
>>> # 更改新张量的元素值
>>> reshape_a[:, 1] = 100
>>> print(a)
tensor([[ 0, 100, 2],
[ 3, 4, 5],
[ 6, 7, 8]])
>>> print(reshape_a)
tensor([[ 0, 100, 2, 3, 4, 5, 6, 7, 8]])
>>> import torch
>>> a = torch.tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>> # 通过转置将a变成不连续存储的张量
>>> a.t_()
>>> # 判断a是否为连续存储的张量
>>> print(a.is_contiguous())
False
>>> reshape_a = a.reshape(1, 9)
>>> # 更改新张量的元素值
>>> reshape_a[:, 1] = 100
>>> print(a)
tensor([[0, 3, 6],
[1, 4, 7],
[2, 5, 8]])
>>> print(reshape_a)
tensor([[ 0, 100, 6, 1, 4, 7, 2, 5, 8]])
本文分享自 AI机器学习与深度学习算法 微信公众号,前往查看
如有侵权,请联系 cloudcommunity@tencent.com 删除。
本文参与 腾讯云自媒体同步曝光计划 ,欢迎热爱写作的你一起参与!
扫码关注腾讯云开发者
领取腾讯云代金券
Copyright © 2013 - 2025 Tencent Cloud. All Rights Reserved. 腾讯云 版权所有
深圳市腾讯计算机系统有限公司 ICP备案/许可证号:粤B2-20090059 深公网安备号 44030502008569
腾讯云计算(北京)有限责任公司 京ICP证150476号 | 京ICP备11018762号 | 京公网安备号11010802020287
Copyright © 2013 - 2025 Tencent Cloud.
All Rights Reserved. 腾讯云 版权所有