专栏首页mltensorflow 在加载大型的embedding模型参数时,会遇到cannot be larger than 2GB

tensorflow 在加载大型的embedding模型参数时,会遇到cannot be larger than 2GB

      这种问题是,对于每一个变量 variable 由于是基于protobuf存在这大小限制(2G),这个时候,我们需要将embedding拆开,拆分成N等分,来使得每一个

variable都在2G以下; 

 1 # !/usr/bin/env/python
 2 # coding=utf-8
 3 import tensorflow as tf
 4 import numpy as np
 5 
 6 input_ids = tf.placeholder(dtype=tf.int32, shape=[None,None])
 7 
 8 num_shards = 3
 9 weights = []
10 weights_shape = np.arange(27).reshape(9, 3)
11 # assert weights_shape[0] % num_shards == 0
12 num_shards_len = (weights_shape.shape[0]) / num_shards
13 assert  (weights_shape.shape[0]) % num_shards ==0
14 begin_ = 0
15 ends_ = num_shards_len
16 for i in range(0, num_shards):
17     if (i + 1) * num_shards_len < weights_shape.shape[0]:
18         begin_ = i * num_shards_len
19         if i + 1 == num_shards:
20             ends_ = weights_shape.shape[0]
21         else:
22             ends_ = (i + 1) * num_shards_len
23     else:
24         begin_ = i * num_shards_len
25         ends_ = weights_shape.shape[0]
26     weights_i = tf.get_variable("words-%02d" % i,
27                                 initializer=tf.constant(weights_shape[begin_: ends_, ]))
28     weights.append(weights_i)
29 
30 input_embedding = tf.nn.embedding_lookup(weights, input_ids,partition_strategy="div")
31 
32 sess = tf.InteractiveSession()
33 sess.run(tf.global_variables_initializer())
34 print(sess.run(weights))
35 
36 print(sess.run(input_embedding, feed_dict={input_ids: [[1, 2], [3, 0], [8, 2], [5, 1]]}))

 结果为:

[array([[0, 1, 2],
       [3, 4, 5],
       [6, 7, 8]]), array([[ 9, 10, 11],
       [12, 13, 14],
       [15, 16, 17]]), array([[18, 19, 20],
       [21, 22, 23],
       [24, 25, 26]])]
[[[ 3  4  5]
  [ 6  7  8]]

 [[ 9 10 11]
  [ 0  1  2]]

 [[24 25 26]
  [ 6  7  8]]

 [[15 16 17]
  [ 3  4  5]]]

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • mxnet运行时遇到问题及解决方法

    1.训练好模型之后,进行预测时出现这种错误: 1 mxnet.base.MXNetError: [15:05:50] src/ndarray/ndarray.c...

    Gxjun
  • HDUOJ-----I NEED A OFFER!

    I NEED A OFFER! Time Limit : 2000/1000ms (Java/Other)   Memory Limit : 65536/327...

    Gxjun
  • HDUOJ -----1864 最大报销额(动态规划)

    最大报销额 Time Limit: 1000/1000 MS (Java/Others)    Memory Limit: 32768/32768 K (Jav...

    Gxjun
  • Tensorflow卷积实现原理+手写python代码实现卷积

    从一个通道的图片进行卷积生成新的单通道图的过程很容易理解,对于多个通道卷积后生成多个通道的图理解起来有点抽象。本文以通俗易懂的方式讲述卷积,并辅以图片解释,能快...

    superhua
  • Qt显示视频流——nginx+rtmp搭建直播服务器(二)

    上次介绍的是使用ffmpeg推流,这次介绍的是使用nginx + rtmp搭建直播服务器。

    用户5908113
  • 超实用:小团队如何从零搭建一个自动化运维体系?

    如下图,现在行业内各巨头自动化运维架构的最终样子大家都知道了,但是如何根据自己团队当前的情况一步步向这个目标演进?

    小小科
  • .htaccess文件常用功能总结 【原创】

    .htaccess文件常用功能总结 Write By CS逍遥剑仙 我的主页: www.csxiaoyao.com GitHub: githu...

    CS逍遥剑仙
  • linux日志循环

    操作系统(Windows,Unix),应用一般都会记录日志,方便使用者常看系统或应用使用情况,或者排查故障。

    zero000
  • 创业公司如何组建技术团队

    创业已经三年,作为青橙科技CTO兼半个HR,组建技术团队是我最重要的工作之一。不少朋友向我询问经验,在和他们的交流之后,我觉得有必要将我的一些想法与感悟记录下来...

    tyrchen
  • 百年经典:法约尔对未来工程师的建议

    你们将幸福地想到自己终于是有用之才了,你们有理由希望通过劳动获得令人尊重的地位。   你们将来需要的素质并非完全是今天让你们名列前茅的那些东西。比如健康,行...

    机器人网

扫码关注云+社区

领取腾讯云代金券

玩转腾讯云 有奖征文活动