首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >numba的高效平方欧几里德距离编码会比numpy的高效编码慢吗?

numba的高效平方欧几里德距离编码会比numpy的高效编码慢吗?
EN

Stack Overflow用户
提问于 2018-06-04 15:48:52
回答 1查看 711关注 0票数 1

我修改了(Why this numba code is 6x slower than numpy code?)中最有效的代码,使其能够处理x1为(n,m)

代码语言:javascript
复制
@nb.njit(fastmath=True,parallel=True)
def euclidean_distance_square_numba_v5(x1, x2):
    res = np.empty((x1.shape[0], x2.shape[0]), dtype=x2.dtype)
    for a_idx in nb.prange(x1.shape[0]):
        for o_idx in range(x2.shape[0]):
            val = 0.
            for i_idx in range(x2.shape[1]):
                tmp = x1[a_idx, i_idx] - x2[o_idx, i_idx]
                val += tmp * tmp 
            res[a_idx, o_idx] = val 
    return res

然而,它仍然不比效率更高的numpy版本更高效:

代码语言:javascript
复制
def euclidean_distance_square_einsum(x1, x2):
    return np.einsum('ij,ij->i', x1, x1)[:, np.newaxis] + np.einsum('ij,ij->i', x2, x2) - 2*np.dot(x1, x2.T)

输入为

代码语言:javascript
复制
a = np.zeros((1000000,512), dtype=np.float32)
b = np.zeros((100, 512), dtype=np.float32)

我得到的numba代码的时间是2.4723422527313232,numpy代码的时间是0.8260958194732666。

EN

回答 1

Stack Overflow用户

发布于 2018-06-05 06:36:20

是的,这是意料之中的。

你必须知道的第一件事是: dot-product是numpy-version的主力,这里针对的是稍微小一点的数组:

代码语言:javascript
复制
>>> def only_dot(x1, x2):
        return - 2*np.dot(x1, x2.T)

>>> a = np.zeros((1000,512), dtype=np.float32)
>>> b = np.zeros((100, 512), dtype=np.float32)

>>> %timeit(euclidean_distance_square_einsum(a,b))
6.08 ms ± 312 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
>>> %timeit(euclidean_only_dot(a,b))
5.25 ms ± 330 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

也就是说,85%的时间都花在了它上面。

当你看你的numba代码时,这看起来有点奇怪/不同寻常/更复杂的矩阵-矩阵-乘法版本-例如,一个人可以看到相同的三个循环。

因此,基本上,您正在尝试击败一个最好的优化算法之一。下面是somebody trying to do it and failing的例子。我的安装使用的是英特尔的MKL版本,它必须比默认的实现更复杂,可以在here上找到。

有时,在享受了所有的乐趣之后,人们不得不承认自己的“重新发明的轮子”不如最先进的轮子……但只有这样,人们才能真正欣赏它的性能。

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/50675705

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档