首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >重塑打破了Keras模型

重塑打破了Keras模型
EN

Stack Overflow用户
提问于 2019-03-08 10:28:13
回答 1查看 76关注 0票数 0

我正在研究一个单词嵌入模型,当我试图添加一个最终的Reshape时,这个模型总是崩溃。下面是带有抛出的Reshape的模型:

ValueError: total size of new array must be unchanged

我搞不懂为什么这些尺寸不能加起来。

代码语言:javascript
运行
复制
embedding_size = 50
input_size = 46
# Both inputs are 1-dimensional
ingredients = Input(
     name='ingredients',
    shape=(input_size,)
)
documents = Input(
    name='documents',
    shape=(input_size,)
)


ingredients_embedding = Embedding(name='ingredients_embedding',
                                  input_dim=training_size,
                                  output_dim=embedding_size)(ingredients)

# Embedding the document (shape is (None, 46, 50))
document_embedding = Embedding(name='documents_embedding',
                               input_dim=training_size,
                               output_dim=embedding_size)(documents)

# Merge the layers with a dot product along the second axis (shape is (None, 46, 46))
merged = Dot(name='dot_product', normalize=True, axes=2)([ingredients_embedding, document_embedding])

# ~ This like breaks ~
# Reshape to be a single number (shape will be (None, 1))
merged = Reshape(target_shape=(1,))(merged) # <-- ValueError: total size of new array must be unchanged


m = Model(inputs=[ingredients, documents], outputs=merged)
m.compile(optimizer='Adam', loss='mse')

return m
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-03-08 11:18:59

将文档表示为向量的一种典型方法是沿着句子维度对document_embedding矩阵求和。

代码语言:javascript
运行
复制
from keras.layers import Input, Embedding, Dot, Reshape, Lambda
from keras.models import Model
import keras.backend as K

embedding_size = 50
input_size = 46

ingredients = Input(
    name='ingredients',
    shape=(input_size,)
)
documents = Input(
    name='documents',
    shape=(input_size,)
)

ingredients_embedding = Embedding(name='ingredients_embedding',
                                  input_dim=input_size,
                                  output_dim=embedding_size)(ingredients)

document_embedding = Embedding(name='documents_embedding',
                               input_dim=input_size,
                               output_dim=embedding_size)(documents)

#sum over the sentence dimension
ingredients_embedding = Lambda(lambda x: K.sum(x, axis=-2))(ingredients_embedding)
#sum over the sentence dimension
document_embedding = Lambda(lambda x: K.sum(x, axis=-2))(document_embedding)

merged = Dot(name='dot_product', normalize=True, axes=-1)([ingredients_embedding, document_embedding])

merged = Reshape(target_shape=(1,))(merged) 

m = Model(inputs=[ingredients, documents], outputs=merged)
m.compile(optimizer='Adam', loss='mse')
m.summary()

document_embedding的形状是(None, input_size, embedding_size),所以-2是倒数第二个轴,也就是句子维度的轴。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/55055858

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档