如何使用tensorflow做张量排序和字符串拼接?

本文,将总结一下最近使用tensorflow中遇到的两个小需求:张量排序和字符串拼接,咱们一起来学习一下,嘻嘻!

1、张量排序

tensorflow是没有类似于python中sorted或者np.sort方法的,如果在流中使用这两个方法,是会报错的!那么我们如果想要在graph中实现对张量的排序,该如何做呢!我觉得可以使用top_k函数!

tf.nn.top_k

函数如下:

tf.nn.top_k(input, k, name=None)

这个函数的作用是返回 input 中每行最大的 k 个数(如果想要实现排序,k设置成数组长度即可),并且返回它们所在位置的索引。因此,返回的是一个tuple,我们用下标索引0取出排序后的结果。

看下面的例子:

choose = tf.placeholder(tf.int64,[None,5])
sortresult = tf.nn.top_k(choose,5,sorted=True)
sortresultarr =  tf.nn.top_k(choose,5,sorted=True)[0]


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    feed_dict = {
        choose:[[5,4,3,0,1],[2,3,0,4,2],[2,3,5,4,2]]
    }
    print(sess.run(sortresult,feed_dict=feed_dict))
    print(sess.run(sortresultarr,feed_dict = feed_dict))

返回的结果如下:

TopKV2(values=array([[5, 4, 3, 1, 0],
       [4, 3, 2, 2, 0],
       [5, 4, 3, 2, 2]]), indices=array([[0, 1, 2, 4, 3],
       [3, 1, 0, 4, 2],
       [2, 3, 1, 0, 4]], dtype=int32))
[[5 4 3 1 0]
 [4 3 2 2 0]
 [5 4 3 2 2]]

2、字符串拼接

实现字符串拼接,如果给出的是数字型的tensor,我们首先要将数字转换成字符串,这里使用tf.as_string方法。

sortresultarr =  tf.as_string(tf.nn.top_k(choose,5,sorted=True)[0])

输出如下:

[[b'5' b'4' b'3' b'1' b'0']
 [b'4' b'3' b'2' b'2' b'0']
 [b'5' b'4' b'3' b'2' b'2']]

也许你可能会使用tf.cast方法,不好意思,我们在将int64位转换成string时,报错了:

sortresultarr =  tf.cast(tf.nn.top_k(choose,5,sorted=True)[0],tf.string)
error:Cast int64 to string is not supported

转换成字符串之后,字符串拼接我们可以查到两种方法:tf.reduce_join和tf.string_join。我们分别来试验下这两种方法。

tf.string_join

tf.string_join(
    inputs,
    separator='',
    name=None
)

该方法将给定的字符串张量列表中的字符串连接成一个张量。如果我们直接把刚才的结果放入到函数中,报错了:

sortresultstr = tf.string_join(sortresultarr,separator=",")

#ERROR
TypeError: Expected list for 'inputs' argument to 'string_join' Op, not <tf.Tensor 'AsString:0' shape=(?, 5) dtype=string>.

因为函数要求输入的是一个list,而非一个张量,那好,我们就放入一个list,比如我们将结果的前两行放入:

sortresultstr = tf.string_join([sortresultarr[0],sortresultarr[1]],separator=",")

这次没有报错,而是返回了一个有趣的结果:

[b'5,4' b'4,3' b'3,2' b'1,2' b'0,0']

可以看到,它将我们传入的list中,按位进行了拼接,是不是很有趣!不过这并不是我们想要的答案,如果想要按行进行拼接,应该使用reduce_join函数。

tf.reduce_join

reduce_join(
    inputs,
    axis=None,
    keep_dims=False,
    separator='',
    name=None,
    reduction_indices=None
)

解释一下几个重要的参数: inputs:string类型的Tensor。要加入的输入。所有减少的指数必须为非零的大小。 axis:拼接的维度。 keep_dims:可选的bool。默认为False。如果为True,则保留维度减小的长度1。 separator:可选的string。默认为""。加入时要使用的分隔符。

看下面的例子:

sortresultstr = tf.reduce_join(sortresultarr,axis=1,keep_dims=True,separator=",")

结果如下:

[[b'5,4,3,1,0']
 [b'4,3,2,2,0']
 [b'5,4,3,2,2']]

参考文献

1、https://www.w3cschool.cn/tensorflow_python/tensorflow_python-zku82hj1.html 2、https://www.w3cschool.cn/tensorflow_python/tensorflow_python-ukns2mo5.html 3、https://blog.csdn.net/wuguangbin1230/article/details/72820627

推荐阅读:强化学习系列

实战深度强化学习DQN-理论和实践

DQN三大改进(一)-Double DQN

DQN三大改进(二)-Prioritised replay

DQN三大改进(三)-Dueling Network

深度强化学习-Policy Gradient基本实现

深度强化学习-Actor-Critic算法原理和实现

深度强化学习-DDPG算法原理和实现

对抗思想与强化学习的碰撞-SeqGAN模型原理和代码解析

用Deep Recurrent Q Network解决部分观测问题!

有关作者:

石晓文,中国人民大学信息学院在读研究生,美团外卖算法实习生

简书ID:石晓文的学习日记(https://www.jianshu.com/u/c5df9e229a67)

天善社区:https://www.hellobi.com/u/58654/articles

腾讯云:https://cloud.tencent.com/developer/user/1622140

原文发布于微信公众号 - 小小挖掘机(wAIsjwj)

原文发表时间:2018-07-07

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏派森公园

最简单的NP-Hard问题

1748
来自专栏java一日一条

java float double精度为什么会丢失?浅谈java的浮点数精度问题

由于对float或double 的使用不当,可能会出现精度丢失的问题。问题大概情况可以通过如下代码理解:

872
来自专栏MelonTeam专栏

OpenGL学习笔记(二)——渲染管线&着色语言

导语 :渲染管线(渲染流水线),一般由显示芯片(GPU)内部处理图形信号的并行处理单元组成。这些并行处理单元两两之间相互独立。不同的型号硬件上独立处理单元的数量...

3438
来自专栏mukekeheart的iOS之旅

算法——(转)动态规划入门

动态规划相信大家都知道,动态规划算法也是新手在刚接触算法设计时很苦恼的问题,有时候觉得难以理解,但是真正理解之后,就会觉得动态规划其实并没有想象中那么难。网上也...

1781
来自专栏技术总结

算法(2)

2319
来自专栏武培轩的专栏

剑指Offer-连续子数组的最大和

题目描述 在古老的一维模式识别中,常常需要计算连续子向量的最大和,当向量全为正数的时候,问题很好解决。但是,如果向量中包含负数,是否应该包含某个负数,并期望旁边...

2782
来自专栏MyBlog

数值分析读书笔记(5)数值逼近问题(I)----插值极其数值计算

当给定插值函数是多项式函数的时候, 我们可以产生一种插值的方案, 下面介绍一下Lagrange插值

1611
来自专栏老九学堂

【学习】Java微课堂之switch语句

知识点: ? ? 扩展知识介绍 Java随机数类Random介绍 Java实用工具类库中的类java.util.Random提供了产生各种类型随机数的方法。它可...

3585
来自专栏数据结构与算法

洛谷T21776 子序列

题目描述 你有一个长度为 nn 的数列 ,这个数列由 0,10,1 组成,进行 mm 个的操作: 1~l~r1 l r :把数列区间 [l, r][l,r] ...

3928
来自专栏ATYUN订阅号

NumPy中einsum的基本介绍

einsum函数是NumPy的中最有用的函数之一。由于其强大的表现力和智能循环,它在速度和内存效率方面通常可以超越我们常见的array函数。但缺点是,可能需要一...

5103

扫码关注云+社区

领取腾讯云代金券