前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >PyTorch入门笔记-分割chunk函数

PyTorch入门笔记-分割chunk函数

作者头像
触摸壹缕阳光
发布2021-02-26 16:25:06
5.9K0
发布2021-02-26 16:25:06
举报

chunk

torch.chunk(input, chunks, dim = 0) 函数会将输入张量(input)沿着指定维度(dim)均匀的分割成特定数量的张量块(chunks),并返回元素为张量块的元组。torch.chunk 函数有三个参数:

  • input(Tensor)- 待分割的输入张量
  • chunks(int)- 均匀分割张量块的数量
  • dim(int)- 进行分割的维度

以包含批量维度的图像张量为例,设张量

A

保存了 128 张,长和宽为 32 的三通道像素矩阵,则张量

A

的形状为

[128, 3, 32, 32]

(PyTorch将通道维度放在前面,即

(batch\_{size}, channels, height, width)

)。

现在我们想将张量

A

这 128 张图片均匀的分割成 16 块,每块包含 8 张图片。可以使用 torch.chunk 函数沿着第 0 个维度(批量维度,dim = 0)均匀的将张量

A

(input = A)分割成 16 块(chunks = 16)。

代码语言:javascript
复制
import torch

A = torch.randint(0, 255, (128, 3, 32, 32))

result = torch.chunk(input=A,
                     chunks=16,
                     dim=0)

print(type(result))
# <class 'tuple'>

print(len(result))
# 16

print(type(result[0]))
# <class 'torch.Tensor'>

print(result[0].size())
# torch.Size([8, 3, 32, 32])

将形状为

[128, 3, 32, 32]

的张量

A

,沿着第 0 个维度(批量维度)均匀分割成 16 块(

\frac{128}{16} = 8

),其中每一块都是形状为

[8, 3, 32, 32]

的张量。

如果将将张量

A

这 128 张图片均匀的分割成 14 块(

\frac{128}{14} \approx 9.14

),显然不能像分割成 16 块那样能够均匀的分割。在这种情况下,torch.chunk 函数会先按照每块 10 张图片进行分割,即每一块都是形状为

[10, 3, 32, 32]

的张量,余下的作为最后一块。

代码语言:javascript
复制
import torch

A = torch.randint(0, 255, (128, 3, 32, 32))

result = torch.chunk(input=A,
                     chunks=14,
                     dim=0)


print(len(result))
# 13

print(result[0].size())
# torch.Size([10, 3, 32, 32])

print(result[-1].size())
# torch.Size([8, 3, 32, 32])

小结

可以沿着输入张量的任意维度均匀分割。使用 torch.chunk 函数沿着 dim 维度将张量均匀的分割成 chunks 块,若式子

\frac{input.size(dim)}{chunks}

结果为:

  • 整数(整除),表示能够将其均匀的分割成 chunks 块,直接进行分割即可;
  • 浮点数(不能够整除),先按每块
\lceil \frac{input.size(dim)}{chunks} \rceil

\lceil \ \rceil

为向上取整)进行分割,余下的作为最后一块;

比如,将形状为

[2, 3]

的张量

B

,现在沿着第 1 个维度均匀的分割成 2 块。B.size(1) = 3、chunks = 2,即:

\frac{input.size(dim)}{chunks} = \frac{B.size(3)}{chunks}
=\frac{3}{2} = 1.5

1.5 不是整数,则将其向上取整

\lceil 1.5 \rceil = 2

,先将 3 按每块 2 个进行分割,余下的作为最后一块。

代码语言:javascript
复制
import torch

B = torch.arange(6).reshape(2, 3)

result = torch.chunk(input = B,
                     chunks = 2,
                     dim = 1)
print(B)
# tensor([[0, 1, 2],
#        [3, 4, 5]]) 

print(result)
# tensor([[0, 1],
#        [3, 4]]), tensor([[2],
#        [5]]))
本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2021-02-10,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI机器学习与深度学习算法 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • chunk
  • 小结
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档