专栏首页python pytorch AI机器学习实践pytorch基础知识-维度变换-(上)

pytorch基础知识-维度变换-(上)

维度变换是pytorch中的重要操作,尤其是在图片处理中。本文对pytorch中的维度变换进行讲解。

维度变换有四种操作:

(1)view/reshape (类似于numpy中的reshape操作,可在不改变数据情况下,将tensor转换shape)

(2)Squeeze/unsqueeze (前一个为删减维度,第二个为增加维度)

(3)Transpose/t/permute (矩阵的单次、多次交换操作)

(4)Expand/repeat (维度的扩展)

下面分别介绍view和reshape操作

View和reshape两者可实现相同的操作,他们之所以不同是因为他们分别出现在pytorch 0.3和 0.4版本中。

下面以具体代码介绍其功能

import torch

a = torch.rand(4, 1, 28, 28)
# 创建4张图片,一个通道(黑白照片),28*28像素点的图片
print(a.shape)
# 查看a的shape

b = a.view(4, 28*28)
# 代表将后面28*28的像素点全合并在一起,这样变为[4, 784]的shape,各个像素点失去了上下级的关系,全部被打平。这样处理的数据适合于全连接层的计算
print(b.shape)
print(b)
torch.Size([4, 1, 28, 28])
torch.Size([4, 784])
tensor([[0.0340, 0.7099, 0.4833,  ..., 0.2647, 0.2162, 0.7296],
        [0.3813, 0.8616, 0.8151,  ..., 0.9942, 0.2087, 0.8194],
        [0.9671, 0.1019, 0.7831,  ..., 0.0181, 0.1157, 0.7953],
        [0.5822, 0.3729, 0.2884,  ..., 0.0572, 0.4267, 0.1253]])

同时也可以进行其他类型维度转换

# 进行其他类型维度转换
c = a.view(4*28, 28)
# 代表只关心行数据
print(c.shape)
print(c)

输出

torch.Size([112, 28])
tensor([[0.5150, 0.7596, 0.2775,  ..., 0.5091, 0.2392, 0.8344],
        [0.9011, 0.5765, 0.2927,  ..., 0.5699, 0.3586, 0.8753],
        [0.7888, 0.3819, 0.2285,  ..., 0.0180, 0.8713, 0.2788],
        ...,
        [0.7574, 0.8183, 0.7526,  ..., 0.8110, 0.7561, 0.3782],
        [0.6651, 0.1608, 0.4255,  ..., 0.6095, 0.8237, 0.4504],
        [0.1026, 0.9139, 0.0975,  ..., 0.7457, 0.7392, 0.9503]])

在进行高维向低维进行转换时,

# 进行高维至低维时
d = c.view(8, 14, 7, 4)
print(d.shape)
print(d)
torch.Size([8, 14, 7, 4])
tensor([[[[6.1078e-01, 6.2038e-01, 4.8822e-01, 2.2785e-01],
          [4.3340e-01, 7.3079e-01, 5.3903e-01, 5.7458e-01],
          [2.6486e-01, 6.3135e-01, 6.4279e-01, 1.0438e-01],
          ...,
          [9.3821e-01, 4.5505e-01, 2.4141e-01, 5.1869e-01],
          [6.4698e-01, 2.9054e-01, 4.0393e-01, 3.7668e-01],
          [9.6046e-03, 9.4442e-01, 2.5506e-01, 2.3018e-01]],
 
 ...,
 
         [[5.2430e-01, 8.8357e-02, 7.4498e-01, 2.3804e-01],
          [5.9681e-01, 4.2410e-01, 7.5572e-01, 6.7488e-01],
          [6.7680e-01, 5.3951e-01, 1.5074e-01, 2.8085e-01],
          ...,
          [5.3654e-01, 1.8240e-01, 2.3953e-01, 8.9576e-01],
          [9.3373e-01, 4.6603e-01, 8.9729e-01, 1.0706e-01],
          [7.7679e-01, 7.8289e-01, 3.5601e-01, 4.7687e-02]]]])

由此可见,只要总tensor数值相同,既可以进行任意维度上的转换。但要注意变换后的物理意义,不要随便进行转换。

总之,view和shape操作,不变的是数据本身,变的是对数据的理解方式。要时刻注意数据的排列方式(数据的batch_size、channel number、height、width),不要破坏数据。

本文分享自微信公众号 - python pytorch AI机器学习实践(gh_a7878fd5de90),作者:王某某搞AI

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2019-10-04

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • pytorch基础知识:张量(下)

    其中一维标量主要用于Bias(偏差)中,如在构建神经元中多组数据导入到一个神经元中,由激活函数激活输出一个数值,则该神经元主要使用bias功能。线性层输入(Li...

    用户6719124
  • Pytorch-nn.Module

    (1)nn.Module在pytorch中是基本的复类,继承它后会很方便的使用nn.linear、nn.normalize等。

    用户6719124
  • CIFAR10数据集实战-LeNet5神经网络(上)

    上次课我们讲解了对于CIFAR10数据读取部分代码的编写,本节讲解如何编写经典的LeNet5神经网络。

    用户6719124
  • 一个有趣的时间段重叠问题

    版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.n...

    用户1148526
  • MySQL中InnoDB引擎对索引的扩展

    MySQL中,使用InnoDB引擎的每个表,创建的普通索引(即非主键索引),都会同时保存主键的值。

    数据和云
  • pytorch基础知识:张量(下)

    其中一维标量主要用于Bias(偏差)中,如在构建神经元中多组数据导入到一个神经元中,由激活函数激活输出一个数值,则该神经元主要使用bias功能。线性层输入(Li...

    用户6719124
  • 丢给你个环形队列玩玩

    假设我需要处理10000个字节的数据,就是串口一次性会发过来10000个字节,然后单片机每次取10个字节处理,然后处理1000次就处理完了

    杨奉武
  • 重叠时间段问题优化算法详解

    这是一个实际业务需求中的问题。某一直播业务表中记录了如下格式的用户进出直播间日志数据:

    用户1148526
  • 十分钟快速了解Pandas的常用操作!

    原文 | https://pandas.pydata.org/pandas-docs/version/0.18.0/

    刘早起
  • 案例分析:闰秒带来的BUG是否影响了你?

    闰秒如何影响了IT世界?在2016年底我们写下的文章里曾经提到2017开年多出这一秒,大家是否平稳度过?欢迎大家留言讲诉你遇到的真实故事。 毫无疑问,根据墨菲...

    数据和云

扫码关注云+社区

领取腾讯云代金券