前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Python入门教程(五):Numpy计算之广播

Python入门教程(五):Numpy计算之广播

作者头像
数据万花筒
发布2020-12-30 14:24:14
6380
发布2020-12-30 14:24:14
举报
文章被收录于专栏:数据万花筒数据万花筒

广播(broadcasting)是通用函数另一个非常有用的功能,它能够操纵不同大小和形状的数组,这就是我们所说的广播。

01

广播简介

对于同样大小的数组,二元运算符是对相应元素逐个计算,如例1所示。

广播允许这些二元运算符可以用于不同大小的数组。

代码语言:javascript
复制
例1:
import numpy as np
a = np.array([0, 1, 2])
b = np.array([5, 5, 5])
a + b
# array([5, 6, 7])

如例2所示,可以简单地将一个标量(可以认为是一个零维数组)和一个数组相加。这个操作,我们可以认为是将数值5扩展或者重复至数组[5,5,5],然后执行加法。Numpy广播功能的好处是,这种对值的重复实际上没有发生,但是这是一种很好理解的广播模型。

代码语言:javascript
复制
例2:
# 
a + 5
# array([5, 6, 7])

我们也可以把这个原理拓展到更高维度的数组,下面例子展示了一个一维数组和一个二维数组相加的结果。在例3中一个二维数组被拓展了或者被广播了。他沿着第二个维度扩展,拓展到匹配M数组的形状。

代码语言:javascript
复制
例3:
M = np.ones((3, 3))
M

# array([[ 1.,  1.,  1.],
#        [ 1.,  1.,  1.],
#       [ 1.,  1.,  1.]])

M + a
# array([[ 1.,  2.,  3.],
#       [ 1.,  2.,  3.],
#       [ 1.,  2.,  3.]])

以上数组理解起来还比较容易,更复杂的情况涉及到对两个数组的同时广播,如例4所示。

代码语言:javascript
复制
例4:
a = np.arange(3)
b = np.arange(3)[:, np.newaxis]

print(a)
print(b)

# [0 1 2]
# [[0]
#  [1]
#  [2]]

a + b
# array([[0, 1, 2],
#       [1, 2, 3],
#       [2, 3, 4]])

该例子中,我们对a,b都进行了拓展匹配到一个公共的形状,下图中浅色的盒子表示广播的值。

代码语言:javascript
复制

02

广播的规则

Numpy的广播遵循一组严格的规则,设定这组规则是为了决定两个数组之间的操作,其规则如下:

规则1:如果两个数组的维度不相同,那么小维度数组的形状将会在最左边补1.

规则2:如果两个数组的形状在任何一个维度上都不匹配,那么数组的形状会沿着维度为1的维度拓展以匹配另外一个数组形状。

规则3:如果两个数组的形状在任何一个维度上都不匹配并且没有任何一个维度等于1,那么会引发异常。

广播示例1:

将一个二维数组和一个一维数组相加。

代码语言:javascript
复制
M = np.ones((2, 3))
a = np.arange(3)
# 查看两数组的形状
# M.shape = (2, 3)
# a.shape = (3,)


# 根据规则1,数组a维度更小,所以在其左边补1
# M.shape -> (2, 3)
# a.shape -> (1, 3)

# 根据规则2,第一个维度不匹配,因此拓展这个维度以匹配数组。
# M.shape -> (2, 3)
# a.shape -> (2, 3)

# 两个数组维度匹配了,两个数组相加
M + a
# array([[ 1.,  2.,  3.],
#        [ 1.,  2.,  3.]])

‍广播示例2:‍

下面这个例子是两个维度都需要广播。

代码语言:javascript
复制
a = np.arange(3).reshape((3, 1))
b = np.arange(3)
# 查看两个数组的维度


# - ``a.shape = (3, 1)``
# - ``b.shape = (3,)``

# 根据规则1,用1将b的形状补全

# - ``a.shape -> (3, 1)``
# - ``b.shape -> (1, 3)``

# 根据规则2,更新数组的维度来相互匹配

# - ``a.shape -> (3, 3)``
# - ``b.shape -> (3, 3)``

# 因为结果匹配,所以两个形状是兼容的,可以看到如下效果:
a + b
# array([[0, 1, 2],
#        [1, 2, 3],
#        [2, 3, 4]])

广播示例3:

下面这个例子是两个数组不兼容的示例。

和第一个示例相比,这里的M是转置的。

代码语言:javascript
复制
# 首先,我们先来看一下数组的形状
M = np.ones((3, 2))
a = np.arange(3)
# M.shape = (3, 2)
# a.shape = (3,)

# 根据规则1,我们需要用1将b的形状补齐。
# M.shape -> (3, 2)
# a.shape -> (1, 3)

# 根据规则2,a数组的第一个维度进行拓展以匹配到M的维度。
# M.shape -> (3, 2)
# a.shape -> (3, 3)

# 根据规则3进行判断,最终形状还是不匹配,因此两个数组是不兼容的,当我们执行运算时,会得到如下的结果:
M + a
# ---------------------------------------------------------------------------
# ValueError                                Traceback (most recent call last)
# <ipython-input-13-9e16e9f98da6> in <module>()
# ----> 1 M + a

# ValueError: operands could not be broadcast together with shapes (3,2) (3,)

这时候,你可能会像通过在a数组的右边补上1,而不是左边补上1,让a和M的维度变得兼容。但是这不被广播的规则所允许。这种灵活性在某些场景中可能会有用,但它可能会导致结果模糊。如果你希望实现右边补全,可以通过变形数组来实现。

代码语言:javascript
复制
a[:, np.newaxis].shape
# (3, 1)
M + a[:, np.newaxis]
# array([[ 1.,  1.],
#        [ 2.,  2.],
#        [ 3.,  3.]])

另外需要注意的是,这里仅用到了+运算符,而这些广播规则对于任意二进制通用函数都是使用的。例如,logaddexp(a,b)函数,比起简单的方法,该函数计算log(exp(a)+exp(b))更加准确。

代码语言:javascript
复制
np.logaddexp(M, a[:, np.newaxis])
# array([[ 1.31326169,  1.31326169],
#        [ 1.69314718,  1.69314718],
#        [ 2.31326169,  2.31326169]])

03

广播的实践

广播在实际操作中用的很多,下面我们通过几个简单的例子进行说明。

1.数组归一化

假设你有一个10个观察值的数组,每个观察值包括3个数值,按照惯例,我们将用一个10*3的数组存放该数据。我们可以计算每个特征值的均值,计算方法是利用mean函数沿着第一个维度聚合。

代码语言:javascript
复制
X = np.random.random((10, 3))
Xmean = X.mean(0)
Xmean
# array([ 0.53514715,  0.66567217,  0.44385899])

现在从X数组的元素中减去这个均值,实现归一化(该操作是一个广播操作)。

为了进一步核对我们的处理是否正确,可以检查归一化的数组的均值是否接近0。

代码语言:javascript
复制
X_centered = X - Xmean
X_centered.mean(0)
# array([  2.22044605e-17,  -7.77156117e-17,  -1.66533454e-17])

2.画一个二维函数

广播的另一个非常有用的地方在于,它能基于二维函数显示图像,我们定义一个函数z=f(x,y),可以用广播沿着数值区间计算该函数。

代码语言:javascript
复制
# x and y have 50 steps from 0 to 5
x = np.linspace(0, 5, 50)
y = np.linspace(0, 5, 50)[:, np.newaxis]

z = np.sin(x) ** 10 + np.cos(10 + y * x) * np.cos(x)
%matplotlib inline
import matplotlib.pyplot as plt
plt.imshow(z, origin='lower', extent=[0, 5, 0, 5],
           cmap='viridis')
plt.colorbar();
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-11-21,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 数据万花筒 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 01
  • 02
  • 03
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档