首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >TensorFlow: numpy.repeat()替代方法

TensorFlow: numpy.repeat()替代方法
EN

Stack Overflow用户
提问于 2016-02-12 19:28:08
回答 9查看 35.9K关注 0票数 43

我想以成对的方式比较神经网络的预测值yp,所以我使用(回到我的旧numpy实现):

代码语言:javascript
复制
idx = np.repeat(np.arange(len(yp)), len(yp))
jdx = np.tile(np.arange(len(yp)), len(yp))
s = yp[[idx]] - yp[[jdx]]

这基本上创建了一个索引网格,然后我使用它。idx=[0,0,0,1,1,1,...]jdx=[0,1,2,0,1,2...]。我不知道是否有更简单的方式来做这件事。

无论如何,TensorFlow有一个tf.tile(),但它似乎缺少一个tf.repeat()

代码语言:javascript
复制
idx = np.repeat(np.arange(n), n)
v2 = v[idx]

然后我得到了错误:

代码语言:javascript
复制
TypeError: Bad slice index [  0   0   0 ..., 215 215 215] of type <type 'numpy.ndarray'>

使用TensorFlow常量进行索引也不起作用:

代码语言:javascript
复制
idx = tf.constant(np.repeat(np.arange(n), n))
v2 = v[idx]

-

代码语言:javascript
复制
TypeError: Bad slice index Tensor("Const:0", shape=TensorShape([Dimension(46656)]), dtype=int64) of type <class 'tensorflow.python.framework.ops.Tensor'>

我的想法是将我的RankNet实现转换为TensorFlow。

EN

回答 9

Stack Overflow用户

回答已采纳

发布于 2016-02-13 00:11:16

您可以结合使用tf.tile()tf.reshape()来实现np.repeat()的效果

代码语言:javascript
复制
idx = tf.range(len(yp))
idx = tf.reshape(idx, [-1, 1])    # Convert to a len(yp) x 1 matrix.
idx = tf.tile(idx, [1, len(yp)])  # Create multiple columns.
idx = tf.reshape(idx, [-1])       # Convert back to a vector.

您可以使用tf.tile()简单地计算jdx

代码语言:javascript
复制
jdx = tf.range(len(yp))
jdx = tf.tile(jdx, [len(yp)])

对于索引,您可以尝试使用tf.gather()yp张量中提取不连续的切片:

代码语言:javascript
复制
s = tf.gather(yp, idx) - tf.gather(yp, jdx)
票数 43
EN

Stack Overflow用户

发布于 2019-05-31 11:29:34

根据tf api documenttf.keras.backend.repeat_elements()np.repeat()执行相同的工作。例如,

代码语言:javascript
复制
x = tf.constant([1, 3, 3, 1], dtype=tf.float32)
rep_x = tf.keras.backend.repeat_elements(x, 5, axis=0)
# result: [1. 1. 1. 1. 1. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 1. 1. 1. 1. 1.]
票数 7
EN

Stack Overflow用户

发布于 2018-11-14 18:36:33

对于一维张量,我制作了这个函数

代码语言:javascript
复制
def tf_repeat(y,repeat_num)   
        return tf.reshape(tf.tile(tf.expand_dims(y,axis=-1),[1,repeat_num]),[-1]) 
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/35361467

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档