前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >PyTorch入门笔记-改变张量的形状

PyTorch入门笔记-改变张量的形状

作者头像
触摸壹缕阳光
发布2021-01-03 15:58:58
4.1K0
发布2021-01-03 15:58:58
举报

view和reshape

PyTorch 中改变张量形状有 view、reshape 和 resize_ (没有原地操作的resize方法未来会被丢弃) 三种方式,「其中 resize_ 比较特殊,它能够在修改张量形状的同时改变张量的大小,而 view 和 reshape 方法不能改变张量的大小,只能够重新调整张量形状。」

resize_ 方法比较特殊,后续用到的时候再详细介绍。本文主要介绍 view 和 reshape 方法,在 PyTorch 中 view 方法存在很长时间,reshape 方法是在 PyTorch0.4 的版本中引入,两种方法功能上相似,但是一些细节上稍有不同,因此这里介绍两个方法的不同之处。

  • view 只能用于数据连续存储的张量,而 reshape 则不需要考虑张量中的数据是否连续存储

nD 张量底层实现是使用一块连续内存的一维数组,由于 PyTorch 底层实现是 C 语言 (C/C++ 使用行优先方式),所以n维张量也使用行优先方式。比如对于下面形状为 (3 x 3) 的 2D 张量:

2D 张量在内存中实际以一维数组的形式进行存储,行优先的方式指的是存储的顺序按照 2D 张量的行依次存储。

上面形状为 (3 x 3) 的 2D 张量通常称为存储的逻辑结构,而实际存储的一维数组形式称为存储的物理结构。

  1. 如果元素在存储的逻辑结构上相邻,在存储的物理结构中也相邻,则称为连续存储的张量;
  2. 如果元素在存储的逻辑结构上相邻,但是在存储的物理结构中不相邻,则称为不连续存储的张量;

为了方便理解代码,先来熟悉一些方法。

  • 可以通过 tensor.is_contiguous() 来查看 tensor 是否为连续存储的张量;
  • PyTorch 中的转置操作能够将连续存储的张量变成不连续存储的张量;
代码语言:javascript
复制
>>> import torch
>>> a = torch.tensor([[0, 1, 2],
                      [3, 4, 5],
                      [6, 7, 8]])
>>> # 判断a是否为连续存储的张量
>>> print(a.is_contiguous())

True

>>> view_a = a.view(1, 9)
>>> reshape_a = a.reshape(9, 1)
>>> # 通过转置操作将a变成不连续存储的张量
>>> a.t_()
>>> # 判断a是否为连续存储的张量
>>> print(a.is_contiguous())

False

>>> # view_t_a = a.view(1, 9) error
>>> reshape_t_a = a.reshape(1, 9)

其中 view_t_a = a.view(1, 9) 会抛出异常,再次验证了 view 只能用于数据连续存储的张量,而 reshape 则不需要考虑张量中的数据是否连续存储。

  • view 方法会返回原始张量的视图,而 reshape 方法可能返回的是原始张量的视图或者拷贝

原始张量的视图简单来说就是和原始张量共享数据,因此如果改变使用 view 方法返回的新张量,原始张量也会发生相对应的改变。

代码语言:javascript
复制
>>> import torch
>>> a = torch.tensor([[0, 1, 2],
                  [3, 4, 5],
                  [6, 7, 8]])
>>> view_a = a.view(1, 9)
>>> print(view_a)

tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8]])

>>> # 更改张量中的元素值
>>> view_a[:, 1] = 100
>>> print(a)

tensor([[  0, 100,   2],
        [  3,   4,   5],
        [  6,   7,   8]])

>>> print(view_a)

tensor([[  0, 100,   2,   3,   4,   5,   6,   7,   8]])

reshape 方法可能返回的是原始张量的视图或者拷贝,当处理连续存储的张量 reshape 返回的是原始张量的视图,而当处理不连续存储的张量 reshape 返回的是原始张量的拷贝。

代码语言:javascript
复制
>>> import torch
>>> a = torch.tensor([[0, 1, 2],
                      [3, 4, 5],
                      [6, 7, 8]])
>>> # 判断a是否为连续存储的张量
>>> print(a.is_contiguous())

True

>>> reshape_a = a.reshape(1, 9)
>>> # 更改新张量的元素值
>>> reshape_a[:, 1] = 100
>>> print(a)

tensor([[  0, 100,   2],
        [  3,   4,   5],
        [  6,   7,   8]])

>>> print(reshape_a)

tensor([[  0, 100,   2,   3,   4,   5,   6,   7,   8]])
代码语言:javascript
复制
>>> import torch
>>> a = torch.tensor([[0, 1, 2],
                      [3, 4, 5],
                      [6, 7, 8]])
>>> # 通过转置将a变成不连续存储的张量
>>> a.t_()
>>> # 判断a是否为连续存储的张量
>>> print(a.is_contiguous())

False

>>> reshape_a = a.reshape(1, 9)
>>> # 更改新张量的元素值
>>> reshape_a[:, 1] = 100
>>> print(a)

tensor([[0, 3, 6],
        [1, 4, 7],
        [2, 5, 8]])

>>> print(reshape_a)

tensor([[  0, 100,   6,   1,   4,   7,   2,   5,   8]])
本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-12-14,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI机器学习与深度学习算法 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • view和reshape
相关产品与服务
对象存储
对象存储(Cloud Object Storage,COS)是由腾讯云推出的无目录层次结构、无数据格式限制,可容纳海量数据且支持 HTTP/HTTPS 协议访问的分布式存储服务。腾讯云 COS 的存储桶空间无容量上限,无需分区管理,适用于 CDN 数据分发、数据万象处理或大数据计算与分析的数据湖等多种场景。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档