我很难理解JAX文档。有人能告诉我如何用jax.lax.scan
重写这样简单的代码吗?
numbers = numpy.array( [ [3.0, 14.0], [15.0, -7.0], [16.0, -11.0] ])
evenNumbers = 0
for row in numbers:
for n in row:
if n % 2 == 0:
evenNumbers += 1
发布于 2022-09-01 09:23:09
假设一个解决方案应该演示这些概念,而不是优化所示的示例,那么jax.lax.scan
ned函数必须与预期的签名匹配,并且任何动态条件都必须替换为jax.lax.cond
。下面的代码是我能想到的最接近原版的代码,但请注意,我绝不是一个jaxpert。
import jax
import jax.numpy as jnp
def f(carry, row):
even = 0
for n in row:
even += jax.lax.cond(n % 2 == 0, lambda: 1, lambda: 0)
return carry + even, even
numbers = jnp.array([[3.0, 14.0], [15.0, -7.0], [16.0, -11.0]])
jax.lax.scan(f, 0, numbers)
输出
(DeviceArray(2, dtype=int32, weak_type=True),
DeviceArray([1, 0, 1], dtype=int32, weak_type=True))
https://stackoverflow.com/questions/73564732
复制相似问题