tensorflow 在加载大型的embedding模型参数时，会遇到cannot be larger than 2GB

这种问题是，对于每一个变量 variable 由于是基于protobuf存在这大小限制(2G),这个时候，我们需要将embedding拆开，拆分成N等分，来使得每一个

variable都在2G以下;

``` 1 # !/usr/bin/env/python
2 # coding=utf-8
3 import tensorflow as tf
4 import numpy as np
5
6 input_ids = tf.placeholder(dtype=tf.int32, shape=[None,None])
7
8 num_shards = 3
9 weights = []
10 weights_shape = np.arange(27).reshape(9, 3)
11 # assert weights_shape[0] % num_shards == 0
12 num_shards_len = (weights_shape.shape[0]) / num_shards
13 assert  (weights_shape.shape[0]) % num_shards ==0
14 begin_ = 0
15 ends_ = num_shards_len
16 for i in range(0, num_shards):
17     if (i + 1) * num_shards_len < weights_shape.shape[0]:
18         begin_ = i * num_shards_len
19         if i + 1 == num_shards:
20             ends_ = weights_shape.shape[0]
21         else:
22             ends_ = (i + 1) * num_shards_len
23     else:
24         begin_ = i * num_shards_len
25         ends_ = weights_shape.shape[0]
26     weights_i = tf.get_variable("words-%02d" % i,
27                                 initializer=tf.constant(weights_shape[begin_: ends_, ]))
28     weights.append(weights_i)
29
30 input_embedding = tf.nn.embedding_lookup(weights, input_ids,partition_strategy="div")
31
32 sess = tf.InteractiveSession()
33 sess.run(tf.global_variables_initializer())
34 print(sess.run(weights))
35
36 print(sess.run(input_embedding, feed_dict={input_ids: [[1, 2], [3, 0], [8, 2], [5, 1]]}))```

结果为:

```[array([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]), array([[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]]), array([[18, 19, 20],
[21, 22, 23],
[24, 25, 26]])]
[[[ 3  4  5]
[ 6  7  8]]

[[ 9 10 11]
[ 0  1  2]]

[[24 25 26]
[ 6  7  8]]

[[15 16 17]
[ 3  4  5]]]```

0 条评论

• mxnet运行时遇到问题及解决方法

1.训练好模型之后，进行预测时出现这种错误： 1 mxnet.base.MXNetError: [15:05:50] src/ndarray/ndarray.c...

• HDUOJ-----I NEED A OFFER!

I NEED A OFFER! Time Limit : 2000/1000ms (Java/Other)   Memory Limit : 65536/327...

• HDUOJ -----1864 最大报销额(动态规划)

最大报销额 Time Limit: 1000/1000 MS (Java/Others)    Memory Limit: 32768/32768 K (Jav...

• Tensorflow卷积实现原理+手写python代码实现卷积

从一个通道的图片进行卷积生成新的单通道图的过程很容易理解，对于多个通道卷积后生成多个通道的图理解起来有点抽象。本文以通俗易懂的方式讲述卷积，并辅以图片解释，能快...

• Qt显示视频流——nginx+rtmp搭建直播服务器(二)

上次介绍的是使用ffmpeg推流，这次介绍的是使用nginx + rtmp搭建直播服务器。

• 超实用：小团队如何从零搭建一个自动化运维体系？

如下图，现在行业内各巨头自动化运维架构的最终样子大家都知道了，但是如何根据自己团队当前的情况一步步向这个目标演进？

• .htaccess文件常用功能总结 【原创】

.htaccess文件常用功能总结 Write By CS逍遥剑仙 我的主页: www.csxiaoyao.com GitHub: githu...

• linux日志循环

操作系统（Windows，Unix），应用一般都会记录日志，方便使用者常看系统或应用使用情况，或者排查故障。

• 创业公司如何组建技术团队

创业已经三年，作为青橙科技CTO兼半个HR，组建技术团队是我最重要的工作之一。不少朋友向我询问经验，在和他们的交流之后，我觉得有必要将我的一些想法与感悟记录下来...

• 百年经典：法约尔对未来工程师的建议

你们将幸福地想到自己终于是有用之才了，你们有理由希望通过劳动获得令人尊重的地位。 　　你们将来需要的素质并非完全是今天让你们名列前茅的那些东西。比如健康，行...