为更好的理解,本节再进行类比操作
a = torch.rand([16, 8])
b = torch.rand([15, 8])
c = torch.cat([a, b], dim=0)
print('c.shape =', c.shape)
输出为
c.shape = torch.Size([31, 8])
若改用.stack
a = torch.rand([16, 8])
b = torch.rand([15, 8])
d = torch.stack([a, b], dim=0)
print('d.shape =', d.shape)
输出会报错
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 16 and 15 in dimension 1 at ..\aten\src\TH/generic/THTensor.cpp:711
因为.stack函数需要保持各个维度上一致才可以进行
下面进行拆分操作的讲解
首先介绍根据长度进行拆分的split
其API为:.split(self, split_size, dim)
举例
a = torch.rand([15, 8])
b = torch.rand([15, 8])
c = torch.stack([a, b], dim=0)
print('c.shape =', c.shape)
new_a, new_b = c.split(1, dim=0)
# 括号内的1表示长度,每组为1个长度,因此将2分为2个1
print('new_a.shape =', new_a.shape, 'new_b.shape =', new_b.shape)
输出
c.shape = torch.Size([2, 15, 8])
new_a.shape = torch.Size([1, 15, 8]) new_b.shape = torch.Size([1, 15, 8])
此外也可以分为更复杂的组,如分为.split([2, 1], dim=0),表示一组是2个长度、另一组的1个长度。因使用简单,这里不再赘述。
下面介绍按数量区分的.chunk函数
其API为:.chunk(self, chunks, dim)
举例
a = torch.rand(32, 8)
b = torch.rand(32, 8)
c = torch.stack([a, b], dim=0)
print('c.shape =', c.shape)
a1, b1 = c.chunk(2, dim=0)
print('a1.shape =', a1.shape, 'b1.shape =', b1.shape)
# 进行了2组/2个数量的除法操作
输出为
c.shape = torch.Size([2, 32, 8])
a1.shape = torch.Size([1, 32, 8]) b1.shape = torch.Size([1, 32, 8])
这两组操作较为简单,但又十分常见,一定要牢固掌握。
本文分享自 python pytorch AI机器学习实践 微信公众号,前往查看
如有侵权,请联系 cloudcommunity@tencent.com 删除。
本文参与 腾讯云自媒体同步曝光计划 ,欢迎热爱写作的你一起参与!