tf.nn.embedding_lookup记录

我觉得这张图就够了,实际上tf.nn.embedding_lookup的作用就是找到要寻找的embedding data中的对应的行下的vector。

tf.nn.embedding_lookup(params, ids, partition_strategy='mod', name=None, validate_indices=True, max_norm=None)

官方文档位置,其中,params是我们给出的,可以通过: 1.tf.get_variable("item_emb_w", [self.item_count, self.embedding_size])等方式生产服从[0,1]的均匀分布或者标准分布 2.tf.convert_to_tensor转化我们现有的array 然后,ids是我们要找的params中对应位置。

举个例子:

import numpy as np
import tensorflow as tf
data = np.array([[[2],[1]],[[3],[4]],[[6],[7]]])
data = tf.convert_to_tensor(data)
lk = [[0,1],[1,0],[0,0]]
lookup_data = tf.nn.embedding_lookup(data,lk)
init = tf.global_variables_initializer()

先让我们看下不同数据对应的维度:

In [76]: data.shape
Out[76]: (3, 2, 1)
In [77]: np.array(lk).shape
Out[77]: (3, 2)
In [78]: lookup_data
Out[78]: <tf.Tensor 'embedding_lookup_8:0' shape=(3, 2, 2, 1) dtype=int64>

这个是怎么做到的呢?关键的部分来了,看下图:

lk中的值,在要寻找的embedding数据中下找对应的index下的vector进行拼接。永远是look(lk)部分的维度+embedding(data)部分的除了第一维后的维度拼接。很明显,我们也可以得到,lk里面值是必须要小于等于embedding(data)的最大维度减一的

以上的结果就是:

In [79]: data
Out[79]:
array([[[2],
        [1]],

       [[3],
        [4]],

       [[6],
        [7]]])

In [80]: lk
Out[80]: [[0, 1], [1, 0], [0, 0]]

# lk[0]也就是[0,1]对应着下面sess.run(lookup_data)的结果恰好是把data中的[[2],[1]],[[3],[4]]

In [81]: sess.run(lookup_data)
Out[81]:
array([[[[2],
         [1]],

        [[3],
         [4]]],


       [[[3],
         [4]],

        [[2],
         [1]]],


       [[[2],
         [1]],

        [[2],
         [1]]]])

最后,partition_strategy是用于当len(params) > 1,params的元素分割不能整分的话,则前(max_id + 1) % len(params)多分一个id. 当partition_strategy = 'mod'的时候,13个ids划分为5个分区:[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]],也就是是按照数据列进行映射,然后再进行look_up操作。 当partition_strategy = 'div'的时候,13个ids划分为5个分区:[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]],也就是是按照数据先后进行排序标序,然后再进行look_up操作。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏HansBug's Lab

算法模板——线段树4(区间加+区间乘+区间覆盖值+区间求和)

实现功能——1:区间加法 2:区间乘法 3:区间覆盖值 4:区间求和 这是个四种常见线段树功能的集合版哦。。。么么哒(其实只要协调好三种tag的关系并不算太难—...

28430
来自专栏HTML5学堂

Javascript中的Label语句

HTML5学堂:在JavaScript中,我们可能很少会去用到 Label 语句,但是熟练的应用 Label 语句,尤其是在嵌套循环中熟练应用 break, c...

42270
来自专栏python3

python语句-中断循环-continue,break

continue的作用是:从continue语句开始到循环结束,之间所有的语句都不执行,直接从一下次循环重新开始

14230
来自专栏ACM算法日常

Max Sum(优化)- HDU 1003

Given a sequence a[1],a[2],a[3]......a[n], your job is to calculate the max sum ...

8730
来自专栏漫漫深度学习路

python画图:matplotlib(1)

python matplotlib matplotlib是python中用来绘图的一个库,提供非常强大的绘图功能。 安装 pip install matplot...

36070
来自专栏算法channel

Tensorflow|Tensor, 与Numpy比较,Constant

本教程参考stanford.edu-cs20si 01 Operations分类预览 ? 02 Tensor 1 0-d tensor, or "scala...

43870
来自专栏深度学习之tensorflow实战篇

tensorflow(一)windows 10 python3.6安装tensorflow1.4与基本概念解读

一.安装 目前用了tensorflow、deeplearning4j两个深度学习框架, tensorflow 之前一直支持到python 3.5,目前以更...

43340
来自专栏wOw的Android小站

[Tensorflow] 在Android运行TensorFlow模型

以下代码来自于TensorFlowObjectDetectionAPIModel.java

70210
来自专栏深度学习之tensorflow实战篇

tensorflow(一)windows 10 64位安装tensorflow1.4与基本概念解读tf.global_variables_initializer

一.安装 目前用了tensorflow、deeplearning4j两个深度学习框架, tensorflow 之前一直支持到python 3.5,目前以更新...

38560
来自专栏应兆康的专栏

100个Numpy练习【3】

翻译:YingJoy 网址: https://www.yingjoy.cn/ 来源: https://github.com/rougier/numpy-100...

456100

扫码关注云+社区

领取腾讯云代金券