我在Tensorflow 1.9.0中定义了一个自定义损失函数(由于项目限制无法升级)。我有以下变量,经过特征值分解后得到:
# eigw.shape = (?, x)
# eigv.shape = (?, x, y)现在,我想计算eigw的argmax,这样
amax = tf.argmax(eigw, axis=1, output_type=tf.int32)
# amax.shape = (?,)我想用amax中给出的值来索引eigv,这样
# result.shape = (?, y)我该如何做到这一点?我尝试直接访问它,但这样做我遇到了形状不具有相同排名的问题。另外,我尝试过使用tf.while_loop,但我是tf的新手,因此我没有成功。
我还有其他选择吗?怎样才能最容易地解决这个问题呢?
谢谢
发布于 2020-02-22 01:26:23
在您的特定情况下,您可以使用任何TensorFlow函数来收集轴上的最大值,而不是索引。
max_value = tf.math.reduce_max(eigw, axis=1)您可以在文档中看到任何其他参数。由于没有更多关于tesnorlfow.org的TF1.9文档,我可以找到r1.15,它仍然使用静态图形。https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/math/reduce_max
https://stackoverflow.com/questions/60343199
复制相似问题