前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >tensorflow 在加载大型的embedding模型参数时,会遇到cannot be larger than 2GB

tensorflow 在加载大型的embedding模型参数时,会遇到cannot be larger than 2GB

作者头像
Gxjun
发布2018-12-14 14:55:11
2.6K1
发布2018-12-14 14:55:11
举报
文章被收录于专栏:mlmlml

      这种问题是,对于每一个变量 variable 由于是基于protobuf存在这大小限制(2G),这个时候,我们需要将embedding拆开,拆分成N等分,来使得每一个

variable都在2G以下; 

 1 # !/usr/bin/env/python
 2 # coding=utf-8
 3 import tensorflow as tf
 4 import numpy as np
 5 
 6 input_ids = tf.placeholder(dtype=tf.int32, shape=[None,None])
 7 
 8 num_shards = 3
 9 weights = []
10 weights_shape = np.arange(27).reshape(9, 3)
11 # assert weights_shape[0] % num_shards == 0
12 num_shards_len = (weights_shape.shape[0]) / num_shards
13 assert  (weights_shape.shape[0]) % num_shards ==0
14 begin_ = 0
15 ends_ = num_shards_len
16 for i in range(0, num_shards):
17     if (i + 1) * num_shards_len < weights_shape.shape[0]:
18         begin_ = i * num_shards_len
19         if i + 1 == num_shards:
20             ends_ = weights_shape.shape[0]
21         else:
22             ends_ = (i + 1) * num_shards_len
23     else:
24         begin_ = i * num_shards_len
25         ends_ = weights_shape.shape[0]
26     weights_i = tf.get_variable("words-%02d" % i,
27                                 initializer=tf.constant(weights_shape[begin_: ends_, ]))
28     weights.append(weights_i)
29 
30 input_embedding = tf.nn.embedding_lookup(weights, input_ids,partition_strategy="div")
31 
32 sess = tf.InteractiveSession()
33 sess.run(tf.global_variables_initializer())
34 print(sess.run(weights))
35 
36 print(sess.run(input_embedding, feed_dict={input_ids: [[1, 2], [3, 0], [8, 2], [5, 1]]}))

 结果为:

[array([[0, 1, 2],
       [3, 4, 5],
       [6, 7, 8]]), array([[ 9, 10, 11],
       [12, 13, 14],
       [15, 16, 17]]), array([[18, 19, 20],
       [21, 22, 23],
       [24, 25, 26]])]
[[[ 3  4  5]
  [ 6  7  8]]

 [[ 9 10 11]
  [ 0  1  2]]

 [[24 25 26]
  [ 6  7  8]]

 [[15 16 17]
  [ 3  4  5]]]
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2018-11-21 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档