首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >在JAX中加速嵌套for-循环

在JAX中加速嵌套for-循环
EN

Stack Overflow用户
提问于 2022-01-03 15:56:00
回答 1查看 687关注 0票数 1

我想使用JAX的jit方法加速下面示例中嵌套的for-循环。但是,编译需要很长时间,编译后的运行时甚至比不使用jit的版本还要慢。

我是否正确地使用jit?JAX中还有其他我应该在这里使用的特性吗?

代码语言:javascript
运行
复制
import time
import jax.numpy as jnp
from jax import jit
from jax import random

key = random.PRNGKey(seed=0)

width = 32
height = 64

w = random.normal(key=key, shape=(height, width))

def forward():
    a = jnp.zeros(shape=(height, width + 1))

    for i in range(height):
        a = a.at[i, 0].add(1.0)

    for j in range(width):
        for i in range(1, height-1):
            z = a[i-1, j] * w[i-1, j] \
                + a[i, j] * w[i, j] \
                + a[i+1, j] * w[i+1, j]
            a = a.at[i, j+1].set(z)

t0 = time.time()
forward()
print(time.time()-t0)

feedforward_jit = jit(forward)

t0 = time.time()
feedforward_jit()
print(time.time()-t0)
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-01-04 23:02:31

对你的问题的简短回答是:为了优化你的循环,你应该尽你所能从你的程序中删除这些循环。

JAX (如NumPy)是一种建立在数组操作基础上的语言,每当您在数组的维度上使用循环时,JAX (如NumPy)将比您可能希望的要慢。在JIT编译过程中尤其如此: JAX将在将操作发送到XLA之前将循环扁平化,而XLA编译时间尺度大致相当于发送给它的操作数量的平方,因此嵌套循环是快速创建非常的慢速编译的一种很好的方法。

那么,如何避免这些循环呢?首先,让我们重新定义您的函数,以便它接受输入并返回输出(考虑到JAX的死代码消除和异步分派,我认为您的初始基准没有告诉您您认为它们是什么;有关一些技巧,请参阅基准JAX代码 ):

代码语言:javascript
运行
复制
def forward(w):
  height, width = w.shape
  a = jnp.zeros(shape=(height, width + 1))

  for i in range(height):
    a = a.at[i, 0].add(1.0)

  for j in range(width):
    for i in range(1, height-1):
      z = (a[i-1, j] * w[i-1, j]
           + a[i, j] * w[i, j]
           + a[i+1, j] * w[i+1, j])
      a = a.at[i, j+1].set(z)
  return a

第一个循环是可以用一行向量化更新:a = a.at[:, 0].set(1)替换的情况。查看下一个块的内部循环,代码似乎会沿着每一列进行卷积。让我们使用jnp.convolve更有效地完成这个任务。使用这两个优化结果如下:

代码语言:javascript
运行
复制
def forward2(w):
  height, width = w.shape
  a = jnp.zeros((height, width + 1)).at[:, 0].set(1)
  kernel = jnp.ones(3)
  for j in range(width):
    conv = jnp.convolve(a[:, j] * w[:, j], kernel, mode='valid')
    a = a.at[1:-1, j + 1].set(conv)
  return a

接下来,让我们看一下宽度上的循环。这里更复杂,因为每次迭代都取决于最后一个迭代的结果。我们可以用lax.scan来表达这一点,这是JAX内置的控制流算子之一。你可以这样做:

代码语言:javascript
运行
复制
def forward3(w):
  def body(carry, w):
    conv = jnp.convolve(carry * w, kernel, mode='valid')
    out = jnp.zeros_like(w).at[1:-1].set(conv)
    return out, out
  init = jnp.ones(w.shape[0])
  kernel = jnp.ones(3)
  return jnp.vstack([
      init, lax.scan(body, jnp.ones(w.shape[0]), w.T)[1]]).T

我们可以迅速证实,这三种方法提供了相同的产出:

代码语言:javascript
运行
复制
width = 32
height = 64
w = random.normal(key=key, shape=(height, width))

result1 = forward(w)
result2 = forward2(w)
result3 = forward3(w)

assert jnp.allclose(result1, result2)
assert jnp.allclose(result2, result3)

使用IPython的%time魔术,我们可以大致了解每种方法的计算时间,这里是CPU后端(注意使用block_until_ready()来解释JAX的异步调度):

代码语言:javascript
运行
复制
%time forward(w).block_until_ready()
# CPU times: user 23 s, sys: 248 ms, total: 23.3 s
# Wall time: 22.9 s

%time forward2(w).block_until_ready()
# CPU times: user 117 ms, sys: 866 µs, total: 118 ms
# Wall time: 118 ms

%time forward3(w).block_until_ready()
# CPU times: user 93.2 ms, sys: 2.96 ms, total: 96.1 ms
# Wall time: 94 ms

您可以在JAX.html#control-flow上阅读更多关于JAX和控制流的信息。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70568316

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档