前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >PyTorch入门笔记-分割split函数

PyTorch入门笔记-分割split函数

作者头像
触摸壹缕阳光
发布2021-02-26 16:25:28
7.4K0
发布2021-02-26 16:25:28
举报

split

torch.split(input, split_size_or_sections, dim = 0) 函数会将输入张量(input)沿着指定维度(dim)分割成特定数量的张量块,并返回元素为张量块的元素。简单来说,可以将 torch.split 函数看成是 torch.chunk 函数的进阶版,因为 torch.split 不仅能够指定块数均匀分割(torch.chunk 只能指定块数均匀分割),而且能够指定分割每一块的长度。 torch.split 函数有三个参数:

  • tensor(Tensor)- 待分割的输入张量,此处的 tensor 参数和 torch.chunk 函数中的 input 参数类似,只需要注意使用关键字参数时候的参数名
  • split_size_or_sections(int)or(list(int))参数
    • 指定为 int 时,和 torch.chunk(input, chunks, dim = 0) 函数中的 chunks 参数功能一样;
    • 指定为 list(int) 时,list 中的每一个整数元素代表分割的块数,而每个块的长度由对应的整型元素决定;
  • dim(int)- 进行分割的维度

torch.split 函数一共有两种分割形式,而这两种分割形式是由传入 split_size_or_sections 参数的类型所决定的。

指定为 int 时

当传入 torch.split 函数中的 split_size_or_sections 参数为整型时(int),torch.split 函数和 torch.chunk 函数所实现的功能一样,torch.split 函数中的 split_size_or_sections 参数和 torch.chunk 函数中的 chunks 参数等价。

“简单回顾上一小节介绍的 torch.chunk: 使用 torch.chunk 函数沿着 dim 维度将张量均匀的分割成 chunks 块,若式子

\frac{input.size(dim)}{chunks}

结果为:

  • 整数(整除),表示能够将其均匀的分割成 chunks 块,直接进行分割即可;
  • 浮点数(不能够整除),先按每块
\lceil \frac{input.size(dim)}{chunks} \rceil

\lceil \ \rceil

为向上取整)进行分割,余下的作为最后一块;

比如,将形状为

[2, 3]

的张量

B

,现在沿着第 1 个维度均匀的分割成 2 块。B.size(1) = 3、chunks = 2,即:

\frac{input.size(dim)}{chunks} = \frac{B.size(3)}{chunks}
=\frac{3}{2} = 1.5

1.5 不是整数,则将其向上取整

\lceil 1.5 \rceil = 2

,先将 3 按每块 2 个进行分割,余下的作为最后一块。

代码语言:javascript
复制
import torch

B = torch.arange(6).reshape(2, 3)

# 使用torch.chunk函数
result_chunk = torch.chunk(input = B,
                     chunks = 2,
                     dim = 1)

# 使用torch.split函数
result_split = torch.split(tensor = B,
                       split_size_or_sections = 2,
                       dim = 1)

print(B)
# tensor([[0, 1, 2],
#         [3, 4, 5]])

print(result_chunk)
# (tensor([[0, 1],
#          [3, 4]]), 
#  tensor([[2],
#          [5]]))

print(result_split)
# (tensor([[0, 1],
#          [3, 4]]), 
#  tensor([[2],
#          [5]]))

实验结果显示,当传入 torch.split 函数中的 split_size_or_sections 参数为整型时(int),torch.split 和 torch.chunk 两个函数完全一样。

指定为 list 时

当传入 torch.split 函数中的参数 split_size_or_sections 为列表类型时(具体来说应该是元素为 int 整型的 list 列表),list 中的每一个整数元素代表分割的块数,而每个块的长度由对应的整型元素决定。

比如,将形状为

[2, 3]

的张量

B

,现在沿着第 1 个维度分割成 2 块,第一块长度为 1,而第二块长度为 2。使用 torch.split 函数,只需要为 split_size_or_sections 参数传入 [1, 2] 列表即可。

代码语言:javascript
复制
import torch

B = torch.arange(6).reshape(2, 3)

result = torch.split(tensor = B,
                 split_size_or_sections = [1, 2],
                 dim = 1)

print(B)
# tensor([[0, 1, 2],
#       [3, 4, 5]])

print(result)
# (tensor([[0],
#          [3]]), 
#  tensor([[1, 2],
#          [4, 5]]))

传入 split_size_or_sections 参数的 list 中的每一个整数元素代表分割的块数,而每个块的长度由对应的整型元素决定,因此待分割维度的长度应该等于 list 中所有整型元素之和,否则程序会报错。

代码语言:javascript
复制
import torch

B = torch.arange(6).reshape(2, 3)

result = torch.split(tensor = B,
                 split_size_or_sections = [1, 4],
                 dim = 1)

'''
return super(Tensor, self).split_with_sizes(split_size, dim)
RuntimeError: start (1) + length (4) exceeds dimension size (3).
'''
本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2021-02-17,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI机器学习与深度学习算法 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • split
    • 指定为 int 时
      • 指定为 list 时
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档