使用Tensorflow实现数组的部分替换

简单描述一下场景:对于一个二维的整型张量,假设每一行是一堆独立的数,但是对于每一行的数,都有一个设定好的最小值的。我们需要做的是,对于每一行,找到第一次小于最小值的位置,并将该位置起直到行末部分的数字替换为0。是不是有点抽象?我们来举个例子,假设我们的二维整型张量为:

[[5 4 3 0 1]
 [2 3 0 4 2]
 [2 3 5 4 2]]

我们设定的每行最小值为:

[[3],[2],[2]]

则我们最终想要的结果是:

[[5 4 3 0 0]
 [2 3 0 0 0]
 [2 3 5 4 2]]

解释一下,第一行最小值为2,index=3的位置是0,首次小于最小值,因此最后两位变成0,其他位置保持不变。对于其他两行来说也是一样的操作。

看似很简单?以下的实现方案可能比较笨重,如果大家有更好的方法,欢迎留言或者私信微信(sxw2251),咱们一起交流!

tensorflow不能对张量进行直接赋值操作,如果你尝试修改一个tensor中的内容,会报下面的错误:

TypeError: 'Tensor' object does not support item assignment

不能赋值操作大大限制了我的操作啊!不过,经过不懈的研究,上面的需求还是解决了!我们一起来看看实现步骤!

get_shape函数 我们先定义下面的函数,该函数可以返回一个tensor的形状,即使我们的tensor定义时某一维的形状定义为None:

def get_shape(tensor):
  static_shape = tensor.shape.as_list()
  dynamic_shape = tf.unstack(tf.shape(tensor))
  dims = [s[1] if s[0] is None else s[0]
          for s in zip(static_shape, dynamic_shape)]
  return dims

定义输入 我们有两个输入,一个是原始的二维张量,另一个是每一行的最小值:

choose = tf.placeholder(tf.int64,[None,5])
minValue = tf.placeholder(tf.int64,[None,1])

feed_dict = {
    choose:[[5,4,3,0,1],[2,3,0,4,2],[2,3,5,4,2]],
    minValue:[[3],[2],[2]]}

得到每行第一个小于最小值的位置的索引 这里,我们首先判断每个位置的数是否小于最小值,如果小于最小值,返回1,大于等于最小值,返回0,那么使用arg_max函数就可以返回第一个小于最小值的位置的索引:

x = tf.tile(tf.reshape(tf.arg_max(tf.cast(choose<minValue,tf.int64),1),(-1,1)),[1,5])

输出如下,第一行得到的索引是3,第二行得到的索引是2,第三行得到的索引是0:

[[3 3 3 3 3]
 [2 2 2 2 2]
 [0 0 0 0 0]]

这里很容易忽略一种情况,返回是0的情况,此时我们无法判断是全部都大于等于最小值还是0索引对应的值小于最小值。因此我们还需要一个辅助条件,计算每行有多少个数是小于设定的最小值的:

y = tf.tile(tf.reduce_sum(tf.cast(choose<minValue,tf.int64),axis=1,keep_dims=True),[1,5])

对于上面的计算,如果该行所有值都大于等于最小值,结果是0,否则,结果大于0,输出如下:

[[2 2 2 2 2]
 [1 1 1 1 1]
 [0 0 0 0 0]]

得到最终结果

由于我们主要是根据索引去操作的,因此我们为每一个数创建一个索引,以便于我们通过索引进行数据的选择:

y = tf.tile(tf.reduce_sum(tf.cast(choose<minValue,tf.int64),axis=1,keep_dims=True),[1,5])
index = tf.tile(tf.expand_dims(tf.range(5,dtype=tf.int64),0),[get_shape(choose)[0],1])

输出如下:

[[0 1 2 3 4]
 [0 1 2 3 4]
 [0 1 2 3 4]]

激动人心的时刻到了,经过上面两步,我们已经万事俱备了,接下来,我们要做的事,就是根据索引之间的大小关系,要么从原数组里面选数,要么选择0。

result = tf.where(index<x,choose,tf.zeros_like(choose))

得到的结果是:

[[5 4 3 0 0]
 [2 3 0 0 0]
 [0 0 0 0 0]]

可以看到,前两行的结果是对的,但是第三行的结果是错的,这时候就需要我们刚才得到的辅助条件对结果进行修正了:

result = tf.where(index<x,choose,tf.zeros_like(choose)) + tf.where(tf.equal(y,0),choose,tf.zeros_like(choose))

得到的结果如下:

[[5 4 3 0 0]
 [2 3 0 0 0]
 [2 3 5 4 2]]

跟预想的一样,大功告成! 如果还有简单的方法实现上面的需求,欢迎留言哟!

推荐阅读:强化学习系列

实战深度强化学习DQN-理论和实践

DQN三大改进(一)-Double DQN

DQN三大改进(二)-Prioritised replay

DQN三大改进(三)-Dueling Network

深度强化学习-Policy Gradient基本实现

深度强化学习-Actor-Critic算法原理和实现

深度强化学习-DDPG算法原理和实现

对抗思想与强化学习的碰撞-SeqGAN模型原理和代码解析

有关作者:

石晓文,中国人民大学信息学院在读研究生,美团外卖算法实习生

简书ID:石晓文的学习日记(https://www.jianshu.com/u/c5df9e229a67)

天善社区:https://www.hellobi.com/u/58654/articles

腾讯云:https://cloud.tencent.com/developer/user/1622140

原文发布于微信公众号 - 小小挖掘机(wAIsjwj)

原文发表时间:2018-07-06

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏数据结构与算法

BZOJ4872: [Shoi2017]分手是祝愿

Description Zeit und Raum trennen dich und mich. 时空将你我分开。B 君在玩一个游戏,这个游戏由 n 个灯和 n...

2945
来自专栏程序员互动联盟

程序员必须要掌握的十大经典算法

算法一:快速排序算法 快速排序是由东尼·霍尔所发展的一种排序算法。在平均状况下,排序 n 个项目要Ο(n log n)次比较。在最坏状况下则需要Ο(n2)次比较...

41213
来自专栏老九学堂

程序员必须知道的十大基础实用算法及讲解!

最近社群很多的小伙伴们对算法进行了激烈的讨论与学习,今天老九君就给大家介绍一些编程语言里的基础算法,提高小伙伴们的算法知识及编程里对算法的运用。 我们一起来看看...

3575
来自专栏Java 源码分析

枚举

​ 枚举就是尝试所有的可能性,尤其是当我们在确定一个问题是不是的这一类问题中尤其有用,例如说给一堆数,让我我们判断他们是不是素数,或者素数的数量的时候,这...

3136
来自专栏华章科技

程序员必须知道的10大基础实用算法及其讲解

快速排序是由东尼·霍尔所发展的一种排序算法。在平均状况下,排序n个项目要Ο(nlogn)次比较。在最坏状况下则需要Ο(n2)次比较,但这种状况并不常见。事实上,...

942
来自专栏osc同步分享

算法基础

分治法的基本思想: 将一个规模为 n 的问题分解为 k 各规模较小的子问题, 这些子问题互相独立且与原问题是同类型问题。 递归地解这些子问题, 然后把各个子问题...

3949
来自专栏CDA数据分析师

数据分析师不可不知的10大基础实用算法及其讲解

算法一:快速排序算法 快速排序是由东尼·霍尔所发展的一种排序算法。在平均状况下,排序 n 个项目要Ο(n log n)次比较。在最坏状况下则需要Ο(n2)次比较...

2178
来自专栏聊聊技术

原 初学算法-快速排序与线性时间选择(De

3886
来自专栏游戏开发那些事

【随笔】游戏程序开发必知的10大基础实用算法及其讲解

快速排序是由东尼·霍尔所发展的一种排序算法。在平均状况下,排序 n 个项目要Ο(n logn)次比较。在最坏状况下则需要Ο(n2)次比较,但这种状况并不常见。事...

963
来自专栏CSDN技术头条

程序员必须知道的十大基础实用算法及其讲解

算法一:快速排序算法 快速排序是由东尼·霍尔所发展的一种排序算法。在平均状况下,排序 n 个项目要Ο(nlogn) 次比较。在最坏状况下则需要Ο(n2) 次比较...

2075

扫码关注云+社区

领取腾讯云代金券