在tensorflow 2.0中,我有两个整数张量(tf.uint8),我们称它们为A和B。张量A的秩是任意的,而B是一元的。我正在寻找的结果是得到一个bools张量C (tf.bool),这样:
(对于该示例,假设A为3级)
中存在Ai,j,k
(i,j,k是这里使用的指数,只是为了澄清这个概念)
总结而言,我需要检查A的元素是否在B中,并创建一个掩码(C),说明A的哪些元素在B中,哪些元素不在B中。
可视化示例(实际上不是代码,只是研究行为的可视化表示):
A = [[1,2,3],
[4,5,6]]
B = [1,5]
C = [[True, False, False],
[False, True, False]]
发布于 2019-12-26 22:33:21
你可以做下面的事。我找不到一种用矢量化的方式来解决这个问题的方法,因为你想让它在任意大小的A上工作,但是只要B不是很长,它就会工作得很好。
A = tf.constant([[1,2,3],[4,5,6]])
B = tf.constant([1,5])
C = tf.math.greater(tf.reduce_sum(tf.map_fn(lambda b: tf.cast(tf.math.equal(A,b), tf.int32), B), axis=0),0)
https://stackoverflow.com/questions/59492964
复制相似问题