首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何解释tf.map_fn的结果?

如何解释tf.map_fn的结果?
EN

Stack Overflow用户
提问于 2017-09-07 20:47:03
回答 1查看 12.3K关注 0票数 8

看一下代码:

代码语言:javascript
运行
复制
import tensorflow as tf
import numpy as np

elems = tf.ones([1,2,3],dtype=tf.int64)
alternates = tf.map_fn(lambda x: (x, x, x), elems, dtype=(tf.int64, tf.int64, tf.int64))
with tf.Session() as sess:
    print(sess.run(alternates))

输出为:

代码语言:javascript
运行
复制
(array([[[1, 1, 1],
        [1, 1, 1]]], dtype=int64), array([[[1, 1, 1],
        [1, 1, 1]]], dtype=int64), array([[[1, 1, 1],
        [1, 1, 1]]], dtype=int64))

我看不懂输出,谁能告诉我?

更新

elems是一个张量,所以应该沿着-0轴将它解包,我们将得到[[1,1,1],[1,1,1]],然后map_fn[[1,1,1],[1,1,1]]传递给lambda x:(x,x,x),这意味着x=[[1,1,1],[1,1,1]],我认为map_fn的输出是

代码语言:javascript
运行
复制
[[[1,1,1],[1,1,1]],
 [[1,1,1],[1,1,1]],
 [[1,1,1],[1,1,1]]]

输出形状是[3,2,3]shape(2,3)列表

但实际上,输出是张量列表,每个张量的形状是[1,2,3]

或者换句话说:

代码语言:javascript
运行
复制
import tensorflow as tf
import numpy as np

elems = tf.constant([1,2,3],dtype=tf.int64)
alternates = tf.map_fn(lambda x: (x, 2*x, -x), elems, dtype=(tf.int64, tf.int64, tf.int64))
with tf.Session() as sess:
    print(sess.run(alternates))

为什么输出是

代码语言:javascript
运行
复制
(array([1, 2, 3], dtype=int64), 
 array([2, 4, 6], dtype=int64), 
 array([-1, -2, -3], dtype=int64))

而不是

代码语言:javascript
运行
复制
(array([1, 2, -1], dtype=int64), 
 array([2, 4, -2], dtype=int64), 
 array([3, 6, -3], dtype=int64))

这两个问题是一样的。

Update2

代码语言:javascript
运行
复制
import tensorflow as tf
import numpy as np

elems = [tf.constant([1,2,3],dtype=tf.int64)]
alternates = tf.map_fn(lambda x: x, elems, dtype=tf.int64)
with tf.Session() as sess:
    print(sess.run(alternates))

elems是一个张量列表,所以根据接口,tf.constant([1,2,3],dtype=tf.int64)将沿着0轴解包,所以map_fn将作为[x for x in [1,2,3]]工作,但实际上它会引发一个错误。

代码语言:javascript
运行
复制
ValueError: The two structures don't have the same nested structure. First struc
ture: <dtype: 'int64'>, second structure: [<tf.Tensor 'map/while/TensorArrayRead
V3:0' shape=() dtype=int64>].

怎么了?

update3

代码语言:javascript
运行
复制
import tensorflow as tf
import numpy as np

elems = (tf.constant([1,2,3],dtype=tf.int64),tf.constant([1,2,3],dtype=tf.int64))
alternates = tf.map_fn(lambda x: x, elems, dtype=(tf.int64, tf.int64))
with tf.Session() as sess:
    print(sess.run(alternates))

输出为

代码语言:javascript
运行
复制
(array([1, 2, 3], dtype=int64), array([1, 2, 3], dtype=int64))

似乎elems没有被拆开,为什么?

代码语言:javascript
运行
复制
import tensorflow as tf
import numpy as np

elems = (tf.constant([1,2,3],dtype=tf.int64),tf.constant([1,2,3],dtype=tf.int64))
alternates = tf.map_fn(lambda x: [x], elems, dtype=(tf.int64, tf.int64))
with tf.Session() as sess:
    print(sess.run(alternates))

它将引发一个错误

代码语言:javascript
运行
复制
TypeError: The two structures don't have the same sequence type. First structure
 has type <class 'tuple'>, while second structure has type <class 'list'>.

谁能告诉我tf.map_fn是如何工作的?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-09-07 21:03:51

第一,

代码语言:javascript
运行
复制
elems = tf.ones([1,2,3],dtype=tf.int64)

elems是一个形状为1x2x3的3维张量,即:

代码语言:javascript
运行
复制
[[[1, 1, 1],
  [1, 1, 1]]]

然后,

代码语言:javascript
运行
复制
alternates = tf.map_fn(lambda x: (x, x, x), elems, dtype=(tf.int64, tf.int64, tf.int64))

alternates是与elems形状相同的三个张量的元组,每个张量都是根据给定的函数构建的。由于该函数只返回一个重复输入三次的元组,这意味着这三个张量将与elems相同。如果函数是lambda x: (x, 2 * x, -x),那么第一个输出张量将与elems相同,第二个输出张量将是elems的两倍,而第三个输出张量则相反。

在所有这些情况下,最好使用常规操作而不是tf.map_fn;但是,可能存在这样的情况:您有一个接受N维张量的函数,而您有一个N+1的张量要应用于此张量。

更新:

可以这么说,我认为您正在以“相反的方式”思考tf.map_fn。张量中的元素或行数与函数中的输出数之间不存在一对一的对应关系;实际上,您可以传递一个函数,返回一个包含任意多个元素的元组。

以你的最后一个例子为例:

代码语言:javascript
运行
复制
elems = tf.constant([1,2,3],dtype=tf.int64)
alternates = tf.map_fn(lambda x: (x, 2*x, -x), elems, dtype=(tf.int64, tf.int64, tf.int64))

tf.map_fn首先在第一个轴上拆分elems,即拆分为123,并将函数应用于其中的每一个,得到:

代码语言:javascript
运行
复制
(1, 2, -1)
(2, 4, -2)
(3, 6, -3)

请注意,正如我所说的,这些元组中的每一个都可以有您想要的任意多个元素。现在,最终输出是在相同位置中连接结果产生的;因此,您将获得:

代码语言:javascript
运行
复制
[1, 2, 3]
[2, 4, 6]
[-1, -2, -3]

同样,如果函数生成具有更多元素的元组,您将获得更多的输出张量。

更新2:

关于你的新例子:

代码语言:javascript
运行
复制
import tensorflow as tf
import numpy as np

elems = (tf.constant([1,2,3],dtype=tf.int64),tf.constant([1,2,3],dtype=tf.int64))
alternates = tf.map_fn(lambda x: x, elems, dtype=(tf.int64, tf.int64))
with tf.Session() as sess:
    print(sess.run(alternates))

documentation说:

此方法还允许fn的多极性元素和输出。如果elems是张量的(可能是嵌套的)列表或元组,那么这些张量中的每一个都必须具有匹配的第一个(解包)维度。fn的签名可以匹配元素的结构。也就是说,如果elems是(t1,[t2,t3,t4,t5]),那么fn的合适签名是: fn = lambda (t1,[t2,t3,t4,t5]):。

这里,elems是两个张量的元组,根据需要在第一维中具有相同的大小。tf.map_fn一次接受每个输入张量的一个元素(因此是两个元素的元组),并对其应用给定的函数,这应该返回与您在dtypes中传递的相同的结构(也是两个元素的元组);如果不提供dtypes,则预期的输出与输入相同(同样,是两个元素的元组,因此在本例中dtypes是可选的)。不管怎样,它是这样的:

代码语言:javascript
运行
复制
f((1, 1)) -> (1, 1)
f((2, 2)) -> (2, 2)
f((3, 3)) -> (3, 3)

这些结果组合在一起,连接结构中所有相应的元素;在这种情况下,第一个位置的所有数字产生第一个输出,第二个位置的所有数字产生第二个输出。最后,结果是请求的结构(两个元素的元组)填充了这些连接:

代码语言:javascript
运行
复制
([1, 2, 3], [1, 2, 3])
票数 7
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/46096767

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档