前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >我对torch中的gather函数的一点理解

我对torch中的gather函数的一点理解

作者头像
树枝990
发布2020-08-20 01:20:38
8800
发布2020-08-20 01:20:38
举报
文章被收录于专栏:拇指笔记拇指笔记

官方文档的解释

代码语言:javascript
复制
torch.gather(input,dim,index,out=None) → Tensor
torch.gather(input, dim, index, out=None) → Tensor
    Gathers values along an axis specified by dim.
    For a 3-D tensor the output is specified by:
    out[i][j][k] = input[index[i][j][k]][j][k]  # dim=0    out[i][j][k] = input[i][index[i][j][k]][k]  # dim=1    out[i][j][k] = input[i][j][index[i][j][k]]  # dim=2
    Parameters:
        input (Tensor) – The source tensor        dim (int) – The axis along which to index        index (LongTensor) – The indices of elements to gather        out (Tensor, optional) – Destination tensor
    Example:
    >>> t = torch.Tensor([[1,2],[3,4]])    >>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))     1  1     4  3    [torch.FloatTensor of size 2x2]

举个例子

代码语言:javascript
复制
import torch
a = torch.Tensor([[1,2],                 [3,4]])
b = torch.gather(a,1,torch.LongTensor([[0,0],[1,0]]))#1. 取各个元素行号:[(0,y)(0,y)][(1,y)(1,y)]#2. 取各个元素值做行号:[(0,0)(0,0)][(1,1)(1,0)]#3. 根据得到的索引在输入中取值#[1,1],[4,3]
c = torch.gather(a,0,torch.LongTensor([[0,0],[1,0]]))#1. 取各个元素列号:[(x,0)(x,1)][(x,0)(x,1)]#2. 取各个元素值做行号:[(0,0)(0,1)][(1,0)(0,1)]#3. 根据得到的索引在输入中取值#[1,2],[3,2]

原理解释

假设输入与上同;index=B;输出为C B中每个元素分别为b(0,0)=0,b(0,1)=0 b(1,0)=1,b(1,1)=0

如果dim=0(列) 则取B中元素的列号,如:b(0,1)的1 b(0,1)=0,所以C中的c(0,1)=输入的(0,1)处元素2

如果dim=1(行) 则取B中元素的列号,如:b(0,1)的0 b(0,1)=0,所以C中的c(0,1)=输入的(0,0)处元素1

总结如下:输出 元素 在 输入张量 中的位置为:输出元素位置取决于同位置的index元素 dim=1时,取同位置的index元素的行号做行号,该位置处index元素做列号 dim=0时,取同位置的index元素的列号做列号,该位置处index元素做行号。

最后根据得到的索引在输入中取值

index类型必须为LongTensor gather最终的输出变量与index同形。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-03-02,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 拇指笔记 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 官方文档的解释
  • 举个例子
  • 原理解释
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档