首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >JAX jax.lax.switch的意外行为

JAX jax.lax.switch的意外行为
EN

Stack Overflow用户
提问于 2022-11-22 11:51:04
回答 2查看 27关注 0票数 1

我在jax.lax.switch中看到了一种意想不到的行为。

代码语言:javascript
运行
复制
def fun_a():
    print('a')
    
def fun_b():
    print('b')
    
def fun_c():
    print('c')

functions_list=[fun_a,fun_b,fun_c]

然后打电话

代码语言:javascript
运行
复制
jax.lax.switch(0,functions_list)

返回

代码语言:javascript
运行
复制
a
b
c

我只希望看到“一个”印刷。

EN

回答 2

Stack Overflow用户

发布于 2022-11-22 14:18:50

这是因为打印是一个副作用,您可能会有意外的错误切换它。在Jax常见问题中有更多信息,其中有一个打印失败jax.grad的例子。在这种情况下,函数应该返回要打印的值。但是,字符串不是有效的jax类型,开关只支持数值。例如,您可以尝试这样的方法:

代码语言:javascript
运行
复制
def fun_a():
    return ord('a')  # convert 'a' to int (= 97)


def fun_b():
    return ord('b')


def fun_c():
    return ord('c')


functions_list = [fun_a, fun_b, fun_c]
out = jax.lax.switch(0, functions_list)

print(chr(out))  # 'a'
票数 0
EN

Stack Overflow用户

发布于 2022-11-27 13:08:37

考虑到JAX编译器的工作方式,这是预期的行为:它期望纯函数,而且您的函数并不是纯的,因为打印是一种副作用。

如果希望在转换后的JAX函数中按预期进行打印,则可以使用jax.debug.print。例如:

代码语言:javascript
运行
复制
import jax

def fun_a():
    jax.debug.print('a')
    
def fun_b():
    jax.debug.print('b')
    
def fun_c():
    jax.debug.print('c')

functions_list=[fun_a,fun_b,fun_c]
jax.lax.switch(0,functions_list)

输出:

代码语言:javascript
运行
复制
a
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/74532184

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档