首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >对top_k输出使用scatter_nd

对top_k输出使用scatter_nd
EN

Stack Overflow用户
提问于 2018-10-26 03:11:20
回答 2查看 422关注 0票数 2

我一直在尝试做一些看似简单的事情,但没有成功。我有一个(?,4)张量,其中每一行将是01之间的4个浮点数。我想用一个新的张量替换它,其中每一行只有前2个条目,其他地方都是零。

使用(2, 4)的示例

代码语言:javascript
复制
source = [ [0.1, 0.2, 0.5, 0.6],
           [0.8, 0.7, 0.2, 0.1] ]

result = [ [0.0, 0.0, 0.5, 0.6],
           [0.8, 0.7, 0.0, 0.0] ]

我尝试在源代码上使用 top_k ,然后使用 scatter_nd 和top_k返回的索引,但在scatter_nd中确实有4个小时的形状和排名错误。

我已经准备好放弃了,但我想我应该先在这里寻求帮助。我在这里发现了几个密切相关的问题,但我无法概括我的案例中的信息。

我刚刚尝试的另一种方法是:

代码语言:javascript
复制
tensor = tf.constant( [ [0.1, 0.2, 0.8], [0.1, 0.2, 0.7] ])
values, indices = tf.nn.top_k(tensor, 1)
elems = (tensor, values)
masked_a = tf.map_fn( 
           lambda a : tf.where( tf.greater_equal(a[0], a[1]), a[0], 
           tf.zeros_like(a[0]) ), 
           elems)

但这一条给出了以下错误:

代码语言:javascript
复制
ValueError: The two structures don't have the same number of elements.
First structure (2 elements): (tf.float32, tf.float32)
Second structure (1 elements): Tensor("map/while/Select:0", shape=(3,), dtype=float32)

我是TensorFlow的新手,所以如果我错过了一些简单的或不清楚的东西,很抱歉。

谢谢!

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2018-10-26 06:09:51

您可以通过将行索引附加到top_k返回的索引来使用tf.scatter_nd完成此操作。

代码语言:javascript
复制
import tensorflow as tf

source = tf.constant([
    [0.1, 0.2, 0.5, 0.6],
    [0.8, 0.7, 0.2, 0.1]])

# get indices of top k
k = 2
top_k, top_k_inds = tf.nn.top_k(source, k, )

# indices are only columns, we will stack 
# it so the row indice is also there and
# make tensor of row numbers ie.
# [[0, 0],
#  [1, 1],
#  ...
num_rows = tf.shape(source)[0]
row_range = tf.range(num_rows)
row_tensor = tf.tile(row_range[:,None], (1, k))

# stack along the final dimension, as this is what
# scatter_nd uses as the indices
top_k_row_col_indices = tf.stack([row_tensor, top_k_inds], axis=2)

# to mask off everything, we will multiply the top_k by
# 1. so all the updates are just 1
updates = tf.ones([num_rows, k], dtype=tf.float32)

# build the mask
zero_mask = tf.scatter_nd(top_k_row_col_indices, updates, [num_rows, 4])

with tf.Session() as sess:
    zeroed = source*zero_mask
    print(zeroed.eval())

这应该会打印出来

代码语言:javascript
复制
[[0.  0.  0.5 0.6]
[0.8 0.7 0.  0. ]]
票数 1
EN

Stack Overflow用户

发布于 2019-04-08 10:48:38

只需粘贴几行代码:)

代码语言:javascript
复制
import tensorflow as tf

def attach_indice(tensor, top_k = None):
    flatty = tf.reshape(tensor, [-1])
    orig_shape = tf.shape(tensor)
    length = tf.shape(flatty)[0]
    if top_k is not None:
        orig_shape = orig_shape[:-1] # dim for top_k
        length //= top_k
    indice = tf.unravel_index(tf.range(length), orig_shape)
    indice = tf.transpose(indice)
    if indice.dtype != tensor.dtype:
        indice = tf.cast(indice, tensor.dtype)
    if top_k is not None:
        _dims = len(tensor.shape) - 1 # indice of indice
        shape = [1 for _ in range(_dims)]
        shape[-1] *= top_k
        indice = tf.reshape(tf.tile(indice, shape), [-1, _dims])
    return tf.concat([indice, flatty[:, None]], -1)

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# tf.enable_eager_execution()

from time import time

top_k = 3
shape = [50, 40, 100]

q = tf.random_uniform(shape)

# fast: 4.376221179962158 (GPU) / 2.483684778213501 (CPU)
v, k = tf.nn.top_k(q, top_k)
k = attach_indice(k, top_k)
s = tf.scatter_nd(k, tf.reshape(v, [-1]), shape)

# very slow: 281.82796931266785 (GPU) / 35.163344860076904 (CPU)
# s = tf.map_fn(lambda v__k__: tf.map_fn(lambda v_k_: tf.scatter_nd(v_k_[1][:, None], v_k_[0], [shape[-1]]), v__k__, q.dtype), tf.nn.top_k(q, top_k), q.dtype)

start = time()
with tf.Session() as sess:
    for _ in range(1000):
        sess.run(s)
print('time', time() - start)
票数 -1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/52996505

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档