首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >用jax.lax.scan重写循环

用jax.lax.scan重写循环
EN

Stack Overflow用户
提问于 2022-09-01 05:47:14
回答 1查看 433关注 0票数 1

我很难理解JAX文档。有人能告诉我如何用jax.lax.scan重写这样简单的代码吗?

代码语言:javascript
运行
复制
numbers = numpy.array( [ [3.0, 14.0], [15.0, -7.0], [16.0, -11.0] ])
evenNumbers = 0
for row in numbers:
      for n in row:
         if n % 2 == 0:
            evenNumbers += 1
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-09-01 09:23:09

假设一个解决方案应该演示这些概念,而不是优化所示的示例,那么jax.lax.scanned函数必须与预期的签名匹配,并且任何动态条件都必须替换为jax.lax.cond。下面的代码是我能想到的最接近原版的代码,但请注意,我绝不是一个jaxpert。

代码语言:javascript
运行
复制
import jax
import jax.numpy as jnp

def f(carry, row):

    even = 0
    for n in row:
        even += jax.lax.cond(n % 2 == 0, lambda: 1, lambda: 0)

    return carry + even, even

numbers = jnp.array([[3.0, 14.0], [15.0, -7.0], [16.0, -11.0]])
jax.lax.scan(f, 0, numbers)

输出

代码语言:javascript
运行
复制
(DeviceArray(2, dtype=int32, weak_type=True),
 DeviceArray([1, 0, 1], dtype=int32, weak_type=True))
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73564732

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档