前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >broadcasting(自动扩展)与合并分割(下)

broadcasting(自动扩展)与合并分割(下)

作者头像
用户6719124
发布2019-11-17 23:05:28
5110
发布2019-11-17 23:05:28
举报

为更好的理解,本节再进行类比操作

代码语言:javascript
复制
a = torch.rand([16, 8])
b = torch.rand([15, 8])
c = torch.cat([a, b], dim=0)
print('c.shape =', c.shape)

输出为

代码语言:javascript
复制
c.shape = torch.Size([31, 8])

若改用.stack

代码语言:javascript
复制
a = torch.rand([16, 8])
b = torch.rand([15, 8])
d = torch.stack([a, b], dim=0)
print('d.shape =', d.shape)

输出会报错

代码语言:javascript
复制
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 16 and 15 in dimension 1 at ..\aten\src\TH/generic/THTensor.cpp:711

因为.stack函数需要保持各个维度上一致才可以进行

下面进行拆分操作的讲解

首先介绍根据长度进行拆分的split

其API为:.split(self, split_size, dim)

举例

代码语言:javascript
复制
a = torch.rand([15, 8])
b = torch.rand([15, 8])
c = torch.stack([a, b], dim=0)
print('c.shape =', c.shape)
new_a, new_b = c.split(1, dim=0)
# 括号内的1表示长度,每组为1个长度,因此将2分为2个1
print('new_a.shape =', new_a.shape, 'new_b.shape =', new_b.shape)

输出

代码语言:javascript
复制
c.shape = torch.Size([2, 15, 8])
new_a.shape = torch.Size([1, 15, 8]) new_b.shape = torch.Size([1, 15, 8])

此外也可以分为更复杂的组,如分为.split([2, 1], dim=0),表示一组是2个长度、另一组的1个长度。因使用简单,这里不再赘述。

下面介绍按数量区分的.chunk函数

其API为:.chunk(self, chunks, dim)

举例

代码语言:javascript
复制
a = torch.rand(32, 8)
b = torch.rand(32, 8)
c = torch.stack([a, b], dim=0)
print('c.shape =', c.shape)
a1, b1 = c.chunk(2, dim=0)
print('a1.shape =', a1.shape, 'b1.shape =', b1.shape)
# 进行了2组/2个数量的除法操作

输出为

代码语言:javascript
复制
c.shape = torch.Size([2, 32, 8])
a1.shape = torch.Size([1, 32, 8]) b1.shape = torch.Size([1, 32, 8])

这两组操作较为简单,但又十分常见,一定要牢固掌握。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-10-10,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 python pytorch AI机器学习实践 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档