前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >tf.nn.embedding_lookup函数

tf.nn.embedding_lookup函数

作者头像
周小董
发布2019-03-25 10:00:06
1.7K0
发布2019-03-25 10:00:06
举报
文章被收录于专栏:python前行者

我觉得这张图就够了,实际上tf.nn.embedding_lookup的作用就是找到要寻找的embedding data中的对应的行下的vector。

tf.nn.embedding_lookup(params, ids, partition_strategy=‘mod’, name=None, validate_indices=True, max_norm=None)

参数说明:

  • params: 表示完整的嵌入张量,或者除了第一维度之外具有相同形状的P个张量的列表,表示经分割的嵌入张量
  • ids: 一个类型为int32或int64的Tensor,包含要在params中查找的id
  • partition_strategy: 指定分区策略的字符串,如果len(params)> 1,则相关。当前支持“div”和“mod”。 默认为“mod”
  • name: 操作名称(可选)
  • validate_indices:  是否验证收集索引
  • max_norm: 如果不是None,嵌入值将被l2归一化为max_norm的值

官方文档位置,其中,params是我们给出的,是服从[0,1]的均匀分布或者标准分布,tf.convert_to_tensor转化我们现有的array 然后,ids是我们要找的params中对应位置。

举个例子:

代码语言:javascript
复制
import numpy as np
import tensorflow as tf
data = np.array([[[2],[1]],[[3],[4]],[[6],[7]]])
data = tf.convert_to_tensor(data)
lk = [[0,1],[1,0],[0,0]]
lookup_data = tf.nn.embedding_lookup(data,lk)
init = tf.global_variables_initializer()

先让我们看下不同数据对应的维度:

代码语言:javascript
复制
In [76]: data.shape
Out[76]: (3, 2, 1)
In [77]: np.array(lk).shape
Out[77]: (3, 2)
In [78]: lookup_data
Out[78]: <tf.Tensor 'embedding_lookup_8:0' shape=(3, 2, 2, 1) dtype=int64>

这个是怎么做到的呢?关键的部分来了,看下图:

lk中的值,在要寻找的embedding数据中下找对应的index下的vector进行拼接。永远是look(lk)部分的维度+embedding(data)部分的除了第一维后的维度拼接。很明显,我们也可以得到,lk里面值是必须要小于等于embedding(data)的最大维度减一的

以上的结果就是:

代码语言:javascript
复制
In [79]: data
Out[79]:
array([[[2],
        [1]],

       [[3],
        [4]],

       [[6],
        [7]]])

In [80]: lk
Out[80]: [[0, 1], [1, 0], [0, 0]]

# lk[0]也就是[0,1]对应着下面sess.run(lookup_data)的结果恰好是把data中的[[2],[1]],[[3],[4]]

In [81]: sess.run(lookup_data)
Out[81]:
array([[[[2],
         [1]],

        [[3],
         [4]]],


       [[[3],
         [4]],

        [[2],
         [1]]],


       [[[2],
         [1]],

        [[2],
         [1]]]])
代码语言:javascript
复制
import tensorflow as tf;
import numpy as np;
 
c = np.random.random([10,1])
b = tf.nn.embedding_lookup(c, [1, 3])
 
with tf.Session() as sess:
	sess.run(tf.initialize_all_variables())
	print(sess.run(b))
	print(c)

输出:

代码语言:javascript
复制
[[ 0.77505197]
 [ 0.20635818]]

[[ 0.23976515]
 [ 0.77505197]
 [ 0.08798201]
 [ 0.20635818]
 [ 0.37183035]
 [ 0.24753178]
 [ 0.17718483]
 [ 0.38533808]
 [ 0.93345168]
 [ 0.02634772]]

分析:输出为张量的第一和第三个元素。

  • 如果ids是多行:
代码语言:javascript
复制
import tensorflow as tf
import numpy as np

a = [[0.1, 0.2, 0.3], [1.1, 1.2, 1.3], [2.1, 2.2, 2.3], [3.1, 3.2, 3.3], [4.1, 4.2, 4.3]]
a = np.asarray(a)
idx1 = tf.Variable([0, 2, 3, 1], tf.int32)
idx2 = tf.Variable([[0, 2, 3, 1], [4, 0, 2, 2]], tf.int32)
out1 = tf.nn.embedding_lookup(a, idx1)
out2 = tf.nn.embedding_lookup(a, idx2)
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    print(sess.run(out1))
    print(out1)
    print('==================')
    print(sess.run(out2))
    print(out2)

输出:

代码语言:javascript
复制
[[ 0.1  0.2  0.3]
 [ 2.1  2.2  2.3]
 [ 3.1  3.2  3.3]
 [ 1.1  1.2  1.3]]
Tensor("embedding_lookup:0", shape=(4, 3), dtype=float64)
==================
[[[ 0.1  0.2  0.3]
  [ 2.1  2.2  2.3]
  [ 3.1  3.2  3.3]
  [ 1.1  1.2  1.3]]

 [[ 4.1  4.2  4.3]
  [ 0.1  0.2  0.3]
  [ 2.1  2.2  2.3]
  [ 2.1  2.2  2.3]]]
Tensor("embedding_lookup_1:0", shape=(2, 4, 3), dtype=float64)

参考:https://blog.csdn.net/uestc_c2_403/article/details/72779417 https://www.jianshu.com/p/abea0d9d2436 https://www.cnblogs.com/gaofighting/p/9625868.html https://blog.csdn.net/yangfengling1023/article/details/82910951

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2019年03月18日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档