前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >torch.scatter_

torch.scatter_

作者头像
烤粽子
发布2021-11-03 14:21:47
6680
发布2021-11-03 14:21:47
举报

torch.scatter_

Tensor.scatter_(dim, index, src, reduce=None) → Tensor Parameters

  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to scatter, can be either empty or of the same dimensionality as src. When empty, the operation returns self unchanged.
  • src (Tensor or float) – the source element(s) to scatter. 要填进去的元素
  • reduce (str, optional) – reduction operation to apply, can be either 'add' or 'multiply'. 用的相对较少。

直接看例子,

代码语言:javascript
复制
>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],
        [0, 2, 0, 0, 0],
        [0, 0, 3, 0, 0]])
# 从这个例子出发来简单说明:首先dim=0,意味着需要沿着axis=0的方向进行操作,即index每一列逐渐增大,按列找到对应的索引号,然后按顺序把src中的元素填进去。
请添加图片描述
请添加图片描述
代码语言:javascript
复制
>>> index = torch.tensor([[0, 1, 2], [0, 1, 4]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
tensor([[1, 2, 3, 0, 0],
        [6, 7, 0, 0, 8],
        [0, 0, 0, 0, 0]])
# dim=1, 按行找到对应的index,按顺序把src中的元素填进去

>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),1.23)
tensor([[2.0000, 2.0000, 1.2300, 2.0000],
        [2.0000, 2.0000, 2.0000, 1.2300]])
# dim=1, 按行找到对应的index,按顺序把src中的元素填进去,不用管原来的位置是什么数字。
请添加图片描述
请添加图片描述

**注意:**index可以不用满,src按顺序填充。

代码语言:javascript
复制
>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
...            1.23, reduce='multiply')
tensor([[2.0000, 2.0000, 2.0000*1.23, 2.0000],
        [2.0000, 2.0000, 2.0000, 2.000*1.23]])
tensor([[2.0000, 2.0000, 2.4600, 2.0000],
        [2.0000, 2.0000, 2.0000, 2.4600]])
# dim=1, 按行找到对应的index,按顺序把src中的元素乘上去

>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
...            1.23, reduce='add')
tensor([[2.0000, 2.0000, 2.0000+1.23, 2.0000],
        [2.0000, 2.0000, 2.0000, 2.000+1.23]])
tensor([[2.0000, 2.0000, 3.2300, 2.0000],
        [2.0000, 2.0000, 2.0000, 3.2300]])
# dim=1, 按行找到对应的index,按顺序把src中的元素加上去上去

Reference:

[1] TORCH.TENSOR.SCATTER_

tps://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html)

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2021-11-02 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • torch.scatter_
  • Reference:
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档