我有一个向量-jacobian积,我想计算。
函数func
包含四个参数,最后两个参数是静态的:
def func(variational_params, e, A, B):
...
return model_params, dlogp, ...
函数jit非常精细
func_jitted = jit(func, static_argnums=(2, 3))
素数是variational_params
,余切是dlogp
(函数的第二个输出)。
天真地计算向量-jacobian积(通过形成jacobian)工作得很好:
jacobian_func = jacobian(func_jitted, argnums=0, has_aux=True)
jacobian_jitted = jit(jacobian_func, static_argnums=(2, 3))
jac, func_output = jacobian_jitted(variational_params, e, A, B)
naive_vjp = func_output.T @ jac
当尝试以有效的方式形成vjp
时,
f_eval, vjp_function, aux_output = vjp(func_jitted, variational_params, e, A, B, has_aux=True)
我得到以下错误:
ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 2) of type <class 'jax.interpreters.ad.JVPTracer'> for function func is non-hashable.
我有点困惑,因为函数func
跳得很好.没有将static_argnums
添加到vjp
函数的选项,因此我不太确定这意味着什么。
发布于 2022-10-14 23:28:23
对于像jit
这样的高级转换API,JAX通常提供一种类似于static_argnums
或argnums
的机制,允许对静态变量和动态变量进行规范。
对于较低级别的转换例程(如jvp
和vjp
),不提供这些机制,但仍然可以通过传递部分评估的函数来完成相同的任务。例如:
from functools import partial
f_eval, vjp_function, aux_output = vjp(partial(func_jitted, A=A, B=B), variational_params, e, has_aux=True)
这就是转换参数(如argnums
和static_argnums
)是如何在幕后实现的。
https://stackoverflow.com/questions/74065210
复制相似问题