torch.cat(tensors, dim = 0)
函数拼接操作是在现有维度上合并数据,并不会创建新的维度。如果在合并数据时,希望创建一个新的维度,则需要使用 torch.stack 操作。
torch.stack(tensors, dim = 0)
函数可以使用堆叠的方式合并多个张量,参数 tensors 保存了所有需要合并张量的序列(任何Python的序列对象,比如列表、元组等),参数 dim 指定新维度插入的位置,torch.stack 函数中的 dim 参数与 torch.unsqueeze 函数(增加长度为 1 的新维度)中的 dim 参数用法一致:
dim ≥ 0
时,在 dim 之前插入新维度;dim < 0
时,在 dim 之后插入新维度;例如,对于形状为
的张量,在不同位置通过 torch.stack 操作插入新维度,dim 参数对应的插入位置设置如下图所示。
比如张量
是形状为
的 3 通道图片张量,张量
是另外一个形状为
的 3 通道图片张量。使用 torch.stack 合并这两个图片张量,批量维度插入在 dim = 0 的位置上,具体代码如下。
import torch
# 模拟图像张量A
a = torch.randn(3, 32, 32)
# 模拟图像张量B
b = torch.randn(3, 32, 32)
# 堆叠合并为2个图片张量,批量的维度插在最前面
stack_ab = torch.stack([a, b], dim = 0)
print(stack_ab.size())
# torch.Size([2, 3, 32, 32])
同样可以在其它位置上插入新的维度,比如,最末尾插入批量维度。
import torch
# 模拟图像张量A
a = torch.randn(3, 32, 32)
# 模拟图像张量B
b = torch.randn(3, 32, 32)
# 堆叠合并为2个图片张量,批量的维度插在最末尾
stack_ab = torch.stack([a, b], dim = -1)
print(stack_ab.size())
# torch.Size([3, 32, 32, 2])
torch.cat(tensors, dim = 0)
函数有两个约束:
显然 torch.cat 函数也能够拼接合并
和
两个图片张量。
import torch
# 模拟图像张量A
a = torch.randn(3, 32, 32)
# 模拟图像张量B
b = torch.randn(3, 32, 32)
# 拼接合并为2个图片张量,没有批量维度的概念
cat_ab = torch.cat([a, b], dim = 0)
print(cat_ab.size())
# torch.Size([6, 32, 32])
形状都是
的
和
两个图片张量,沿着第 0 个维度进行合并(通道维度)后的张量形状为
。虽然 torch.cat 函数能够顺利的拼接合并,但是在理解时,需要按照前 3 个通道来自第一张图片,后 3 个通道来自第二张图片的方式理解张量。对于这个例子,明显通过 torch.stack 的方式创建新维度的方式更为合理,得到的形状为
的张量也更容易理解。
torch.stack(tensors, dim = 0)
使用个 torch.cat 函数一样同样需要一些约束,这也是在使用 torch.stack(tensors, dim = 0)
函数时需要注意的地方。
import torch
# 模拟图像张量A
a = torch.randint(0, 255, (3, 32, 32))
# 模拟图像张量B
b = torch.randn(3, 32, 32)
print(a.dtype)
# torch.int64
print(b.dtype)
# torch.float32
# 非法堆叠,张量A和B的数据类型不相同
stack_ab = torch.stack([a, b], dim = 0)
print(stack_ab.size())
# Traceback (most recent call last):
# File "/home/chenkc/code/pytorch/test01.py", line 12, in <module>
# stack_ab = torch.stack([a, b], dim = 0)
# RuntimeError: Expected object of scalar type long int but got scalar type float for sequence element 1.
torch.stack 也需要满足张量堆叠合并的条件,它需要所有待合并的张量形状完全一致才可以进行合并。如果待合并张量的形状不一致时,进行堆叠合并会发生错误。
import torch
# 模拟图像张量A
a = torch.randn(3, 32, 32)
# 模拟图像张量B
b = torch.randn(1, 32, 32)
# 非法堆叠操作,张量的形状不相同
stack_ab = torch.stack([a, b], dim = 0)
print(stack_ab.size())
# Traceback (most recent call last):
# File "/home/chenkc/code/pytorch/test01.py", line 9, in <module>
# stack_ab = torch.stack([a, b], dim = 0)
# RuntimeError: stack expects each tensor to be equal size, but got [3, 32, 32] at entry 0 and [1, 32, 32] at entry 1
References: