torch.masked_select(input,mask,out=None) 函数返回一个根据布尔掩码 (boolean mask) 索引输入张量的 1D 张量,其中布尔掩码和输入张量就是 torch.masked_select(input, mask, out = None) 函数的两个关键参数,函数的参数有:
「masked_select 函数最关键的参数就是布尔掩码 mask,传入 mask 参数的布尔张量通过 True 和 False (或 1 和 0) 来决定输入张量对应位置的元素是否保留,既然是一一对应的关系,这就需要传入 mask 中的布尔张量和传入 input 中的输入张量形状要相同。」 这里需要注意此时的形状相同包括显式的相等,还包括隐式的相等。
input.size() = mask.size()
;>>> import torch
>>> x = torch.randn([3, 4])
>>> print(x)
tensor([[ 1.2001, 1.2968, -0.6657, -0.6907],
[-2.0099, 0.6249, -0.5382, 1.4458],
[ 0.0684, 0.4118, 0.1011, -0.5684]])
>>> # 将x中的每一个元素与0.5进行比较
>>> # 当元素大于等于0.5返回True,否则返回False
>>> mask = x.ge(0.5)
>>> print(mask)
tensor([[ True, True, False, False],
[False, True, False, True],
[False, False, False, False]])
>>> print(torch.masked_select(x, mask))
tensor([1.2001, 1.2968, 0.6249, 1.4458])
简单回顾广播机制:广播机制 (Broadcast) 是在科学运算中经常使用的小技巧,它是一种轻量级的张量复制手段,只在逻辑层面扩展和复制张量,并不进行实际的存储复制操作,从而大大的减少了计算代价。
有了广播机制,并不是所有形状不一致的张量都能进行广播,需要满足一定的规则。比如对于两个张量来说:
「对于 masked_select 函数中的广播机制比较简单,因为无论在什么情况下都是需要将传入 mask 参数的布尔张量广播成与传入 input 参数中的输入张量相同的形状。简单来说,输入张量不变只对布尔张量进行广播,而广播后的形状和输入张量的形状一致。」
假如此时输入张量为:
形状为 (2, 2),布尔张量为
,形状为 (2, )。
将输入张量和广播后的布尔张量一一对应,通过 True 和 False 决定是否筛选出该元素,最终筛选出来的元素为 0 和 2,由于使用 masked_select 函数返回的都是 1D 张量,因此最终的结果为 tensor([0, 2])。
>>> import torch
>>> # mask.size() ≠ x.size()但是mask能够广播成x.size()的形状
>>> x = torch.arange(4).view([2, 2])
>>> mask2 = torch.tensor([True, False])
>>> print(torch.masked_select(x, mask2))
tensor([0, 2])
>>> # mask.size() ≠ x.size()并且mask不符合广播规则
>>> mask3 = torch.tensor([True, False, True, False])
>>> print(torch.masked_select(x, mask3))
Traceback (most recent call last):
File "/home/chenkc/code/masked_select.py", line 100, in <module>
print(torch.masked_select(x, mask3))
RuntimeError: The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 1
masked_select 函数虽然简单,但是有几点需要注意: