首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >为什么numpy的einsum比numpy的内置函数快?

为什么numpy的einsum比numpy的内置函数快?
EN

Stack Overflow用户
提问于 2013-08-22 02:31:42
回答 3查看 14.5K关注 0票数 76

让我们从三个dtype=np.double数组开始。使用用icc编译并链接到英特尔mkl的numpy 1.7.1在英特尔CPU上执行计时。使用不带mklgcc编译的具有numpy 1.6.1的AMD cpu也被用来验证时序。请注意,计时与系统大小几乎成线性关系,并不是由于numpy函数if语句产生的小开销造成的,这些差异将以微秒而不是毫秒为单位显示:

arr_1D=np.arange(500,dtype=np.double)
large_arr_1D=np.arange(100000,dtype=np.double)
arr_2D=np.arange(500**2,dtype=np.double).reshape(500,500)
arr_3D=np.arange(500**3,dtype=np.double).reshape(500,500,500)

首先让我们看一下np.sum函数:

np.all(np.sum(arr_3D)==np.einsum('ijk->',arr_3D))
True

%timeit np.sum(arr_3D)
10 loops, best of 3: 142 ms per loop

%timeit np.einsum('ijk->', arr_3D)
10 loops, best of 3: 70.2 ms per loop

权力:

np.allclose(arr_3D*arr_3D*arr_3D,np.einsum('ijk,ijk,ijk->ijk',arr_3D,arr_3D,arr_3D))
True

%timeit arr_3D*arr_3D*arr_3D
1 loops, best of 3: 1.32 s per loop

%timeit np.einsum('ijk,ijk,ijk->ijk', arr_3D, arr_3D, arr_3D)
1 loops, best of 3: 694 ms per loop

外部产品:

np.all(np.outer(arr_1D,arr_1D)==np.einsum('i,k->ik',arr_1D,arr_1D))
True

%timeit np.outer(arr_1D, arr_1D)
1000 loops, best of 3: 411 us per loop

%timeit np.einsum('i,k->ik', arr_1D, arr_1D)
1000 loops, best of 3: 245 us per loop

以上所有这些都是np.einsum的两倍快。这些应该是苹果与苹果的比较,因为所有的东西都是特定于dtype=np.double的。我希望在这样的操作中速度会更快:

np.allclose(np.sum(arr_2D*arr_3D),np.einsum('ij,oij->',arr_2D,arr_3D))
True

%timeit np.sum(arr_2D*arr_3D)
1 loops, best of 3: 813 ms per loop

%timeit np.einsum('ij,oij->', arr_2D, arr_3D)
10 loops, best of 3: 85.1 ms per loop

对于np.innernp.outernp.kronnp.sum,Einsum的速度似乎至少是axes选择的两倍。主要的例外是np.dot,因为它从BLAS库中调用DGEMM。那么,为什么np.einsum比其他同等的numpy函数要快呢?

DGEMM的完整性案例:

np.allclose(np.dot(arr_2D,arr_2D),np.einsum('ij,jk',arr_2D,arr_2D))
True

%timeit np.einsum('ij,jk',arr_2D,arr_2D)
10 loops, best of 3: 56.1 ms per loop

%timeit np.dot(arr_2D,arr_2D)
100 loops, best of 3: 5.17 ms per loop

主要的理论来自@sebergs的评论,即np.einsum可以利用SSE2,但numpy的ufuncs直到numpy 1.8才能使用(参见change log)。我相信这是正确的答案,但还无法确认。一些有限的证据可以通过改变输入数组的数据类型和观察速度差异以及不是每个人都观察到相同的计时趋势的事实来找到。

EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/18365073

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档