类似于: def compute(x): y = very_expensive_function(x)
return y 但是,每个数组x[i]具有不同的长度我可以很容易地解决这个问题,方法是用尾随零填充数组,使它们都具有相同的长度N,并且vmap(compute)可以应用于具有形状(batch_size, N)的批处理。但是,这样做会导致还会在每个数组x[i]的尾随零上调用very_expens
我试图得到输出w.r.t的二阶导数,这是用亚麻建立的神经网络的输入。该网络的结构如下:import jaximport flax.linen as nnbatch = jnp.ones((32, 3)) #Dummy input to Initialize the NN
params = model.init(jax