首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >Tensorflow中张量对象在map_fn中具有多参数和一个返回值的自定义函数

Tensorflow中张量对象在map_fn中具有多参数和一个返回值的自定义函数
EN

Stack Overflow用户
提问于 2021-04-01 18:32:02
回答 1查看 434关注 0票数 0

我有两个张量t1和t2 (shape=(64,64,3),dtype=tf.float64)。我想要执行一个自定义函数"func“,它以两个张量作为输入,并返回一个新的张量。

代码语言:javascript
运行
复制
@tf.function
def func(a):
  t1 = a[0]
  t2 = a[1]
  
  return tf.add(t1, t2)

我使用tensorflow的map_fn为输入的每个元素执行函数。

代码语言:javascript
运行
复制
t = tf.map_fn(fn=func, elems=(t1, t2), dtype=(tf.float64, tf.float64))
tf.print(t)

用于测试的样本输入张量是,

代码语言:javascript
运行
复制
t1 = tf.constant([[1.1, 2.2, 3.3],
                  [4.4, 5.5, 6.6]])
t2 = tf.constant([[7.7, 8.8, 9.9],
                  [10.1, 11.11, 12.12]])

我不能用两个参数来使用map_fn。在tf.stack中尝试过,也可以打开堆栈,但也不起作用。知道怎么做吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-04-02 04:23:37

"map_fn“的"elems”参数解压缩沿轴0传递给它的参数。因此,为了在自定义函数中传递多个张量,

  1. 我们必须将它们叠加在一起。
  2. 沿轴0添加一个额外的维度。

代码语言:javascript
运行
复制
# t1 and t2 has shape [2, 3]
val = tf.stack([t1, t2]) # shape is now [2, 2, 3]
val = tf.expand_dims(val, axis=0) # shape is now [1, 2, 2, 3]
t = tf.map_fn(fn=func, elems=val, dtype=tf.float64)

此外,"map_fn“的"dtype”应该是函数的返回类型。例如,在这种情况下,应该是tf.float64。如果函数返回一个元组,那么dtype也将是一个tuple。

代码语言:javascript
运行
复制
@tf.function
def func(a): # a has shape [2, 2, 3]
  t1 = a[0] # shape [2, 3]
  t2 = a[1] # shape [2, 3]

  return tf.add(t1, t2)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66909817

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档