首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >tensorflow级联张量逐行

tensorflow级联张量逐行
EN

Stack Overflow用户
提问于 2017-12-06 22:03:52
回答 1查看 1.5K关注 0票数 1

我有一个BasicLSTMCell,它被输入到一个static_rnn中,它可以显示sent_max_len = 2的时间。输入是batch_size = 2的一批句子。每个单词都有一个embed_size = 6。输出如下所示:

代码语言:javascript
运行
复制
       e1           e2           e3            e4          e5           e6
[[-0.01236094, -0.00423804, -0.01091367,  0.00286771, -0.00911439, -0.00964547]  s1.w1      
[-0.0316297 ,  0.00904978, -0.02972977, -0.00720989,  0.00432076,  0.00946513]] s2.w1
                                       t1

[[-0.03660333,  0.00613474, -0.03758694, -0.0070029 , -0.00036427, 0.00386676], s1.w2        
[-0.04840172,  0.01757939, -0.05444464, -0.01508901,  0.01216465, 0.01938537]] s2.w2
                                        t2
s1 = sentence 1
w1 = word 1
t1 = timestep 1
e1 = embedding 1

我想把输出连接起来。但问题是,当我沿着tf.concat沿axis=0执行axis=0时,它获取t1的输出并与t2连接,如下所示:

代码语言:javascript
运行
复制
[[-0.01236094 -0.00423804 -0.01091367  0.00286771 -0.00911439 -0.00964547]  s1.w1
 [-0.0316297   0.00904978 -0.02972977 -0.00720989  0.00432076  0.00946513]  s2.w1
 [-0.03660333  0.00613474 -0.03758694 -0.0070029  -0.00036427  0.00386676]  s1.w2
 [-0.04840172  0.01757939 -0.05444464 -0.01508901  0.01216465  0.01938537]  s2.w2

但我想像这样连接起来:

代码语言:javascript
运行
复制
[[-0.01236094 -0.00423804 -0.01091367  0.00286771 -0.00911439 -0.00964547]  s1.w1
 [-0.03660333  0.00613474 -0.03758694 -0.0070029  -0.00036427  0.00386676]  s1.w2
 [-0.0316297   0.00904978 -0.02972977 -0.00720989  0.00432076  0.00946513]  s2.w1
 [-0.04840172  0.01757939 -0.05444464 -0.01508901  0.01216465  0.01938537]  s2.w2

因为这个concated_output通过其他层,我的最后一个predicted_outputactual_output比较,actual_output看起来如下所示:

代码语言:javascript
运行
复制
[[s1.w1.actualOutput]
 [s1.w2.actualOutput]
 [s2.w1.actualOutput]
 [s2.w2.actualOutput]
]

显然,我不能沿着axis=1连接,因为即使它将以正确的顺序给出输出,单词嵌入也将被合并。

EN

Stack Overflow用户

发布于 2017-12-07 05:25:20

我将使用tf.gather_nd,在这里您需要提供索引来收集给定张量中的项。例如:

代码语言:javascript
运行
复制
data1 = tf.constant(
    [
        [[1,1,1],[2,2,2]],
        [[3,3,3],[4,4,4]]

    ]
)
indices = tf.constant([
    [[0,0], [1,0]],
    [[0,1], [1,1]]
])
result = tf.gather_nd(data1, indices)

将给予:

代码语言:javascript
运行
复制
[[[1 1 1]
  [3 3 3]]

 [[2 2 2]
 [4 4 4]]]

然后,您可以使用concat和axis=0将张量转换为所需的格式。

票数 2
EN
查看全部 1 条回答
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/47684197

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档