首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >如何将LSTM与Dense连接起来?

如何将LSTM与Dense连接起来?
EN

Stack Overflow用户
提问于 2020-01-26 21:48:16
回答 1查看 41关注 0票数 0

当尝试将LSTM与Dense连接时,它给出一个错误(在尝试训练时):

代码语言:javascript
复制
input = Input(shape=(x_train.shape[1], None))
X = Embedding(num_words, max_article_len)(input)
X = LSTM(128, return_sequences=True, dropout = 0.5)(X)
X = LSTM(128)(X)
X = Dense(32, activation='softmax')(X)

model = Model(inputs=[input], outputs=[X])
...
>>> ValueError: Error when checking target: expected dense to have shape (32,) but got array with shape (1,)

我尝试了不同的连接选项,但错误重复出现:

代码语言:javascript
复制
X, h, c = LSTM(128, return_sequences=False, return_state=True, dropout = 0.5)(X)
X = Dense(32, activation='softmax')(X)
>>> ValueError: Error when checking target: expected dense to have shape (32,) but got array with shape (1,)

functional API / Sequential上有任何解决方案选项吗?

数据转换代码:

代码语言:javascript
复制
train = pd.read_csv('train.csv')
articles = train['text']
y_train = train['lang']

num_words = 50000
max_article_len = 20

tokenizer = Tokenizer(num_words=num_words)
tokenizer.fit_on_texts(articles)

sequences = tokenizer.texts_to_sequences(articles)
x_train = pad_sequences(sequences, maxlen=max_article_len, padding='post')

x_train.shape
>>> (18974, 100)
y_train.shape
>>> (18974,)
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-01-26 21:54:59

最后一个参数必须设置为False

代码语言:javascript
复制
X = LSTM(128, return_sequences=True, dropout = 0.5)(X)
X = LSTM(128, return_sequences=False)(X)

如果你仍然有问题,那么一定是你的输入形状有问题。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/59918948

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档