前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >pytorch基础知识-高阶代码

pytorch基础知识-高阶代码

作者头像
用户6719124
发布2019-11-17 22:47:24
8590
发布2019-11-17 22:47:24
举报

本节介绍pytorch中where和gather两个高阶代码。

.where (根据条件的选取,选择出源头)

API: torch.where(condition, x, y) => Tensor

在输出时选择返回x或y当中的一个,选择条件依据前面给定的condition

因此 输出的out_tensor为x(当条件为真时) 或 y(当条件不满足时)

举例

代码语言:javascript
复制
import torch
cond = torch.tensor([[0.6, 0.7], [0.8, 0.4]])
# 将cond以可能性表示,数值越接近于1,是a的可能性越大,越接近于0,是b的可能性越大
# 我们假设以是否大于0.5作为考察条件,大于为1,小于为0
# 将上述分类器构思完毕后,我们选择源头的两个元素分别设定为a和b
a = torch.tensor([[1., 1.], [1., 1.]])
b = torch.tensor([[0., 0.], [0., 0.]])
# a为大于1,b为小于1
print(torch.where(cond > 0.5, a, b))
# 输出选择结果

输出

代码语言:javascript
复制
tensor([[1., 1.],
        [1., 0.]])

0.6、0.7、0.8大于0.5,输出选择为1,。0.4小于0.5,输出选择为0.

当然这里也可以使用两个for循环来输出相应的值,但会占用更高的CPU运算空间。因此这里介绍的高阶操作: torch.where(condition, x, y) 集成了上述两个连续的操作,使代码更简洁

.gather (收集,查表操作)

其API为:torch.gather(input, dim, index, out=None) => Tensor,使用时根据设定的dim信息指定收集位置。

举例具体说明:

对于Mnist数据集,假设得到的输出为一个[4, 10]矩阵,里面的最大值为0.9(当做算出的最高“可能性”),由此得到的index为[10, 1]的label,里面是类似于[1,0, 2, 3, 1...]的label,此时使用gather来进行查表操作,将label数值与之前的index对应起来。

下面以具体案例来进行讲解

代码语言:javascript
复制
import torch
prob = torch.rand(4, 10)
print(prob)
idex = prob.topk(dim=1, k=3)
# k=3 表明取最有可能的三种“结果”
print(idex[1])
# 查看index
label = torch.arange(10)+100
# 假设做的是一个大项目,里面有大量的index
print(torch.gather(label.expand(4, 10), dim=1, index=idex.idex[1]))
# 将结果表扩展成[4, 10], 在1维上进行查表,即dim=1,在[4, 3]上进行查表4次(batch=4),对每个batch查询3次,以得到3个索引

输出结果

代码语言:javascript
复制
tensor([[1, 3, 5],
        [0, 7, 8],
        [7, 3, 2],
        [3, 7, 5]])

这个结果表明第一个图片最有可能对应于index中的第1个,第二个图片最有可能对应于index中的第0个。

tensor([[101, 103, 105],

[100, 107, 108],

[107, 103, 102],

[103, 107, 105]])

在大量数据中,第一个图片最有可能对应于index中的第101个

由实际操作来看,gather主要完成的是映射功能

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-10-15,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 python pytorch AI机器学习实践 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档