首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >在形成时具有不可理解的静态参数的问题

在形成时具有不可理解的静态参数的问题
EN

Stack Overflow用户
提问于 2022-10-14 06:32:34
回答 1查看 48关注 0票数 1

我有一个向量-jacobian积,我想计算。

函数func包含四个参数,最后两个参数是静态的:

代码语言:javascript
运行
复制
def func(variational_params, e, A, B):
    ...
    return model_params, dlogp, ...

函数jit非常精细

代码语言:javascript
运行
复制
func_jitted = jit(func, static_argnums=(2, 3))

素数是variational_params,余切是dlogp (函数的第二个输出)。

天真地计算向量-jacobian积(通过形成jacobian)工作得很好:

代码语言:javascript
运行
复制
jacobian_func = jacobian(func_jitted, argnums=0, has_aux=True)
jacobian_jitted = jit(jacobian_func, static_argnums=(2, 3))
jac, func_output = jacobian_jitted(variational_params, e, A, B)
naive_vjp = func_output.T @ jac 

当尝试以有效的方式形成vjp时,

代码语言:javascript
运行
复制
f_eval, vjp_function, aux_output = vjp(func_jitted, variational_params, e, A, B, has_aux=True)

我得到以下错误:

代码语言:javascript
运行
复制
ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 2) of type <class 'jax.interpreters.ad.JVPTracer'> for function func is non-hashable.

我有点困惑,因为函数func跳得很好.没有将static_argnums添加到vjp函数的选项,因此我不太确定这意味着什么。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-10-14 23:28:23

对于像jit这样的高级转换API,JAX通常提供一种类似于static_argnumsargnums的机制,允许对静态变量和动态变量进行规范。

对于较低级别的转换例程(如jvpvjp ),不提供这些机制,但仍然可以通过传递部分评估的函数来完成相同的任务。例如:

代码语言:javascript
运行
复制
from functools import partial

f_eval, vjp_function, aux_output = vjp(partial(func_jitted, A=A, B=B), variational_params, e, has_aux=True)

这就是转换参数(如argnumsstatic_argnums )是如何在幕后实现的。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/74065210

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档