前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >使用Tensorflow实现数组的部分替换

使用Tensorflow实现数组的部分替换

作者头像
石晓文
发布2018-07-25 14:36:03
3.6K0
发布2018-07-25 14:36:03
举报

简单描述一下场景:对于一个二维的整型张量,假设每一行是一堆独立的数,但是对于每一行的数,都有一个设定好的最小值的。我们需要做的是,对于每一行,找到第一次小于最小值的位置,并将该位置起直到行末部分的数字替换为0。是不是有点抽象?我们来举个例子,假设我们的二维整型张量为:

[[5 4 3 0 1]
 [2 3 0 4 2]
 [2 3 5 4 2]]

我们设定的每行最小值为:

[[3],[2],[2]]

则我们最终想要的结果是:

[[5 4 3 0 0]
 [2 3 0 0 0]
 [2 3 5 4 2]]

解释一下,第一行最小值为2,index=3的位置是0,首次小于最小值,因此最后两位变成0,其他位置保持不变。对于其他两行来说也是一样的操作。

看似很简单?以下的实现方案可能比较笨重,如果大家有更好的方法,欢迎留言或者私信微信(sxw2251),咱们一起交流!

tensorflow不能对张量进行直接赋值操作,如果你尝试修改一个tensor中的内容,会报下面的错误:

TypeError: 'Tensor' object does not support item assignment

不能赋值操作大大限制了我的操作啊!不过,经过不懈的研究,上面的需求还是解决了!我们一起来看看实现步骤!

get_shape函数 我们先定义下面的函数,该函数可以返回一个tensor的形状,即使我们的tensor定义时某一维的形状定义为None:

def get_shape(tensor):
  static_shape = tensor.shape.as_list()
  dynamic_shape = tf.unstack(tf.shape(tensor))
  dims = [s[1] if s[0] is None else s[0]
          for s in zip(static_shape, dynamic_shape)]
  return dims

定义输入 我们有两个输入,一个是原始的二维张量,另一个是每一行的最小值:

choose = tf.placeholder(tf.int64,[None,5])
minValue = tf.placeholder(tf.int64,[None,1])

feed_dict = {
    choose:[[5,4,3,0,1],[2,3,0,4,2],[2,3,5,4,2]],
    minValue:[[3],[2],[2]]}

得到每行第一个小于最小值的位置的索引 这里,我们首先判断每个位置的数是否小于最小值,如果小于最小值,返回1,大于等于最小值,返回0,那么使用arg_max函数就可以返回第一个小于最小值的位置的索引:

x = tf.tile(tf.reshape(tf.arg_max(tf.cast(choose<minValue,tf.int64),1),(-1,1)),[1,5])

输出如下,第一行得到的索引是3,第二行得到的索引是2,第三行得到的索引是0:

[[3 3 3 3 3]
 [2 2 2 2 2]
 [0 0 0 0 0]]

这里很容易忽略一种情况,返回是0的情况,此时我们无法判断是全部都大于等于最小值还是0索引对应的值小于最小值。因此我们还需要一个辅助条件,计算每行有多少个数是小于设定的最小值的:

y = tf.tile(tf.reduce_sum(tf.cast(choose<minValue,tf.int64),axis=1,keep_dims=True),[1,5])

对于上面的计算,如果该行所有值都大于等于最小值,结果是0,否则,结果大于0,输出如下:

[[2 2 2 2 2]
 [1 1 1 1 1]
 [0 0 0 0 0]]

得到最终结果

由于我们主要是根据索引去操作的,因此我们为每一个数创建一个索引,以便于我们通过索引进行数据的选择:

y = tf.tile(tf.reduce_sum(tf.cast(choose<minValue,tf.int64),axis=1,keep_dims=True),[1,5])
index = tf.tile(tf.expand_dims(tf.range(5,dtype=tf.int64),0),[get_shape(choose)[0],1])

输出如下:

[[0 1 2 3 4]
 [0 1 2 3 4]
 [0 1 2 3 4]]

激动人心的时刻到了,经过上面两步,我们已经万事俱备了,接下来,我们要做的事,就是根据索引之间的大小关系,要么从原数组里面选数,要么选择0。

result = tf.where(index<x,choose,tf.zeros_like(choose))

得到的结果是:

[[5 4 3 0 0]
 [2 3 0 0 0]
 [0 0 0 0 0]]

可以看到,前两行的结果是对的,但是第三行的结果是错的,这时候就需要我们刚才得到的辅助条件对结果进行修正了:

result = tf.where(index<x,choose,tf.zeros_like(choose)) + tf.where(tf.equal(y,0),choose,tf.zeros_like(choose))

得到的结果如下:

[[5 4 3 0 0]
 [2 3 0 0 0]
 [2 3 5 4 2]]

跟预想的一样,大功告成! 如果还有简单的方法实现上面的需求,欢迎留言哟!

推荐阅读:强化学习系列

实战深度强化学习DQN-理论和实践

DQN三大改进(一)-Double DQN

DQN三大改进(二)-Prioritised replay

DQN三大改进(三)-Dueling Network

深度强化学习-Policy Gradient基本实现

深度强化学习-Actor-Critic算法原理和实现

深度强化学习-DDPG算法原理和实现

对抗思想与强化学习的碰撞-SeqGAN模型原理和代码解析

有关作者:

石晓文,中国人民大学信息学院在读研究生,美团外卖算法实习生

简书ID:石晓文的学习日记(https://www.jianshu.com/u/c5df9e229a67)

天善社区:https://www.hellobi.com/u/58654/articles

腾讯云:https://cloud.tencent.com/developer/user/1622140

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2018-07-06,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 小小挖掘机 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档