首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >为什么jax.numpy.dot()在CPU上的运行速度比numpy.dot()慢?

为什么jax.numpy.dot()在CPU上的运行速度比numpy.dot()慢?
EN

Stack Overflow用户
提问于 2020-08-31 13:54:45
回答 1查看 1.4K关注 0票数 2

我想使用JAX来加速CPU上的numpy代码,稍后在GPU上。下面是在本地计算机上运行的示例代码(只有CPU):

代码语言:javascript
运行
复制
import jax.numpy as jnp
from jax import random, jix
import numpy as np
import time

size = 3000

key = random.PRNGKey(0)
x =  random.normal(key, (size,size), dtype=jnp.float64)

start=time.time()
test = jnp.dot(x, x.T).block_until_ready()
print('Time of jnp: {}s'.format(time.time() - start))

x2=np.random.normal((size,size))

start=time.time()
test2 = np.dot(x2, x2.T)
print('Time of np: {}s'.format(time.time() - start))

我收到警告,时间费用如下:

代码语言:javascript
运行
复制
/.../lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: 
UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
Time: 0.45157814025878906s
Time: 0.005244255065917969s

我在这里做错什么了吗?JAX也应该加速CPU上的numpy代码吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-08-31 17:24:02

Jax和Numpy之间可能存在性能差异,但在最初的文章中,时间差异主要归结于数组创建中的一个错误。Jax使用的数组形状为3000x3000,而Numpy使用的数组是长度为2的一维数组。numpy.random.normal的第一个参数是loc (即采样所用高斯的平均值)。关键字参数size=应用于指示数组的形状。

代码语言:javascript
运行
复制
numpy.random.normal(loc=0.0, scale=1.0, size=None)

一旦进行了此更改,Jax和Numpy之间的性能就不那么不同了。

代码语言:javascript
运行
复制
import time
import jax
import jax.numpy as jnp
import numpy as np

size = 3000

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (size, size), dtype=jnp.float64)

start = time.time()
test = jnp.dot(x, x.T).block_until_ready()
print("Time of jnp: {:0.4f} s".format(time.time() - start))

x2 = np.random.normal(size=(size, size)).astype(np.float64)

start = time.time()
test2 = np.dot(x2, x2.T)
print("Time of np: {:0.4f} s".format(time.time() - start))

一次运行的输出是

代码语言:javascript
运行
复制
Time of jnp: 2.3315 s
Time of np: 2.8811 s

在测量定时性能时,应该收集多次运行,因为函数的性能是时间的扩展,而不是单个值。这可以通过Python标准库timeit.timeit函数或IPython和朱庇特笔记本中的%timeit魔术来完成。

代码语言:javascript
运行
复制
import time
import jax
import jax.numpy as jnp
import numpy as np

size = 3000

key = jax.random.PRNGKey(0)
xjnp = jax.random.normal(key, shape=(size, size), dtype=jnp.float64)
xnp = np.random.normal(size=(size, size)).astype(np.float64)

%timeit jnp.dot(xjnp, xjnp.T).block_until_ready()
# 2.03 s ± 39.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit np.dot(xnp, xnp.T)
# 3.41 s ± 501 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

xjnp = xjnp.astype(jnp.float32)
xnp = xnp.astype(np.float32)

%timeit jnp.dot(xjnp, xjnp.T).block_until_ready()
# 2.05 s ± 74.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit np.dot(xnp, xnp.T)
# 1.73 s ± 383 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

在Numpy中,32位浮点数似乎有一个优化的点操作。

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63672151

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档