首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >如何在Keras中基于另一个数组进行索引

如何在Keras中基于另一个数组进行索引
EN

Stack Overflow用户
提问于 2018-06-02 01:17:31
回答 1查看 923关注 0票数 1

我有两个张量ab,它们的形状如下

>>K.int_shape(a)
(None, 5 , 2)
>>K.int_shape(b)
(None, 5)

我想得到的是一个张量c

>>K.int_shape(c)
(None, 2)

这样沿轴0选择b中最大元素的索引,并使用该索引沿轴1索引a

示例-假设我有

a = np.array([[[2, 7],
    [6, 5],
    [9, 9],
    [4, 2],
    [5, 9]],

   [[8, 1],
    [8, 8],
    [3, 9],
    [9, 2],
    [9, 1]],

   [[3, 9],
    [6, 4],
    [5, 7],
    [5, 2],
    [5, 6]],

   [[7, 5],
    [9, 9],
    [9, 5],
    [9, 8],
    [5, 7]],

   [[6, 3],
    [1, 7],
    [3, 6],
    [8, 2],
    [3, 2]],

   [[6, 4],
    [5, 9],
    [8, 6],
    [5, 2],
    [5, 2]],

   [[2, 6],
    [6, 5],
    [3, 1],
    [6, 2],
    [6, 4]]])

而且我有

b = np.array([[ 0.27,  0.25,  0.23,  0.06,  0.19],
[ 0.3 ,  0.13,  0.17,  0.2 ,  0.2 ],
[ 0.08,  0.04,  0.40,  0.36,  0.12],
[ 0.3 ,  0.33,  0.11,  0.07,  0.19],
[ 0.15,  0.21,  0.30,  0.12,  0.22],
[ 0.3 ,  0.13,  0.23,  0.1 ,  0.23],
[ 0.26,  0.35 ,  0.25 ,  0.07,  0.07]])

我对c的期望是什么

c = np.zeros((7,2))
for i in range(7):
    ind = np.argmax(b[i, :])
    c[i, :] = a[i, ind, :]
c
array([[ 2.,  7.],
   [ 8.,  1.],
   [ 5.,  7.],
   [ 9.,  9.],
   [ 3.,  6.],
   [ 6.,  4.],
   [ 6.,  5.]])
EN

回答 1

Stack Overflow用户

发布于 2018-06-02 01:54:06

使用Tensorflow作为后端(我对Theano了解不多),使用tf.gather_nd()

import keras.backend as K
import tensorflow as tf

# `a` and `b` the numpy arrays defined in the question
A = tf.constant(a)
B = tf.constant(b)

# Obtaining your max indices over axis 1, which will be used as indices for axis 1 of A:
col_ind = K.argmax(B, axis=1)

# Creating row range, which will be used as indices for axis 0 of A:
row_ind = K.arange(col_ind.shape[0], dtype='int64')

# Stacking the indices together:
ind = K.stack((row_ind, col_ind), axis=-1)

# Gathering the results:
c = tf.gather_nd(A, ind) # no equivalent I know in K, and no idea about theano...

with tf.Session() as sess:
    print(c.eval())
    # [[2 7]
    #  [8 1]
    #  [5 7]
    #  [9 9]
    #  [3 6]
    #  [6 4]
    #  [6 5]]
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/50648271

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档