做像这样的事情
import numpy as np
a = np.random.rand(10**4, 10**4)
b = np.dot(a, a)
使用多核,并且运行良好。
不过,a
中的元素是64位浮点数(在32位平台中是32位吗?),我想将8位整数数组相乘。不过,请尝试以下几点:
a = np.random.randint(2, size=(n, n)).astype(np.int8)
结果导致点积不使用多核,因此在我的PC上运行速度慢了大约1000倍。
array: np.random.randint(2, size=shape).astype(dtype)
dtype shape %time (average)
float32 (2000, 2000) 62.5 ms
float32 (3000, 3000) 219 ms
float32 (4000, 4000) 328 ms
float32 (10000, 10000) 4.09 s
int8 (2000, 2000) 13 seconds
int8 (3000, 3000) 3min 26s
int8 (4000, 4000) 12min 20s
int8 (10000, 10000) It didn't finish in 6 hours
float16 (2000, 2000) 2min 25s
float16 (3000, 3000) Not tested
float16 (4000, 4000) Not tested
float16 (10000, 10000) Not tested
我知道NumPy使用BLAS,它不支持整数,但如果我使用SciPy BLAS包装器,即。
import scipy.linalg.blas as blas
a = np.random.randint(2, size=(n, n)).astype(np.int8)
b = blas.sgemm(alpha=1.0, a=a, b=a)
计算是多线程的。现在,对于浮点32,blas.sgemm
的运行时间与np.dot
完全相同,但对于非浮点数,它会将所有内容转换为float32
并输出浮点数,这是np.dot
所不做的。(此外,b
现在是F_CONTIGUOUS
顺序,这是一个较小的问题)。
因此,如果我想做整数矩阵乘法,我必须执行以下操作之一:
np.dot
,很高兴我保留了8位的sgemm
,并使用了4倍的内存。np.float16
,只使用了2倍的内存,但需要注意的是,np.dot
在float16阵列上比在float32阵列上慢得多,比int8慢得多。我可以遵循选项4吗?这样的库存在吗?
免责声明:我实际上正在运行NumPy + MKL,但我已经在vanilly NumPy上尝试了类似的测试,得到了类似的结果。
https://stackoverflow.com/questions/35101312
复制相似问题