## 如何基于Keras中的另一个数组进行索引内容来源于 Stack Overflow，并遵循CC BY-SA 3.0许可协议进行翻译与使用

• 回答 (2)
• 关注 (0)
• 查看 (107)

```>>K.int_shape(a)
(None, 5 , 2)
>>K.int_shape(b)
(None, 5)```

```>>K.int_shape(c)
(None, 2)```

```a = np.array([[[2, 7],
[6, 5],
[9, 9],
[4, 2],
[5, 9]],

[[8, 1],
[8, 8],
[3, 9],
[9, 2],
[9, 1]],

[[3, 9],
[6, 4],
[5, 7],
[5, 2],
[5, 6]],

[[7, 5],
[9, 9],
[9, 5],
[9, 8],
[5, 7]],

[[6, 3],
[1, 7],
[3, 6],
[8, 2],
[3, 2]],

[[6, 4],
[5, 9],
[8, 6],
[5, 2],
[5, 2]],

[[2, 6],
[6, 5],
[3, 1],
[6, 2],
[6, 4]]])```

```b = np.array([[ 0.27,  0.25,  0.23,  0.06,  0.19],
[ 0.3 ,  0.13,  0.17,  0.2 ,  0.2 ],
[ 0.08,  0.04,  0.40,  0.36,  0.12],
[ 0.3 ,  0.33,  0.11,  0.07,  0.19],
[ 0.15,  0.21,  0.30,  0.12,  0.22],
[ 0.3 ,  0.13,  0.23,  0.1 ,  0.23],
[ 0.26,  0.35 ,  0.25 ,  0.07,  0.07]])```

```c = np.zeros((7,2))
for i in range(7):
ind = np.argmax(b[i, :])
c[i, :] = a[i, ind, :]
c
array([[ 2.,  7.],
[ 8.,  1.],
[ 5.,  7.],
[ 9.,  9.],
[ 3.,  6.],
[ 6.,  4.],
[ 6.,  5.]])```

### 2 个回答

```import keras.backend as K
import tensorflow as tf

# `a` and `b` the numpy arrays defined in the question
A = tf.constant(a)
B = tf.constant(b)

# Obtaining your max indices over axis 1, which will be used as indices for axis 1 of A:
col_ind = K.argmax(B, axis=1)

# Creating row range, which will be used as indices for axis 0 of A:
row_ind = K.arange(col_ind.shape[0], dtype='int64')

# Stacking the indices together:
ind = K.stack((row_ind, col_ind), axis=-1)

# Gathering the results:
c = tf.gather_nd(A, ind) # no equivalent I know in K, and no idea about theano...

with tf.Session() as sess:
print(c.eval())
# [[2 7]
#  [8 1]
#  [5 7]
#  [9 9]
#  [3 6]
#  [6 4]
#  [6 5]]```

```A = K.constant(a)
B = K.constant(b)
mxidx = K.argmax(B, axis=1)
c = K.map_fn(lambda i: A[i, mxidx[i], :], K.arange(A.shape[0], dtype='int64'))
print K.eval(c)
array([[ 2.,  7.],
[ 8.,  1.],
[ 5.,  7.],
[ 9.,  9.],
[ 3.,  6.],
[ 6.,  4.],
[ 6.,  5.]], dtype=float32)```

```%timeit K.eval(c)
The slowest run took 9.76 times longer than the fastest. This could mean
that an intermediate result is being cached.
100000 loops, best of 3: 12.2 µs per loop```