我有下面的示例代码,它适用于常规的map
def f(x_y):
x, y = x_y
return x.sum() + y.sum()
xs = [jnp.zeros(3) for i in range(4)]
ys = [jnp.zeros(2) for i in range(4)]
list(map(f, zip(xs, ys)))
# returns:
[DeviceArray(0., dtype=float32),
DeviceArray(0., dtype=float32),
DeviceArray(0., dtype=float32),
DeviceArray(0., dtype=float32)]我怎样才能用jax.vmap代替?最天真的是:
jax.vmap(f)(zip(xs, ys))但这给出了:
ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())发布于 2022-06-05 19:10:30
对于使用jax.vmap,不需要压缩变量。你可以在下面写你想写的东西:
import jax.numpy as jnp
from jax import vmap
def f(x_y):
x, y = x_y
return x.sum() + y.sum()
xs = jnp.zeros((4,3))
ys = jnp.zeros((4,2))
vmap(f)((xs, ys))输出:
DeviceArray([0., 0., 0., 0.], dtype=float32)发布于 2022-06-05 22:16:27
vmap被设计成在默认情况下映射多个变量,因此不需要zip。此外,它只能映射数组轴,而不能映射列表或元组的元素。因此,编写示例的一个更规范的方法是将列表转换为数组,并执行如下操作:
def g(x, y):
return x.sum() + y.sum()
xs_arr = jnp.asarray(xs)
ys_arr = jnp.asarray(ys)
jax.vmap(g)(xs_arr, ys_arr)
# DeviceArray([0., 0., 0., 0.], dtype=float32)https://stackoverflow.com/questions/72509839
复制相似问题