broadcast是numpy中array的一个重要操作。
首先,broadcast只适用于加减。
然后,broadcast执行的时候,如果两个array的shape不一样,会先给“短”的那一个,增加高维度“扩展”(broadcasting),比如,一个2维的array,可以是一个3维size为1的3维array。
类似于: shape(1,3,2) = shape(3,2)
最后,比较两个 array(扩展后的),按照 dimension 从低到高,比较每一个维度的 size 是否满足下面两个条件之一:
所以,举例,下列 array 是否可以进行 broadcast:
broadcast 之后的运算是怎样呢?举例说明:
a = [ [0,1,2,3], [4,5,6,7] ]
b = [1,2,3,4]
a + b = [ [1,3,5,7], [5,7,9,11] ]
或可自己运行下面代码观察:
import numpy as np
a = np.arange(12)
b = a.reshape(3,2,2)
c = np.arange(4)
d = c.reshape(2, 2)
e = np.arange(2)
print(d+b)
print(e+b)
Output:
----------------------------------------
[ 0 1 2 3 4 5 6 7 8 9 10 11]
[[[ 0 1]
[ 2 3]]
[[ 4 5]
[ 6 7]]
[[ 8 9]
[10 11]]]
[0 1 2 3]
[[0 1]
[2 3]]
[0 1]
[[[ 0 2]
[ 4 6]]
[[ 4 6]
[ 8 10]]
[[ 8 10]
[12 14]]]
[[[ 0 2]
[ 2 4]]
[[ 4 6]
[ 6 8]]
[[ 8 10]
[10 12]]]
-----------------------------------
还有下面一种特殊情况,即扩展低维度为 1 的情况下:
import numpy as np
a = np.arange(3)
b = np.arange(5)
a = a[:, np.newaxis]
print(a)
print(b)
print(a+b)
Output:
--------------
[[0]
[1]
[2]]
[0 1 2 3 4]
[[0 1 2 3 4]
[1 2 3 4 5]
[2 3 4 5 6]]
--------------