首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何处理tensorflow中的大型(>2GB)嵌入查找表?

如何处理tensorflow中的大型(>2GB)嵌入查找表?
EN

Stack Overflow用户
提问于 2017-10-12 15:03:04
回答 3查看 3.4K关注 0票数 9

当我使用预先训练过的单词向量进行LSTM分类时,我想知道如何处理tensorflow中大于2gb的嵌入查找表。

为了做到这一点,我尝试让嵌入查找表像下面的代码,

data = tf.nn.embedding_lookup(vector_array, input_data)

得到了这个值错误。

ValueError: Cannot create a tensor proto whose content is larger than 2GB

代码中的变量vector_array是numpy数组,它包含大约1400万个唯一的标记和每个字的100个维字向量。

谢谢你的帮助

EN

回答 3

Stack Overflow用户

回答已采纳

发布于 2018-01-18 10:37:50

对我来说,公认的答案似乎行不通。虽然没有错误,但结果很糟糕(与通过直接初始化进行的较小的嵌入相比),我怀疑嵌入只是使用tf.Variable()初始化的常数0。

只使用没有额外变量的占位符

代码语言:javascript
运行
复制
self.Wembed = tf.placeholder(
    tf.float32, self.embeddings.shape,
    name='Wembed')

然后,在图的每个session.run()上提供嵌入似乎是可行的。

票数 4
EN

Stack Overflow用户

发布于 2017-10-12 16:45:29

您需要将其复制到tf变量中。在StackOverflow:Using a pre-trained word embedding (word2vec or Glove) in TensorFlow中,这个问题有一个很好的答案

我就是这样做的:

代码语言:javascript
运行
复制
embedding_weights = tf.Variable(tf.constant(0.0, shape=[embedding_vocab_size, EMBEDDING_DIM]),trainable=False, name="embedding_weights") 
embedding_placeholder = tf.placeholder(tf.float32, [embedding_vocab_size, EMBEDDING_DIM])
embedding_init = embedding_weights.assign(embedding_placeholder)
sess = tf.Session(config=tf.ConfigProto(log_device_placement=True)) 
sess.run(embedding_init, feed_dict={embedding_placeholder: embedding_matrix})

然后,您可以使用embedding_weights变量执行查找(记住要存储单词索引映射)。

更新:使用变量并不是必需的,但它允许您将变量保存到以后使用,这样您就不必再重新执行整个操作(在我的笔记本上加载非常大的嵌入式时,这需要一段时间)。如果这并不重要,您可以使用尼可拉斯·施内尔建议的占位符。

票数 7
EN

Stack Overflow用户

发布于 2018-07-18 18:23:27

对于我来说,使用Tf1.8时使用大型嵌入feed_dict的速度太慢了,这可能是因为Niklas提到的问题。

最后,我得到了以下代码:

代码语言:javascript
运行
复制
embeddings_ph = tf.placeholder(tf.float32, wordVectors.shape, name='wordEmbeddings_ph')
embeddings_var = tf.Variable(embeddings_ph, trainable=False, name='wordEmbeddings')
embeddings = tf.nn.embedding_lookup(embeddings_var,input_data)
.....
sess.run(tf.global_variables_initializer(), feed_dict={embeddings_ph:wordVectors})
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/46712934

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档