首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >Jax中的一种最小二乘损失函数

Jax中的一种最小二乘损失函数
EN

Stack Overflow用户
提问于 2021-10-16 08:32:39
回答 1查看 102关注 0票数 2

我有一个简单的损失函数,如下所示

代码语言:javascript
运行
复制
        def loss(r, x, y):
            resid = f(r, x) - y
            return jnp.mean(jnp.square(resid))

我想对参数r进行优化,并使用一些静态参数xy来计算残差。所有相关参数都是DeviceArrays

为了实现这一点,我尝试了以下操作

代码语言:javascript
运行
复制
        @partial(jax.jit, static_argnums=(1, 2))
        def loss(r, x, y):
            resid = f(r, x) - y
            return jnp.mean(jnp.square(resid))

但是我得到了这个错误

代码语言:javascript
运行
复制
jax._src.traceback_util.UnfilteredStackTrace: ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 1) of type <class 'numpy.ndarray'> for function loss is non-hashable.

我从#6233了解到这是设计出来的,但我想知道这里的变通方法是什么,因为这似乎是一个非常常见的用例,其中你有一些固定的(输入,输出)训练数据对和一些自由变量。

谢谢你的建议!

编辑:这是我尝试使用jax.jit时得到的错误

代码语言:javascript
运行
复制
jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function loss at /path/to/my/script:9 for jit, this concrete value was not available in Python because it depends on the value of the argument 'r'.`
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-10-16 14:23:37

这听起来像是你把静态参数看作是“不会在计算之间改变的值”。在JAX的JIT中,可以更好地将静态参数视为"hashable编译时常量“。在您的例子中,您没有hashable编译时常量;您有数组,所以您可以不使用静态参数进行JIT编译:

代码语言:javascript
运行
复制
@jit
def loss(r, x, y):
    resid = f(r, x) - y
    return jnp.mean(jnp.square(resid))

如果您确实想让JAX机器知道您的数组是常量的,可以通过闭包或分部传递它们来实现;例如:

代码语言:javascript
运行
复制
from functools import partial

def loss(r, x, y):
    resid = f(r, x) - y
    return jnp.mean(jnp.square(resid))
loss = jit(partial(loss, x=x, y=y))

但是,对于您正在进行的计算类型,其中常量是由JAX数组函数操作的数组,这两种方法产生的XLA代码基本相同,因此您也可以使用更简单的方法。

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69593968

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档