本节介绍pytorch中where和gather两个高阶代码。
.where (根据条件的选取,选择出源头)
API: torch.where(condition, x, y) => Tensor
在输出时选择返回x或y当中的一个,选择条件依据前面给定的condition
因此 输出的out_tensor为x(当条件为真时) 或 y(当条件不满足时)
举例
import torch
cond = torch.tensor([[0.6, 0.7], [0.8, 0.4]])
# 将cond以可能性表示,数值越接近于1,是a的可能性越大,越接近于0,是b的可能性越大
# 我们假设以是否大于0.5作为考察条件,大于为1,小于为0
# 将上述分类器构思完毕后,我们选择源头的两个元素分别设定为a和b
a = torch.tensor([[1., 1.], [1., 1.]])
b = torch.tensor([[0., 0.], [0., 0.]])
# a为大于1,b为小于1
print(torch.where(cond > 0.5, a, b))
# 输出选择结果
输出
tensor([[1., 1.],
[1., 0.]])
0.6、0.7、0.8大于0.5,输出选择为1,。0.4小于0.5,输出选择为0.
当然这里也可以使用两个for循环来输出相应的值,但会占用更高的CPU运算空间。因此这里介绍的高阶操作: torch.where(condition, x, y) 集成了上述两个连续的操作,使代码更简洁
.gather (收集,查表操作)
其API为:torch.gather(input, dim, index, out=None) => Tensor,使用时根据设定的dim信息指定收集位置。
举例具体说明:
对于Mnist数据集,假设得到的输出为一个[4, 10]矩阵,里面的最大值为0.9(当做算出的最高“可能性”),由此得到的index为[10, 1]的label,里面是类似于[1,0, 2, 3, 1...]的label,此时使用gather来进行查表操作,将label数值与之前的index对应起来。
下面以具体案例来进行讲解
import torch
prob = torch.rand(4, 10)
print(prob)
idex = prob.topk(dim=1, k=3)
# k=3 表明取最有可能的三种“结果”
print(idex[1])
# 查看index
label = torch.arange(10)+100
# 假设做的是一个大项目,里面有大量的index
print(torch.gather(label.expand(4, 10), dim=1, index=idex.idex[1]))
# 将结果表扩展成[4, 10], 在1维上进行查表,即dim=1,在[4, 3]上进行查表4次(batch=4),对每个batch查询3次,以得到3个索引
输出结果
tensor([[1, 3, 5],
[0, 7, 8],
[7, 3, 2],
[3, 7, 5]])
这个结果表明第一个图片最有可能对应于index中的第1个,第二个图片最有可能对应于index中的第0个。
tensor([[101, 103, 105],
[100, 107, 108],
[107, 103, 102],
[103, 107, 105]])
在大量数据中,第一个图片最有可能对应于index中的第101个
由实际操作来看,gather主要完成的是映射功能
本文分享自 python pytorch AI机器学习实践 微信公众号,前往查看
如有侵权,请联系 cloudcommunity@tencent.com 删除。
本文参与 腾讯云自媒体分享计划 ,欢迎热爱写作的你一起参与!