首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >用形状(x,y)的二维布尔掩码掩蔽形状(x,y,z)的三维张量

用形状(x,y)的二维布尔掩码掩蔽形状(x,y,z)的三维张量
EN

Stack Overflow用户
提问于 2018-09-04 16:23:11
回答 1查看 56关注 0票数 3

我有以下张量,形状的output (2,6,2):

代码语言:javascript
运行
复制
[[[0.4 0.2]
  [0.7 0.5]
  [0.4 0.1]
  [0.5 0.4]
  [0.9 0.7]
  [0.2 0.9]]

 [[0.6 0.6]
  [0.3 0.5]
  [0.7 0.2]
  [0.8 0.1]
  [0.3 0.5]
  [0.4 0.7]]]

并具有以下布尔掩码张量,形状为(2,6)的mask

代码语言:javascript
运行
复制
mask = tf.sequence_mask(lengths=[3, 4] maxlen=6)

[[ True  True  True False False False]
 [ True  True  True  True False False]]

如何使用mask (或它的调整)应用masked_output = tf.boolean_mask(output, masks),结果如下:

代码语言:javascript
运行
复制
[[[0.4 0.2]
  [0.7 0.5]
  [0.4 0.1]
  [0.0 0.0]
  [0.0 0.0]
  [0.0 0.0]]

 [[0.6 0.6]
  [0.3 0.5]
  [0.7 0.2]
  [0.8 0.1]
  [0.0 0.0]
  [0.0 0.0]]]

编辑

下面的变化,但它似乎仍然是一种混乱的方式。如有其他建议,敬请见谅。

代码语言:javascript
运行
复制
mask = tf.sequence_mask(lengths=[[3, 3], [4, 4]] maxlen=6)
mask = tf.transpose(mask, [0, 2, 1])
EN

回答 1

Stack Overflow用户

发布于 2018-09-04 16:26:41

也许不是最有效的方法,但有效

代码语言:javascript
运行
复制
flat_mask = tf.reshape(mask, shape=(-1,))
expanded_flat_mask = tf.tile(flat_mask, [2])
expanded_mask = tf.reshape(expanded_flat_mask, shape=(2, 6, 2))
masked_output = tf.boolean_mask(output, expanded_mask)

或者另一种

代码语言:javascript
运行
复制
masked_output = output * tf.cast(tf.expand_dims(mask, axis=-1), output.dtype)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/52170485

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档