前言
交换维度顾名思义就是交换不同的维度,线性代数中矩阵的转置操作可以看成是交换第 0 个和第 1 个维度。比如下图形状为 (3, 4) 的矩阵。
交换第 0 个维度和第 1 个维度 (转置) 为形状为 (4, 3) 的矩阵。
不仅是在线性代数中经常会遇到交换维度的操作,在深度学习中交换维度的操作也非常常见。比如对于图片张量来说,在 PyTorch 中将通道维度放到最后面 [b, h, w, c],而在 TensorFlow 中将通道维度放在前面 [b, c, h, w],如果需要将 [b, h, w, c] 转换为 [b, c, h, w] 则需要使用交换维度的操作。
PyTorch 中交换维度的操作有 transpose 和 permute 两种方式。交换维度的操作至少要求张量拥有两个以及两个以上的维度才有意义,因此在介绍交换维度的方式时不再考虑 0D 和 1D 张量。
torch.transpose(input, dim0, dim1)
函数将输入张量 input 的第 dim0 个维度和第 dim1 个维度进行交换,并将交换维度后的张量返回。transpose 函数作用非常直观,使用起来也非常简单,因此使用方法不再过多的赘述。
下面是在使用 transpose 函数时的几个注意事项。
torch.t(input)
。当 input 为 2D 张量时 torch.t(input)
等价 torch.transpose(input, 0, 1)
(或 torch.transpose(input, 1, 0)
);交换 nD 张量 ( n \geq 2) 的任意两个不同的维度在很多时候并不能满足我们的需求。比如将图片张量 [b, h, w, c] 转换为 [b, c, h, w]。
>>> import torch
>>> # 使用[0, 1)均匀分布模拟图片张量
>>> # (batch_size, height, width, channels)
>>> imgs = torch.randn([1, 32, 28, 3])
>>> # 交换height和channels两个维度
>>> imgs_swap = torch.transpose(imgs, 1, 3)
>>> # (batch_size, channels, width, height)
>>> print(imgs_swap.shape)
torch.Size([1, 3, 28, 32])
>>> # 交换width和height两个维度
>>> imgs = torch.transpose(imgs_swap, 2, 3)
>>> # (batch_size, channels, height, width)
>>> print(imgs.shape)
torch.Size([1, 3, 32, 28])
虽然能够将图片张量 [b, h, w, c] 转换为 [b, c, h, w],但是使用了两次 transpose 函数,并且需要熟知每次变换后对应维度的位置,非常容易出错。PyTorch 针对这种多次交换维度的方式提供 permute 函数。
前面提到过 PyTorch 从接口的角度将张量的操作分成两种方式。比如对于 transpose 函数来说,可以使用 torch.transpose(input, dim0, dim1)
或者 input.transpose(dim0, dim1)
,两种定义方式本质上是一样的。但是 permute 函数只有 input.permute(*dims)
一种定义方式,其中 *dims 为期望维度的顺序。 来看看如何通过 permute 函数将图片张量 [b, h, w, c] 转换为 [b, c, h, w]。
>>> import torch
>>> # 使用[0, 1)均匀分布模拟图片张量
>>> # (batch_size, height, width, channels)
>>> imgs = torch.randn([1, 32, 28, 3])
>>> # 重新排列维度顺序
>>> print(imgs.permute(0, 3, 1, 2).shape)
torch.Size([1, 3, 32, 28])
[b, h, w, c] 维度序号为 (0, 1, 2, 3)
,如果想要将 [b, h, w, c] 转换为 [b, c, h, w],只需要重新排列一下维度序号 (0, 3, 1, 2)
,这也是 permute 函数的设计原理。
原文地址:
本文分享自 AI机器学习与深度学习算法 微信公众号,前往查看
如有侵权,请联系 cloudcommunity@tencent.com 删除。
本文参与 腾讯云自媒体同步曝光计划 ,欢迎热爱写作的你一起参与!