我正在寻找一种优化的方法来计算二维数组与三维数组的每个切片的元素乘法(使用numpy)。
例如:
w = np.array([[1,5], [4,9], [12,15]]) y = np.ones((3,2,3))
我想得到一个与y
形状相同的3d数组的结果。
不允许使用*运算符进行广播。在我的例子中,第三个维度非常长,并且for循环并不方便。
发布于 2018-10-11 07:10:55
给定的数组
import numpy as np
w = np.array([[1,5], [4,9], [12,15]])
print(w)
[[ 1 5]
[ 4 9]
[12 15]]
和
y = np.ones((3,2,3))
print(y)
[[[ 1. 1. 1.]
[ 1. 1. 1.]]
[[ 1. 1. 1.]
[ 1. 1. 1.]]
[[ 1. 1. 1.]
[ 1. 1. 1.]]]
我们可以直接将数组相乘,
z = ( y.transpose() * w.transpose() ).transpose()
print(z)
[[[ 1. 1. 1.]
[ 5. 5. 5.]]
[[ 4. 4. 4.]
[ 9. 9. 9.]]
[[ 12. 12. 12.]
[ 15. 15. 15.]]]
我们可能会注意到,这会产生与np.einsum('ij,ijk->ijk',w,y)相同的结果,可能只需要更少的工作量和开销。
https://stackoverflow.com/questions/52749624
复制相似问题