首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >Keras TimeDistributed层实际上是做什么的?

Keras TimeDistributed层实际上是做什么的?
EN

Stack Overflow用户
提问于 2022-03-22 13:24:33
回答 1查看 626关注 0票数 1

给定一个时间序列,我有一个多步预测任务,在这个任务中,我想预测与给定时间序列中的时间步骤相同的次数。如果我有以下模式:

代码语言:javascript
运行
复制
input1 = Input(shape=(n_timesteps, n_channels))
lstm = LSTM(units=100, activation='relu')(input1)
outputs = Dense(n_timesteps, activation="softmax")(lstm)
model = Model(inputs=input1, outputs=outputs)
model.compile(loss="mse", optimizer="adam",
              metrics=["accuracy"])

稠密层上的n_timesteps意味着我将有n_timesteps预测。但是,如果我将密集层封装在一个TimeDistributed中(或者在LSTM层中等效地设置return_sequences=True ),那么单元的数量仍然必须是n_timesteps还是1,因为对于TimeDistributed,我将对序列中的所有时间步骤应用密集层。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-03-29 05:20:28

根据您发布的示例,TimeDistributed实际上将对每个时间步骤应用一个具有softmax激活函数的Dense层:

代码语言:javascript
运行
复制
import tensorflow as tf

n_timesteps = 10
n_channels = 30
input1 = tf.keras.layers.Input(shape=(n_timesteps, n_channels))
lstm = tf.keras.layers.LSTM(units=100, activation='relu', return_sequences=True)(input1)
outputs = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(n_channels, activation="softmax"))(lstm)
model = tf.keras.Model(inputs=input1, outputs=outputs)

请注意,每个完全连接的层都等于n_channels的大小,以便给每个通道在时间步骤n上被预测的公平机会。

如果您正在处理多标签问题,您可以尝试如下所示:

代码语言:javascript
运行
复制
import tensorflow as tf

n_timesteps = 10
features = 3
input1 = tf.keras.layers.Input(shape=(n_timesteps, features))
lstm = tf.keras.layers.LSTM(units=100, activation='relu', return_sequences=False)(input1)
outputs = tf.keras.layers.Dense(n_timesteps, activation="sigmoid")(lstm)
model = tf.keras.Model(inputs=input1, outputs=outputs)

x = tf.random.normal((1, n_timesteps, features))
y = tf.random.uniform((1, n_timesteps), dtype=tf.int32, maxval=2)

print(x)
print(y)
model.compile(optimizer='adam', loss='binary_crossentropy')
model.fit(x, y, epochs=2)
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71572843

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档