我正在寻找一种方法在tensorflow中有一个条件打印节点,使用下面的样例代码行,其中每10个循环计数,它应该在控制台中打印一些东西。但它对我不起作用。有人能给点建议吗?
谢谢,哈米德雷扎
epsilon = tf.cond(tf.constant(counter % 10 == 0, dtype=tf.bool), true_fn=lambda:tf.Print(epsilon, [counter, epsilon], 'batch: ', summarize=10), false_fn=lambda:epsilon)发布于 2020-01-16 19:58:36
我遇到了一个类似的问题,并用下面这个丑陋的解决方案解决了它。
我使用的是tf.print而不是'tf.Print‘,因为它是解密的:
因为true_func和false_func应该返回相同的类型和形状,所以我返回一个没有意义的常量。
def true_func():
    printfunc = tf.print(TENSOR_TO_PRINT,summarize=-1)
    with tf.control_dependencies([printfunc]):
        return tf.constant(1) 
def false_func():
    return tf.constant(1)
shouldprint = tf.cond(YOUR_CONDITION ,true_fn = true_func ,false_fn=false_func)然后,您可以运行shouldprint操作或将其添加为其他操作的依赖项。
https://stackoverflow.com/questions/50830960
复制相似问题