我有一个简单的损失函数,如下所示
def loss(r, x, y):
resid = f(r, x) - y
return jnp.mean(jnp.square(resid))我想对参数r进行优化,并使用一些静态参数x和y来计算残差。所有相关参数都是DeviceArrays。
为了实现这一点,我尝试了以下操作
@partial(jax.jit, static_argnums=(1, 2))
def loss(r, x, y):
resid = f(r, x) - y
return jnp.mean(jnp.square(resid))但是我得到了这个错误
jax._src.traceback_util.UnfilteredStackTrace: ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 1) of type <class 'numpy.ndarray'> for function loss is non-hashable.我从#6233了解到这是设计出来的,但我想知道这里的变通方法是什么,因为这似乎是一个非常常见的用例,其中你有一些固定的(输入,输出)训练数据对和一些自由变量。
谢谢你的建议!
编辑:这是我尝试使用jax.jit时得到的错误
jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function.
While tracing the function loss at /path/to/my/script:9 for jit, this concrete value was not available in Python because it depends on the value of the argument 'r'.`发布于 2021-10-16 14:23:37
这听起来像是你把静态参数看作是“不会在计算之间改变的值”。在JAX的JIT中,可以更好地将静态参数视为"hashable编译时常量“。在您的例子中,您没有hashable编译时常量;您有数组,所以您可以不使用静态参数进行JIT编译:
@jit
def loss(r, x, y):
resid = f(r, x) - y
return jnp.mean(jnp.square(resid))如果您确实想让JAX机器知道您的数组是常量的,可以通过闭包或分部传递它们来实现;例如:
from functools import partial
def loss(r, x, y):
resid = f(r, x) - y
return jnp.mean(jnp.square(resid))
loss = jit(partial(loss, x=x, y=y))但是,对于您正在进行的计算类型,其中常量是由JAX数组函数操作的数组,这两种方法产生的XLA代码基本相同,因此您也可以使用更简单的方法。
https://stackoverflow.com/questions/69593968
复制相似问题