前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >PyTorch入门笔记-拼接cat函数

PyTorch入门笔记-拼接cat函数

作者头像
触摸壹缕阳光
修改2021-02-01 23:24:14
5.3K0
修改2021-02-01 23:24:14
举报

合并

合并是指将多个张量在某个维度上合并成一个张量。张量的合并可以使用拼接(Concatenate)和堆叠(Stack)操作实现,其中拼接操作不会产生新的维度,仅在现有维度上合并,而堆叠会创建新的维度。选择使用拼接还是堆叠操作来合并张量,取决于具体的场景是否需要创建新维度。

拼接

在 PyTorch 中,可以通过 torch.cat(tensors, dim = 0) 函数拼接张量,其中参数 tensor 保存了所有需要合并张量的序列(任何Python的序列对象,比如列表、元组等),dim 参数指定了需要合并的维度索引。

以包含批量维度的图像张量为例,设张量 A 保存了 4 张,长和宽为 32 的三通道像素矩阵,则张量 A 的形状为 [4, 3, 32, 32](PyTorch将通道维度放在前面,即 (batch_size, channels,height, width))。同样的方式,张量 B 保存了另外 5 张,长和宽为 32 的三通道像素矩阵,张量 B 的形状为 [5, 3, 32, 32]。现在需要在批量维度上合并两个包含批量维度的图像张量,这里批量维度索引号为 0,即 dim = 0,合并张量 AB 的代码如下:

代码语言:txt
复制
import torch

# 模拟图像张量A
a = torch.randn(4, 3, 32, 32)
# 模拟图像张量B
b = torch.randn(5, 3, 32, 32)

# 在批量维度上合并张量A和B
cat_ab = torch.cat([a, b], dim = 0)
print(cat_ab.size())
# torch.Size([9, 3, 32, 32])

torch.cat(tensors, dim = 0) 使用需要一些约束,这也是在使用 torch.cat(tensors, dim = 0) 函数时需要注意的地方。

  • 参数 tensors 中所有需要合并的张量必须是相同的数据类型
代码语言:txt
复制
import torch

# 模拟图像张量A
a = torch.randint(0, 255, (4, 3, 32, 32))
# 模拟图像张量B
b = torch.randn(5, 3, 32, 32)

print(a.dtype)
# torch.int64

print(b.dtype)
# torch.float32

# 非法拼接,张量A和B的数据类型不相同
cat_ab = torch.cat([a, b], dim = 0)
print(cat_ab.size())

'''
Traceback (most recent call last):
  File "/home/chenkc/code/pytorch/cat_test.py", line 13, in <module>
    cat_ab = torch.cat([a, b], dim = 0)
RuntimeError: Expected object of scalar type long int but got scalar type float for sequence element 1.
'''
  • 非合并维度的长度必须一致

「从语法上来说,拼接合并操作可以在任意的维度上进行,唯一的约束是非合并维度的长度必须一致。」 比如图像张量形状为 [4, 3, 32, 32] 和形状为 [5, 1, 32, 32] 的张量不能直接在批量维度上进行合并,因为两个图像张量的通道维度的长度不一致,一个是单通道(channels = 1),另一个是三通道(channels = 3)。

代码语言:txt
复制
import torch

# 模拟图像张量A
a = torch.randn(4, 1, 32, 32)
# 模拟图像张量B
b = torch.randn(5, 3, 32, 32)

# 非法拼接,其它维度长度不相同
cat_ab = torch.cat([a, b], dim = 0)
print(cat_ab.size())

'''
Traceback (most recent call last):
  File "/home/chenkc/code/pytorch/repeat_function.py", line 9, in <module>
    cat_ab = torch.cat([a, b], dim = 0)
RuntimeError: Sizes of tensors must match except in dimension 0. Got 1 and 3 in dimension 1
'''

References: 1. 《TensorFlow深度学习》

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2021-01-24,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI机器学习与深度学习算法 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 合并
    • 拼接
    领券
    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档