首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何在压缩参数上使用Jax?

如何在压缩参数上使用Jax?
EN

Stack Overflow用户
提问于 2022-06-05 17:55:57
回答 2查看 455关注 0票数 1

我有下面的示例代码,它适用于常规的map

代码语言:javascript
复制
def f(x_y):
    x, y = x_y
    return x.sum() + y.sum()

xs = [jnp.zeros(3) for i in range(4)]
ys = [jnp.zeros(2) for i in range(4)]

list(map(f, zip(xs, ys)))

# returns:
[DeviceArray(0., dtype=float32),
 DeviceArray(0., dtype=float32),
 DeviceArray(0., dtype=float32),
 DeviceArray(0., dtype=float32)]

我怎样才能用jax.vmap代替?最天真的是:

代码语言:javascript
复制
jax.vmap(f)(zip(xs, ys))

但这给出了:

代码语言:javascript
复制
ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())
EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2022-06-05 19:10:30

对于使用jax.vmap,不需要压缩变量。你可以在下面写你想写的东西:

代码语言:javascript
复制
import jax.numpy as jnp
from jax import vmap

def f(x_y):
    x, y = x_y
    return x.sum() + y.sum()

xs = jnp.zeros((4,3))
ys = jnp.zeros((4,2))
vmap(f)((xs, ys))

输出:

代码语言:javascript
复制
DeviceArray([0., 0., 0., 0.], dtype=float32)
票数 2
EN

Stack Overflow用户

发布于 2022-06-05 22:16:27

vmap被设计成在默认情况下映射多个变量,因此不需要zip。此外,它只能映射数组轴,而不能映射列表或元组的元素。因此,编写示例的一个更规范的方法是将列表转换为数组,并执行如下操作:

代码语言:javascript
复制
def g(x, y):
  return x.sum() + y.sum()

xs_arr = jnp.asarray(xs)
ys_arr = jnp.asarray(ys)

jax.vmap(g)(xs_arr, ys_arr)
# DeviceArray([0., 0., 0., 0.], dtype=float32)
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72509839

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档