首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >为TensorFlow中的每一行获取具有不同掩码的掩码argmax

为TensorFlow中的每一行获取具有不同掩码的掩码argmax
EN

Stack Overflow用户
提问于 2017-08-06 00:08:42
回答 2查看 1.5K关注 0票数 1

我有一个形状为Nx7的张量,它看起来像这样:

代码语言:javascript
复制
[0.97863993   0.64479575 -0.202357    0.94678476  0.0080051   0.44507797 0.47864
 0.05914348  -0.72649432  0.193803    0.47295245  0.8381458   0.30449861 0.46783]

我有另一个相同形状的张量,它是一个布尔掩码:

代码语言:javascript
复制
[True  False True  True  False True  False
 False True  False False True  False False]

我想要得到第一个张量中每一行的argmax,但只得到掩码为True的那些元素的argmax,所以基本上是以下数组的argmax:

代码语言:javascript
复制
[0.97863993  X          -0.202357    0.94678476  X           0.44507797 X
 X          -0.72649432  X           X           0.8381458   X          X]

因此,它应该成为:

代码语言:javascript
复制
[0
 4]

这在TensorFlow中是可能的吗?我正在尝试使用tf.boolean_mask解决这个问题,但我不知道如何处理掩码中具有不同数量的True值的不同行。

TF格式的输入代码:

代码语言:javascript
复制
mask = tf.placeholder(shape=[None, 7], dtype=tf.bool)
val = tf.placeholder(shape=[None, 7], dtype=tf.float32)

arg_max = ???

请注意,我还希望正确处理负值(否则,Ishant Mrinal提出的方法将有效)。

EN

回答 2

Stack Overflow用户

发布于 2017-08-06 00:25:57

将布尔数组转换为浮点数组

代码语言:javascript
复制
# mask = tf.placeholder(shape=[None, 7], dtype=tf.bool)
# mask = tf.cast(mask, dtype=tf.float32)
mask = tf.placeholder(shape=[None, 7], dtype=tf.float32)
val = tf.placeholder(shape=[None, 7], dtype=tf.float32)
argmax = tf.argmax(tf.multiply(val, mask), axis=1)
sess.run(argmax, {val: your_val_array, mask: 2*mask_bool_array.astype(float)-1 })
票数 0
EN

Stack Overflow用户

发布于 2017-08-06 00:38:57

要模拟带掩码的argmax,可以将掩码外的值设置为-inf,例如:

代码语言:javascript
复制
masked_val = tf.minimum(val, (2* tf.to_float(mask) - 1) * np.inf)
masked_arg_max = tf.argmax(masked_val, axis=1)

或者,要计算masked_val,您可以使用

代码语言:javascript
复制
masked_val = tf.where(mask, val, -tf.ones_like(val) * np.inf)

这可以说是更清晰,但可能会浪费内存。

对于一个带面具的argmin,你会做相反的事情:

代码语言:javascript
复制
masked_val = tf.maximum(val, (1 - 2* tf.to_float(mask)) * np.inf)
masked_arg_min = tf.argmin(masked_val, axis=1)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/45523919

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档