我在jax.lax.switch中看到了一种意想不到的行为。
def fun_a():
print('a')
def fun_b():
print('b')
def fun_c():
print('c')
functions_list=[fun_a,fun_b,fun_c]
然后打电话
jax.lax.switch(0,functions_list)
返回
a
b
c
我只希望看到“一个”印刷。
发布于 2022-11-22 14:18:50
这是因为打印是一个副作用,您可能会有意外的错误切换它。在Jax常见问题中有更多信息,其中有一个打印失败jax.grad
的例子。在这种情况下,函数应该返回要打印的值。但是,字符串不是有效的jax类型,开关只支持数值。例如,您可以尝试这样的方法:
def fun_a():
return ord('a') # convert 'a' to int (= 97)
def fun_b():
return ord('b')
def fun_c():
return ord('c')
functions_list = [fun_a, fun_b, fun_c]
out = jax.lax.switch(0, functions_list)
print(chr(out)) # 'a'
发布于 2022-11-27 13:08:37
考虑到JAX编译器的工作方式,这是预期的行为:它期望纯函数,而且您的函数并不是纯的,因为打印是一种副作用。
如果希望在转换后的JAX函数中按预期进行打印,则可以使用jax.debug.print
。例如:
import jax
def fun_a():
jax.debug.print('a')
def fun_b():
jax.debug.print('b')
def fun_c():
jax.debug.print('c')
functions_list=[fun_a,fun_b,fun_c]
jax.lax.switch(0,functions_list)
输出:
a
https://stackoverflow.com/questions/74532184
复制相似问题