首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >双向LSTM合并模式说明

双向LSTM合并模式说明
EN

Stack Overflow用户
提问于 2020-07-20 16:04:23
回答 2查看 2.9K关注 0票数 4

当使用双向LSTM进行序列分类时,我想了解更多关于合并模式的细节,特别是对于"Concat“合并模式,这对我来说仍然是相当不清楚的。

根据我对这个方案的理解:

在将前向层和后向层的合并结果传递到sigmoid函数中后,计算输出y_t。对于"add“、"mul”和"average“合并模式,这似乎相当直观,但我不明白当选择”concat“合并模式时,输出y_t是如何计算的。实际上,在这种合并模式下,我们现在在sidmoid函数之前有了一个向量而不是单个值。

EN

回答 2

Stack Overflow用户

发布于 2020-07-20 19:20:07

在双LSTM中,您将有一个LSTM在输入(例如X)上从左向右展开(例如LSTM1),另一个LSTM从右向左展开(假设您的输入大小(X.shape)为长度,其中每个要素的长度/时间步数/否:of

  • n:Batch size
  • tsequence unrollings)
  • f:No:Of time-step

:of

  1. LSTM1/X.shape/no:of
    • n:Batch size
    • tsequence unrollings)
    • f:No:Of time-step

  1. 假设我们有一个具有单个Bi-LSTM的模型,其定义如下

代码语言:javascript
复制
model.add(Bidirectional(LSTM(10, return_sequences=True), input_shape=(t, f)))

在这种情况下,

  1. 将返回大小为n X t X 10的输出,而LSTM2将返回大小为n X t X 10
  2. Now的输出。您可以选择如何使用merge_mode

在每个时间步组合LSTM2和n X t X 10的输出

sum:在每个时间步将LSTM1输出添加到LSTM2。即。n X t X 10 of LSTM1 + n X t X 10 of LSTM2 =大小n X t X 10的输出

mul:在每个时间步将LSTM1输出按元素相乘到LSTM2,这将导致输出大小为n X t X 10

concat:在每个时间步以元素方式将LSTM1输出连接到LSTM2,这将导致输出大小为n X t X 10*2

ave:每个时间步LSTM1输出到LSTM2的元素平均值,这将导致输出大小为n X t X 10

None:将LSTM1和LSTM2输出作为列表返回

基于merge_mode合并输出后,不应用激活函数。如果您想要应用激活,则必须在模型中明确地将其定义为层。

测试代码

代码语言:javascript
复制
model = Sequential()
model.add(Bidirectional(LSTM(10, return_sequences=True), input_shape=(5, 15), merge_mode='concat'))
assert model.layers[-1].output_shape == (None, 5, 20)

model = Sequential()
model.add(Bidirectional(LSTM(10, return_sequences=True), input_shape=(5, 15), merge_mode='sum'))
assert model.layers[-1].output_shape == (None, 5, 10)

model = Sequential()
model.add(Bidirectional(LSTM(10, return_sequences=True), input_shape=(5, 15), merge_mode='mul'))
assert model.layers[-1].output_shape == (None, 5, 10)

注意:

您不能在序列模型中使用merge_mode=None,因为每一层都应该返回一个张量,但是None返回一个列表,所以您不能将其堆叠在模型中。但是你可以在keras的函数式API中使用它。

票数 7
EN

Stack Overflow用户

发布于 2020-07-20 16:09:38

这很简单。假设您的前向LSTM层返回类似于[0.1, 0.2, 0.3]的状态,而后向LSTM层返回[0.4, 0.5, 0.6]。然后连接(为了简洁)是[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],它被进一步传递到激活层。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62991082

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档