首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >torch.gather() 和torch.sactter_()的用法简析

torch.gather() 和torch.sactter_()的用法简析

作者头像
TeeyoHuang
发布2019-05-25 22:42:06
2K0
发布2019-05-25 22:42:06
举报

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://cloud.tencent.com/developer/article/1433787

torch.gather(input, dim, index, out=None)torch.scatter_(dim, index, src)是一对作用相反的方法

先来看torch.gather, 核心操作其实就是这样:

outik = inputindex[i][j][k]k # if dim == 0

outik = inputiindexik]k # if dim == 1

outik = inputi[indexik] # if dim == 2

是对于out指定位置上的值,去寻找input里面对应的索引位置,根据是index

官方文档给的例子是:

>>> t = torch.Tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
 1  1
 4  3
[torch.FloatTensor of size 2x2]

具体过程就是这里的input = [1,2,3,4], index = [0,0,1,0], dim = 1, 则

out0 = input0 index0 ] = input0 = 1

out0 = input0 index0 ] = input0 = 1

out1 = input1 index1 ] = input1 = 4

out1 = input1 index1 ] = input1 = 3

torch.scatter_(dim, index, src)

核心操作:

self index[i][j][k] k = srcik # if dim == 0

self i indexik ] k = srcik # if dim == 1

self i [ indexik ] = srcik # if dim == 2

这个就是对于src(或者说input)指定位置上的值,去分配给output对应索引位置,根据是index,所以其实把src放在左边更容易理解,官方给的例子如下:

x = torch.rand(2, 5)
>>> x

 0.4319  0.6500  0.4080  0.8760  0.2355
 0.2609  0.4711  0.8486  0.8573  0.1029
[torch.FloatTensor of size 2x5]

>>> torch.zeros(3, 5).scatter_(0, torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)

 0.4319  0.4711  0.8486  0.8760  0.2355
 0.0000  0.6500  0.0000  0.8573  0.0000
 0.2609  0.0000  0.4080  0.0000  0.1029
[torch.FloatTensor of size 3x5]

此例中,src就是x,index就是[0, 1, 2, 0, 0, 2, 0, 0, 1, 2], dim=0

我们把src写在左边,把self写在右边,这样好理解一些,

但要注意是把src的值赋给self,所以用箭头指过去:

0.4319 = Src0 ----->self index[0][0] ----> self0

0.6500 = Src0 ----->self index[0][1] ----> self1

0.4080 = Src0 ----->self index[0][2] ----> self2

0.8760 = Src0 ----->self index[0][3] ----> self0

0.2355 = Src0 ----->self index[0][4] ----> self0

0.2609 = Src1 ----->self index[1][0] ----> self2

0.4711 = Src1 ----->self index[1][1] ----> self0

0.8486 = Src1 ----->self index[1][2] ----> self0

0.8573 = Src1 ----->self index[1][3] ----> self1

0.1029 = Src1 ----->self index[1][4] ----> self2

则我们把src也就是 x的每个值都成功的分配了出去,然后我们再把self对应位置填好

剩下的未得到分配的位置,就填0补充。

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2018年08月29日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档