前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >pytorch的python API略读--tensor(三)

pytorch的python API略读--tensor(三)

作者头像
用户9875047
发布2022-07-04 14:00:01
2120
发布2022-07-04 14:00:01
举报
文章被收录于专栏:机器视觉全栈er机器视觉全栈er

2.1.2 索引

筛选出符合某种条件的subtensor。

torch.where: 根据布尔变量的值选择tensor中的元素,用法如下:

代码语言:javascript
复制
torch.where(condition, x, y)

下面举个简单的例子:

代码语言:javascript
复制
>>> import torch
>>> cvtutorials = torch.randn(3, 4)
>>> threshold = torch.zeros(3, 4)
>>> cvtutorials
tensor([[-1.6981,  1.0443,  2.7922, -0.8736],
        [-2.0208, -0.4815, -0.1488, -0.9714],
        [ 1.1035,  0.4089,  0.6279,  2.4600]])
>>> torch.where(cvtutorials > 0, cvtutorials, threshold)
tensor([[0.0000, 1.0443, 2.7922, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000],
        [1.1035, 0.4089, 0.6279, 2.4600]])

上面torch.where函数返回tensor的某个元素的值遵循这样的选择:如果cvtutorials中的某个元素大于0,那么保留,否则设置为0,用数学公式表达如下:

torch.index_select: 沿着某个维度,通过index对输入tensor进行筛选。用法如下:

代码语言:javascript
复制
torch.index_select(input, dim, index, *, out=None)

举个例子说明下:

代码语言:javascript
复制
>>> cvtutorials = torch.randn(2,3)
>>> cvtutorials
tensor([[-0.9935, -0.9802, -0.6104],
        [ 2.6251, -1.0099,  0.4752]])
>>> indices = torch.tensor([0, 1])
>>> torch.index_select(cvtutorials, 0, indices)
tensor([[-0.9935, -0.9802, -0.6104],
        [ 2.6251, -1.0099,  0.4752]])
>>> torch.index_select(cvtutorials, 1, indices)
tensor([[-0.9935, -0.9802],
        [ 2.6251, -1.0099]])

torch.masked_select: 根据设置的mask,返回一个一维的tensor(向量)。用法如下:

代码语言:javascript
复制
torch.masked_select(input, mask, *, out=None)

举个简单的例子:

代码语言:javascript
复制
>>> cvtutorials = torch.randn(2, 3)
>>> cvtutorials
tensor([[ 1.1016, -1.5259,  1.1065],
        [ 0.4838, -0.5521,  0.1556]])
>>> mask = torch.tensor([[False, True, True], [True, False, False]])
>>> mask
tensor([[False,  True,  True],
        [ True, False, False]])
>>> torch.masked_select(cvtutorials, mask)
tensor([-1.5259,  1.1065,  0.4838])

从中可以看出,根据mask对输入tensor相应位置的元素进行筛选,mask某位置为True,则取出tensor相应位置的元素,否则,不取出。

还有一点,mask的shape不一定和tensor一样,但是需要broadcast到tensor上,例如:

代码语言:javascript
复制
>>> cvtutorials = torch.randn(2, 3)
>>> cvtutorials
tensor([[ 0.8686,  0.0910,  1.8702],
        [ 1.8140, -1.0902,  0.7051]])
>>> mask = torch.tensor([[False, True, True]])
>>> mask
tensor([[False,  True,  True]])
>>> torch.masked_select(cvtutorials, mask)
tensor([ 0.0910,  1.8702, -1.0902,  0.7051])
本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2022-02-24,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 机器视觉全栈er 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 2.1.2 索引
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档