前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >PyTorch入门笔记-masked_select选择函数

PyTorch入门笔记-masked_select选择函数

作者头像
触摸壹缕阳光
发布2020-12-02 12:28:55
3.7K0
发布2020-12-02 12:28:55
举报

masked_select

torch.masked_select(input,mask,out=None) 函数返回一个根据布尔掩码 (boolean mask) 索引输入张量的 1D 张量,其中布尔掩码和输入张量就是 torch.masked_select(input, mask, out = None) 函数的两个关键参数,函数的参数有:

  • input(Tensor) - 需要进行索引操作的输入张量;
  • mask(BoolTensor) - 要进行索引的布尔掩码
  • out(Tensor, optional) - 指定输出的张量。比如执行 torch.zeros([2, 2], out = tensor_a),相当于执行 tensor_a = torch.zeros([2, 2]);

「masked_select 函数最关键的参数就是布尔掩码 mask,传入 mask 参数的布尔张量通过 True 和 False (或 1 和 0) 来决定输入张量对应位置的元素是否保留,既然是一一对应的关系,这就需要传入 mask 中的布尔张量和传入 input 中的输入张量形状要相同。」 这里需要注意此时的形状相同包括显式的相等,还包括隐式的相等。

  • 显式相等非常好理解,input.size() = mask.size()
代码语言:javascript
复制
>>> import torch
>>> x = torch.randn([3, 4])
>>> print(x)

tensor([[ 1.2001,  1.2968, -0.6657, -0.6907],
        [-2.0099,  0.6249, -0.5382,  1.4458],
        [ 0.0684,  0.4118,  0.1011, -0.5684]])

>>> # 将x中的每一个元素与0.5进行比较
>>> # 当元素大于等于0.5返回True,否则返回False
>>> mask = x.ge(0.5)
>>> print(mask)

tensor([[ True,  True, False, False],
        [False,  True, False,  True],
        [False, False, False, False]])

>>> print(torch.masked_select(x, mask))

tensor([1.2001, 1.2968, 0.6249, 1.4458])
  • 隐式相等其实就是 PyTorch 中的广播机制,换句话说,传入 mask 参数的布尔张量和传入 input 参数的输入张量的形状可以不相等,但是这两个张量必须能够通过 PyTorch 中的广播机制广播成相同形状的张量;

简单回顾广播机制:广播机制 (Broadcast) 是在科学运算中经常使用的小技巧,它是一种轻量级的张量复制手段,只在逻辑层面扩展和复制张量,并不进行实际的存储复制操作,从而大大的减少了计算代价。

有了广播机制,并不是所有形状不一致的张量都能进行广播,需要满足一定的规则。比如对于两个张量来说:

  • 如果两个张量的维度不同,则将维度小的张量进行扩展,直到两个张量的维度一样;
  • 如果两个张量在对应维度上的长度相同或者其中一个张量的长度为 1,那么就说这两个张量在该维度上是相容的;
  • 如果两个张量在所有维度上都是相容的,表示这两个张量能够进行广播,否则会出错;
  • 在任何一个维度上,如果一个张量的长度为 1,另一个张量的长度大于 1,那么在该维度上,就好像是对第一个张量进行了复制;

「对于 masked_select 函数中的广播机制比较简单,因为无论在什么情况下都是需要将传入 mask 参数的布尔张量广播成与传入 input 参数中的输入张量相同的形状。简单来说,输入张量不变只对布尔张量进行广播,而广播后的形状和输入张量的形状一致。」

假如此时输入张量为:

\left[\begin{matrix} 0 & 1 \\ 2 & 3 \end{matrix} \right]

形状为 (2, 2),布尔张量为

[True, False]

,形状为 (2, )。

  • 由于只需要对布尔张量进行广播,因此只关注布尔张量,首先为布尔张量添加新的维度,最终两个张量的维度都是 2;
  • 由于布尔张量的第一个维度上的长度和输入张量第一个维度上的长度相等,因此第一个维度相容。布尔张量的第二个维度上的长度为 1,同样在第二个维度上也相同;
  • 布尔张量的两个维度上都是相容的,因此布尔张量可以进行广播;
  • 在布尔张量的第二个维度上进行复制,最终的布尔张量为:
\left [\begin{matrix} True & False \\ True & False \end{matrix} \right]

将输入张量和广播后的布尔张量一一对应,通过 True 和 False 决定是否筛选出该元素,最终筛选出来的元素为 0 和 2,由于使用 masked_select 函数返回的都是 1D 张量,因此最终的结果为 tensor([0, 2])。

代码语言:javascript
复制
>>> import torch
>>> # mask.size() ≠ x.size()但是mask能够广播成x.size()的形状
>>> x = torch.arange(4).view([2, 2])
>>> mask2 = torch.tensor([True, False])
>>> print(torch.masked_select(x, mask2))

tensor([0, 2])

>>> # mask.size() ≠ x.size()并且mask不符合广播规则
>>> mask3 = torch.tensor([True, False, True, False])
>>> print(torch.masked_select(x, mask3))

Traceback (most recent call last):
  File "/home/chenkc/code/masked_select.py", line 100, in <module>
    print(torch.masked_select(x, mask3))
RuntimeError: The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 1

masked_select 函数虽然简单,但是有几点需要注意:

  • 使用 masked_select 函数返回的结果都是 1D 张量,张量中的元素就是被筛选出来的元素值;
  • 传入 input 参数中的输入张量和传入 mask 参数中的布尔张量形状可以不一致,但是布尔张量必须要能够通过广播机制扩展成和输入张量相同的形状;
本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-11-27,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI机器学习与深度学习算法 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • masked_select
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档