前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >tf.scatter_nd()

tf.scatter_nd()

作者头像
狼啸风云
修改2022-09-03 19:32:09
1.5K0
修改2022-09-03 19:32:09
举报

Scatter updates into a new tensor according to indices.

代码语言:javascript
复制
tf.scatter_nd(
    indices,
    updates,
    shape,
    name=None
)

Creates a new tensor by applying sparse updates to individual values or slices within a tensor (initially zero for numeric, empty for string) of the given shape according to indices. This operator is the inverse of the tf.gather_nd operator which extracts values or slices from a given tensor.

This operation is similar to tensor_scatter_add, except that the tensor is zero-initialized. Calling tf.scatter_nd(indices, values, shape) is identical to tensor_scatter_add(tf.zeros(shape, values.dtype), indices, values)

If indices contains duplicates, then their updates are accumulated (summed).

WARNING: The order in which updates are applied is nondeterministic, so the output will be nondeterministic if indices contains duplicates -- because of some numerical approximation issues, numbers summed in different order may yield different results.

indices is an integer tensor containing indices into a new tensor of shape shape. The last dimension of indices can be at most the rank of shape:

代码语言:javascript
复制
indices.shape[-1] <= shape.rank

The last dimension of indices corresponds to indices into elements (if indices.shape[-1] = shape.rank) or slices (if indices.shape[-1] < shape.rank) along dimension indices.shape[-1] of shape. updates is a tensor with shape

代码语言:javascript
复制
indices.shape[:-1] + shape[indices.shape[-1]:]

The simplest form of scatter is to insert individual elements in a tensor by index. For example, say we want to insert 4 scattered elements in a rank-1 tensor with 8 elements.

In Python, this scatter operation would look like this:

代码语言:javascript
复制
    indices = tf.constant([[4], [3], [1], [7]])
    updates = tf.constant([9, 10, 11, 12])
    shape = tf.constant([8])
    scatter = tf.scatter_nd(indices, updates, shape)
    with tf.Session() as sess:
      print(sess.run(scatter))

The resulting tensor would look like this:

代码语言:javascript
复制
[0, 11, 0, 10, 9, 0, 0, 12]

We can also, insert entire slices of a higher rank tensor all at once. For example, if we wanted to insert two slices in the first dimension of a rank-3 tensor with two matrices of new values.

In Python, this scatter operation would look like this:

代码语言:javascript
复制
    indices = tf.constant([[0], [2]])
    updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
                            [7, 7, 7, 7], [8, 8, 8, 8]],
                           [[5, 5, 5, 5], [6, 6, 6, 6],
                            [7, 7, 7, 7], [8, 8, 8, 8]]])
    shape = tf.constant([4, 4, 4])
    scatter = tf.scatter_nd(indices, updates, shape)
    with tf.Session() as sess:
      print(sess.run(scatter))

The resulting tensor would look like this:

代码语言:javascript
复制
[[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
 [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
 [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
 [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]

Note that on CPU, if an out of bound index is found, an error is returned. On GPU, if an out of bound index is found, the index is ignored.

Args:

  • indices: A Tensor. Must be one of the following types: int32, int64. Index tensor.
  • updates: A Tensor. Updates to scatter into output.
  • shape: A Tensor. Must have the same type as indices. 1-D. The shape of the resulting tensor.
  • name: A name for the operation (optional).

Returns:

  • A Tensor. Has the same type as updates.

Compat aliases

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档