import jax.numpy as jnp向量和数组是jnp.array(dtype=jnp.int32)
我有一个带有形状x, d, y的数组
[[[0 0 0],
[0 0 0],
[0 0 0]],
[[0 0 0],
[0 0 0],
[0 0 0]],
[[0 0 0],
[0 0 0],
[0 0 0]]]和向量x = [2 0 3], y = [ 2 0 1], d = [0 0 1]
我想通过索引来获得这样的东西,但是我尝试了,但我不知道如何使用jax.numpy。
[[[0 0 2],
[0 0 0],
[0 0 0]],
[[0 0 0],
[0 0 0],
[0 0 0]],
[[0 0 0],
[0 3 0],
[0 0 0]]]编辑:我想指定,我想把数字从x和它的索引放在数组中,但只有当x>0时。我试过用布尔面具。就像这样
mask = x > 0
array = array.at[mask, d, y].set(array[mask, d, y] + x)发布于 2022-04-15 12:17:16
你有一个三维数组,所以你可以用三个索引数组来索引它.由于您希望d和y与第二维度和第三维度相关联,因此需要为第一个维度创建另一个索引数组:
import jax.numpy as jnp
arr = jnp.zeros((3, 3, 3), dtype='int32')
x = jnp.array([2, 0, 3])
y = jnp.array([2, 0, 1])
d = jnp.array([0, 0, 1])
i = jnp.arange(len(x))
mask = x > 0
out = arr.at[i[mask], d[mask], y[mask]].set(x[mask])
print(out)
# [[[0 0 2]
# [0 0 0]
# [0 0 0]]
# [[0 0 0]
# [0 0 0]
# [0 0 0]]
# [[0 0 0]
# [0 3 0]
# [0 0 0]]]在这种情况下,无论您是否使用掩码(即arr.at[i, d, y].set(x)将给出相同的结果),结果都是相同的,但是因为您的问题显式地指定您只想使用值x > 0,所以我包含了它。
https://stackoverflow.com/questions/71880876
复制相似问题