在学习pytorch的官方文档时,发现掩码的程序贴错了,自己写了一个,大家可以参考。
torch.masked_select(input, mask, out=None) → Tensor
根据掩码张量mask
中的二元值,取输入张量中的指定项( mask
为一个 ByteTensor),将取值返回到一个新的1D张量,
张量 mask
须跟input
张量有相同数量的元素数目,但形状或维度不需要相同。
注意: 返回的张量不与原始张量共享内存空间。
参数:
torch.masked_select(x,mask)