前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >01 笔记分享 : pytorch中max和nonzero使用

01 笔记分享 : pytorch中max和nonzero使用

作者头像
唐国梁Tommy
发布2021-05-28 17:29:20
7930
发布2021-05-28 17:29:20
举报

第一部分:torch.max()

1. 官网链接

https://pytorch.org/docs/stable/generated/torch.max.html#torch.max

2. 解释+案例

2.1 torch.max(input) → Tensor

返回input中所有元素中的最大值

案例:

a = torch.randn(1, 3)

a

输出:

tensor([[ 0.0557, -0.7400, -0.8941]])

torch.max(a)

tensor(0.0557)

a.max()

tensor(0.0557)

2.2 torch.max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor)

返回一个namedtuple(values, indices), values是在指定dim下,input中每行的最大值;indices是最大值所在索引。

如果keepdim=True,output的size与input size保持一致(此种情况除外: input的维度为1,即dim=1)

如果keepdim=False,dim会被torch.squeeze缩小/压缩,导致output tensors的维度为1,少于input的维度。

案例1:

a = torch.randn(4,4 ) # 随机生成4行4列的数据

a

输出:

tensor([[ 1.1982, -0.2496, -0.3671, -1.2475], [-1.6641, 0.6409, 0.9440, -0.1829], [ 0.9641, -0.1747, -1.1281, 0.4016], [ 0.3706, 0.8722, -1.1174, -0.5317]])

torch.max(a, dim=0) # 横轴方向

输出:

torch.return_types.max( values=tensor([1.1982, 0.8722, 0.9440, 0.4016]), indices=tensor([0, 3, 1, 2]))

torch.max(a, dim=0, keepdim=True) # 横轴方向,keepdim=True

输出:

torch.return_types.max( values=tensor([[1.1982, 0.8722, 0.9440, 0.4016]]), indices=tensor([[0, 3, 1, 2]]))

torch.max(a, dim=1) # 纵轴方向

输出:

torch.return_types.max( values=tensor([1.1982, 0.9440, 0.9641, 0.8722]), indices=tensor([0, 2, 0, 1]))

torch.max(a, dim=1, keepdim=True) # 纵轴方向,keepdim=True

输出:

torch.return_types.max( values=tensor([[1.1982], [0.9440], [0.9641], [0.8722]]), indices=tensor([[0], [2], [0], [1]]))

第二部分:torch.nonzero()

1. 官网参考链接 :

https://pytorch.org/docs/stable/generated/torch.nonzero.html#torch-nonzero

2. 方法

torch.nonzero(input, *, out=None, as_tuple=False)

-> LongTensor or tuple of LongTensors

返回值: 默认返回一个2-D的tensor,包含非零值的索引。

官网解释:

torch.nonzero(..., as_tuple=False) (default) returns a 2-D tensor where each row is the index for a nonzero value.

torch.nonzero(..., as_tuple=True) returns a tuple of 1-D index tensors, allowing for advanced indexing, so x[x.nonzero(as_tuple=True)] gives all nonzero values of tensor x. Of the returned tuple, each index tensor contains nonzero indices for a certain dimension.

案例1:

torch.nonzero(torch.tensor([1,1,1,0,1]))

输出:

tensor([[0], [1], [2], [4]])

torch.nonzero(torch.tensor([1,1,1,0,1]), as_tuple=True) # 设置as_tuple

输出:

(tensor([0, 1, 2, 4]),)

案例2:

torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0],

[0.0, 0.0, 0.4, 0.0],

[0.0, 0.0, 1.2, 0.0],

[0.0, 0.0, 0.0,-0.4]]))

输出:

tensor([[0, 0], [1, 2], [2, 2], [3, 3]])

torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0],

[0.0, 0.0, 0.4, 0.0],

[0.0, 0.0, 1.2, 0.0],

[0.0, 0.0, 0.0,-0.4]]), as_tuple=True) # 设置as_tuple

# 返回值: 指定行索引,包含非零值

输出:

(tensor([0, 1, 2, 3]), tensor([0, 2, 2, 3]))

案例3:

t = torch.tensor([[0.6, 1.5, 2.3, 3.7],

[2.4, 0.0, 0.4, 1.8],

[5.1, 0.0, 1.2, 3.4],

[6.3, 4.8, 0.0,-0.4]])

t

输出:

tensor([[ 0.6000, 1.5000, 2.3000, 3.7000], [ 2.4000, 0.0000, 0.4000, 1.8000], [ 5.1000, 0.0000, 1.2000, 3.4000], [ 6.3000, 4.8000, 0.0000, -0.4000]])

t[:, 2:] > 0

输出:

tensor([[ True, True], [ True, True], [ True, True], [False, False]])

(t[:, 2:] > 0).nonzero(as_tuple=False)

# 返回一个2-D的tensor,包含非零值的索引

输出:

tensor([[0, 0], [0, 1], [1, 0], [1, 1], [2, 0], [2, 1]])

(t[:, 2:] > 0).nonzero(as_tuple=False).T # 转置

输出:

tensor([[0, 0, 1, 1, 2, 2], [0, 1, 0, 1, 0, 1]])

i, j = (t[:, 2:] > 0).nonzero(as_tuple=False).T

输出:

i # 行索引

tensor([0, 0, 1, 1, 2, 2])

j # 列索引

tensor([0, 1, 0, 1, 0, 1])

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2021-05-11,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 唐国梁Tommy 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档