首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >返回稀疏张量每行的top_k项

返回稀疏张量每行的top_k项
EN

Stack Overflow用户
提问于 2020-07-21 20:48:02
回答 1查看 219关注 0票数 1

对于密集的张量,我们可以使用tf.nn.topk来找到最后一维的k个最大条目的值和索引。

对于稀疏张量,我希望有效地获得每行的前n个项目,而不会将稀疏张量转换为密集张量。

EN

回答 1

Stack Overflow用户

发布于 2020-07-22 02:16:39

这有点棘手,但这里有一些工作(假设2D稀疏张量,尽管我认为对于更多的外部维度应该是相同的)。这个想法是首先对整个稀疏张量进行排序(而不是使其变得密集),然后对第一列进行切片。要做到这一点,我需要像np.lexsort这样的东西,据我所知,TensorFlow本身并没有提供这样的东西--然而,tf.sparse.reorder实际上做了一些类似词法排序的事情,所以我做了另一个中间稀疏张量来利用这一点。

代码语言:javascript
运行
复制
import tensorflow as tf
import numpy as np

np.random.seed(0)
# Input data
k = 3
r = np.random.randint(10, size=(6, 8))
r[np.random.rand(*r.shape) < .5] = 0
sp = tf.sparse.from_dense(r)
print(tf.sparse.to_dense(sp).numpy())
# [[0 0 0 0 0 0 3 0]
#  [2 4 0 6 8 0 0 6]
#  [7 0 0 1 5 9 8 9]
#  [4 0 0 3 0 0 0 3]
#  [8 1 0 3 3 7 0 1]
#  [0 0 0 0 7 0 0 7]]

# List of value indices
n = tf.size(sp.values, out_type=sp.indices.dtype)
r = tf.range(n)
# Sort values
s = tf.dtypes.cast(tf.argsort(sp.values, direction='DESCENDING'), sp.indices.dtype)
# Find destination index of each sorted value
si = tf.scatter_nd(tf.expand_dims(s, 1), r, [n])
# Abuse sparse tensor functionality to do lexsort with column and destination index
sp2 = tf.sparse.SparseTensor(indices=tf.stack([sp.indices[:, 0], si], axis=1),
                             values=r,
                             dense_shape=[sp.dense_shape[0], n])
sp2 = tf.sparse.reorder(sp2)
# Build top-k result
row = sp.indices[:, 0]
# Make column indices
d = tf.dtypes.cast(row[1:] - row[:-1] > 0, r.dtype)
m = tf.pad(r[1:] * d, [[1, 0]])
col = r - tf.scan(tf.math.maximum, m)
# Get only up to k elements per row
m = col < k
row_m = tf.boolean_mask(row, m)
col_m = tf.boolean_mask(col, m)
idx_m = tf.boolean_mask(sp2.values, m)
# Make result
scatter_idx = tf.stack([row_m, col_m], axis=-1)
scatter_shape = [sp.dense_shape[0], k]
# Use -1 for rows with less than k values
# (0 is ambiguous)
values = tf.tensor_scatter_nd_update(-tf.ones(scatter_shape, sp.values.dtype),
                                     scatter_idx, tf.gather(sp.values, idx_m))
indices = tf.tensor_scatter_nd_update(-tf.ones(scatter_shape, sp.indices.dtype),
                                      scatter_idx, tf.gather(sp.indices[:, 1], idx_m))
print(values.numpy())
# [[ 3 -1 -1]
#  [ 8  6  6]
#  [ 9  9  8]
#  [ 4  3  3]
#  [ 8  7  3]
#  [ 7  7 -1]]
print(indices.numpy())
# [[ 6 -1 -1]
#  [ 4  3  7]
#  [ 5  7  6]
#  [ 0  3  7]
#  [ 0  5  3]
#  [ 4  7 -1]]

编辑:这是另一种可能性,如果你的张量在所有行中都非常稀疏,那么它可能会工作得很好。这个想法是将所有稀疏的张量值“压缩”到第一列中(就像前面已经为sp3做的代码片段一样),然后将其变成一个密集的张量,并照常应用top-k。需要注意的是,索引将引用压缩张量,因此如果您想要获得关于初始稀疏张量的正确索引,则必须采取另一步。

代码语言:javascript
运行
复制
import tensorflow as tf
import numpy as np

np.random.seed(0)
# Input data
k = 3
r = np.random.randint(10, size=(6, 8))
r[np.random.rand(*r.shape) < .8] = 0
sp = tf.sparse.from_dense(r)
print(tf.sparse.to_dense(sp).numpy())
# [[0 0 0 0 0 0 3 0]
#  [0 4 0 6 0 0 0 0]
#  [0 0 0 0 5 0 0 9]
#  [0 0 0 0 0 0 0 0]
#  [8 0 0 0 0 7 0 0]
#  [0 0 0 0 7 0 0 0]]

# Build "condensed" sparse tensor
n = tf.size(sp.values, out_type=sp.indices.dtype)
r = tf.range(n)
# Make indices
row = sp.indices[:, 0]
d = tf.dtypes.cast(row[1:] - row[:-1] > 0, r.dtype)
m = tf.pad(r[1:] * d, [[1, 0]])
col = r - tf.scan(tf.math.maximum, m)
# At least as many columns as k
ncols = tf.maximum(tf.math.reduce_max(col) + 1, k)
sp2 = tf.sparse.SparseTensor(indices=tf.stack([row, col], axis=1),
                             values=sp.values,
                             dense_shape=[sp.dense_shape[0], ncols])
# Get in dense form
condensed = tf.sparse.to_dense(sp2)
# Top-k (indices do not correspond to initial sparse matrix)
values, indices = tf.math.top_k(condensed, k)
print(values.numpy())
# [[3 0 0]
#  [6 4 0]
#  [9 5 0]
#  [0 0 0]
#  [8 7 0]
#  [7 0 0]]

# Now get the right indices
sp3 = tf.sparse.SparseTensor(indices=tf.stack([row, col], axis=1),
                             values=sp.indices[:, 1],
                             dense_shape=[sp.dense_shape[0], ncols])
condensed_idx = tf.sparse.to_dense(sp3)
actual_indices = tf.gather_nd(condensed_idx, tf.expand_dims(indices, axis=-1),
                              batch_dims=1)
print(actual_indices.numpy())
# [[6 0 0]
#  [3 1 0]
#  [7 4 0]
#  [0 0 0]
#  [0 5 0]
#  [4 0 0]]

不过,我不确定这样做是否会更快。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63014913

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档