本章目录:
在前面的小节中,我们学习了如何获取和修改数组的元素或部分元素,我们可以通过简单索引(例如arr[0]
),切片(例如arr[:5]
)和布尔遮盖(例如arr[arr > 0]
)来实现。本节来介绍另外一种数组索引的方式,被称为高级索引。高级索引语法上和前面我们学习到的简单索引很像,区别只是它不是传递标量参数作为索引值,而是传递数组参数作为索引值。它能让我们很迅速的获取和修改复杂数组或子数组的元素值。
高级索引在概念层面非常简单:传递一个数组作为索引值参数,使得用户能一次性的获取或修改多个数组元素值。例如下面的数组:
import numpy as np
rand = np.random.RandomState(42)
x = rand.randint(100, size=10)
print(x)
[51 92 14 71 60 20 82 86 74 74]
假如我们需要访问其中三个不同的元素。我们可以这样做:
[x[3], x[7], x[2]]
[71, 86, 14]
还有一种方法,我们以一个数组的方式将这些元素的索引传递给数组,也可以获得相同的结果:
ind = [3, 7, 4]
x[ind]
array([71, 86, 60])
当使用高级索引时,结果数组的形状取决于索引数组的形状而不是被索引数组的形状:
ind = np.array([[3, 7],
[4, 5]]) # 索引数组是一个2x2数组,结果也将会是一个2x2数组
x[ind]
array([[71, 86],
[60, 20]])
高级索引也支持多维数组。例如:
X = np.arange(12).reshape((3, 4))
X
array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
就像普通索引一样,第一个参数代表行,第二个参数代表列:
row = np.array([0, 1, 2])
col = np.array([2, 1, 3])
X[row, col]
array([ 2, 5, 11])
结果中的第一个值是x[0, 2]
,第二个值是x[1, 1]
,第三个值是x[2, 3]
。高级索引的多个维度组合方式也遵守广播的规则,请查阅在数组上计算:广播。因此,如果我们在上面的行索引数组中增加一个维度,结果将变成一个二维数组:
X[row[:, np.newaxis], col]
array([[ 2, 1, 3],
[ 6, 5, 7],
[10, 9, 11]])
这里,每个行索引都会匹配每个列的向量,就像我们在广播的算术运算中看到一样。例如:
row[:, np.newaxis] * col
array([[0, 0, 0],
[2, 1, 3],
[4, 2, 6]])
记住高级索引结果的形状是索引数组广播后的形状而不是被索引数组形状,这点非常重要。
结合我们前面学习过的索引方法,我们可以组合出更多更强大的操作:
print(X)
[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
我们可以将高级索引和简单索引进行组合:
实际上这就是个广播,将标量广播成一个向量。
X[2, [2, 0, 1]]
array([10, 8, 9])
我们也可以将高级索引和切片进行组合:
X[1:, [2, 0, 1]]
array([[ 6, 4, 5],
[10, 8, 9]])
还可以将高级索引和遮盖进行组合:
mask = np.array([1, 0, 1, 0], dtype=bool)
X[row[:, np.newaxis], mask]
array([[ 0, 2],
[ 4, 6],
[ 8, 10]])
所有这些索引操作可以提供用户非常灵活的方式来获取和修改数组中的数据。
高级索引的一个通用应用场景就是从一个矩阵的行中选取子数据集。例如,我们有一个
的矩阵,代表着一个
维平面上有
个点,例如下面的二维正态分布的点集合:
mean = [0, 0]
cov = [[1, 2],
[2, 5]]
X = rand.multivariate_normal(mean, cov, 100)
X.shape
(100, 2)
我们可以在散点图上绘制这些点:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn; seaborn.set() # 设置图表风格,seaborn
plt.scatter(X[:, 0], X[:, 1]);
下面我们使用高级索引来选择 20 个随机点。方法是先创建一个索引数组,里面的索引值是没有重复的,然后使用这个索引数组来选择点:
indices = np.random.choice(X.shape[0], 20, replace=False)
indices
array([11, 63, 29, 13, 19, 38, 27, 17, 0, 24, 14, 43, 77, 31, 15, 64, 46,
75, 67, 62])
selection = X[indices] # 使用高级索引
selection.shape
(20, 2)
下面我们来看看那些点被选中,让我们上图的基础上将选中的点圈出来:
plt.scatter(X[:, 0], X[:, 1], alpha=0.3)
plt.scatter(selection[:, 0], selection[:, 1],
facecolor='none', s=200);
这种策略经常用来划分数据集,比如用来验证统计模型正确性时需要的训练集和测试集划分,还有就是在回答统计问题时进行取样抽象。
前面我们看到高级索引能够被用来获取一个数组的部分数据,实际上它还能用来修改选中部分的数据。例如,我们手头有一个索引的数组,我们想将这些索引上的数据修改为某个值:
x = np.arange(10)
i = np.array([2, 1, 8, 4])
x[i] = 99
print(x)
[ 0 99 99 3 99 5 6 7 99 9]
我们可以使用任何赋值类型操作,例如:
x[i] -= 10
print(x)
[ 0 89 89 3 89 5 6 7 89 9]
请注意下,如果索引数组中有重复的元素的话,这种修改操作可能会导致一个潜在的意料之外的结果。例如:
x = np.zeros(10)
x[[0, 0]] = [4, 6]
print(x)
[6. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
4 跑到哪里去了呢?这个操作首先赋值x[0] = 4
,然后赋值x[0] = 6
,因此最后x[0]
的值是 6。
上面的例子还算比较清晰,再看下面这个操作:
i = [2, 3, 3, 4, 4, 4]
x[i] += 1
x
array([6., 0., 1., 1., 1., 0., 0., 0., 0., 0.])
我们期望的结果可能是x[3]
的值是 2,而x[4]
的值是 3,因为这两个元素都多次执行了加法操作。但是为何结果不是呢?这是因为x[i] += 1
是操作x[i] = x[i] + 1
的简写,而x[i] + 1
表达式的值已经计算好了,然后才被赋值给x[i]
。因此,上面的操作不会被扩展为重复的运算,而是一次的赋值操作,造成了这种难以理解的结果。
如果我们真的需要这种重复的操作怎么办?对此,NumPy(版本 1.8 以上)提供了at()
ufunc 方法可以满足这个目的,如下:
x = np.zeros(10)
np.add.at(x, i, 1)
print(x)
[0. 0. 1. 2. 3. 0. 0. 0. 0. 0.]
at()
方法不会预先计算表达式的值,而是每次运算时实时得到,方法在一个数组x
中取得特定索引i
,然后将其取得的值与最后一个参数1
进行相应计算,这里是加法add
。还有一个类似的方法是reduceat()
方法,你可以从 NumPy 的文档中阅读它的说明。
你可以使用上面的方法对数据进行高效分组,用于定义自己的直方图。例如,设想我们有 1000 个值,我们想将它们分别放入各个不同的数组分组中。我们可以使用at
函数,例如:
np.random.seed(42)
x = np.random.randn(100) # 获得一个一维100个标准正态分布值
# 得到一个自定义的数据分组,区间-5至5平均取20个点,每个区间为一个数据分组
bins = np.linspace(-5, 5, 20)
counts = np.zeros_like(bins) # counts是x数值落入区间的计数
# 使用searchsorted,得到x每个元素在bins中落入的区间序号
i = np.searchsorted(bins, x)
# 使用at和add,对x元素在每个区间的元素个数进行计算
np.add.at(counts, i, 1)
counts 现在包含着每个数据分组中元素的个数,换句话来说,就是直方图:
Matplotlib 3.1 开始,linestyle 关键字参数已经过时,后续版本会抛弃。下面代码依据最新参数更改为 drawstyle 或 ds。
# 用图表展示结果
plt.plot(bins, counts, ds='steps');
当然,如果每次要画直方图的时候,都要经过这么复杂的计算,很不方便。这也就是为什么 Matplotlib 提供了plt.hist()
方法的原因,可以用一行代码完成上面操作:
plt.hist(x, bins, histtype='step');
这个函数会创建一个和上图基本完全一样的图形。Matplotlib 使用np.histogram
函数来计算数据分组,这个函数进行的计算和我们上面的代码非常接近。我们比较一下这两个方法:
print("NumPy routine:")
%timeit counts, edges = np.histogram(x, bins)
print("Custom routine:")
%timeit np.add.at(counts, np.searchsorted(bins, x), 1)
NumPy routine:
34.8 µs ± 2.21 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Custom routine:
18.2 µs ± 457 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
我们自己写的一行代码比 NumPy 优化的算法要快出许多,这是因为什么?如果你深入到np.histogram
函数的源代码进行阅读(你可以通过在 IPython 中输入np.histogram??
来查阅)的时候,你会发现函数除了搜索和计数之外,还做了其他很多工作;这是因为 NumPy 的函数要更加灵活,而且当数据量变大的时候能够提供更好的性能:
x = np.random.randn(1000000)
print("NumPy routine:")
%timeit counts, edges = np.histogram(x, bins)
print("Custom routine:")
%timeit np.add.at(counts, np.searchsorted(bins, x), 1)
NumPy routine:
75.4 ms ± 2.75 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Custom routine:
130 ms ± 1.77 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
上面的结果说明当涉及到算法的性能时,永远不可能是一个简单的问题。对于大数据集来说一个很高效的算法,并不一定也适用于小数据集,反之亦然(参见大 O 复杂度)。我们这里使用自己的代码实现这个算法,目的是理解上面的基本函数,后续可以使用这些函数构建自己定义的各种功能。在数据科学应用中使用 Python 编写代码的关键在于,你能掌握 NumPy 提供的很方便的函数如np.histogram
,你也能知道什么情况下适合使用它们,当需要更加定制的功能时你还能使用底层的函数自己实现相应的算法。
本节之前,我们主要关注 NumPy 中那些获取和操作数组数据的工具。本小节我们会介绍对 NumPy 数组进行排序的算法。这些算法在基础计算机科学领域是很热门的课题:如果你学习过相关的课程的话,你可能梦(或者根据你的经历,可能是噩梦)到过有关插入排序、选择排序、归并排序、快速排序、冒泡排序和其他很多很多名词。这些都是为了完成一件工作的:对数组进行排序。
例如,一个简单的选择排序会重复寻找列表中最小的值,然后和当前值进行交换,直到列表排序完成。我们可以在 Python 中用简单的几行代码完成这个算法:
import numpy as np
def selection_sort(x):
for i in range(len(x)):
swap = i + np.argmin(x[i:]) # 寻找子数组中的最小值的索引序号
(x[i], x[swap]) = (x[swap], x[i]) # 交换当前值和最小值
return x
x = np.array([2, 1, 4, 3, 5])
selection_sort(x)
array([1, 2, 3, 4, 5])
任何一个 5 年的计算机科学专业都会教你,选择排序很简单,但是对于大的数组来说运行效率就不够了。对于数组具有
个值,它需要
次循环,每次循环中需要
次比较和寻找来交换元素。大 O表示法经常用来对算法性能进行定量分析(参见大 O 复杂度),选择排序平均需要
:如果列表中的元素个数加倍,执行时间增长大约是原来的 4 倍。
甚至选择排序也远比下面这个bogo 排序算法有效地多,这是作者最喜爱的排序算法:
def bogosort(x):
while np.any(x[:-1] > x[1:]):
np.random.shuffle(x)
return x
x = np.array([2, 1, 4, 3, 5])
bogosort(x)
array([1, 2, 3, 4, 5])
这个有趣而粗笨的算法完全依赖于概率:它重复的对数组进行随机的乱序直到结果刚好是正确排序为止。这个算法平均需要
,即N乘以N的阶乘,明显的,在真实情况下,它不应该被用于排序计算。
幸运的是,Python 內建有了排序算法,比我们刚才提到那些简单的算法都要高效。我们从 Python 內建的排序开始介绍,然后再去讨论 NumPy 中为了数组优化的排序函数。
np.sort
和 np.argsort
虽然 Python 有內建的sort
和sorted
函数可以用来对列表进行排序,我们在这里不讨论它们。因为 NumPy 的np.sort
函数有着更加优秀的性能,而且也更满足我们要求。默认情况下np.sort
使用的是
快速排序排序算法,归并排序和堆排序也是可选的。对于大多数的应用场景来说,默认的快速排序都能满足要求。
对数组进行排序,返回排序后的结果,不改变原始数组的数据,你应该使用np.sort
:
x = np.array([2, 1, 4, 3, 5])
np.sort(x)
array([1, 2, 3, 4, 5])
如果你期望直接改变数组的数据进行排序,你可以对数组对象使用它的sort
方法:
x.sort()
print(x)
[1 2 3 4 5]
相关的函数是argsort
,它将返回排好序后元素原始的序号序列:
x = np.array([2, 1, 4, 3, 5])
i = np.argsort(x)
print(i)
[1 0 3 2 4]
结果的第一个元素是数组中最小元素的序号,第二个元素是数组中第二小元素的序号,以此类推。这些序号可以通过高级索引的方式使用,从而获得一个排好序的数组:
更好的问题应该是,假如我们希望获得数组中第二、三小的元素,我们可以这样做:
x[i[1:3]]
x[i]
array([1, 2, 3, 4, 5])
NumPy 的排序算法可以沿着多维数组的某些轴axis
进行,如行或者列。例如:
rand = np.random.RandomState(42)
X = rand.randint(0, 10, (4, 6))
print(X)
[[6 3 7 4 6 9]
[2 6 7 4 3 7]
[7 2 5 4 1 7]
[5 1 4 0 9 5]]
# 沿着每列对数据进行排序
np.sort(X, axis=0)
array([[2, 1, 4, 0, 1, 5],
[5, 2, 5, 4, 3, 7],
[6, 3, 7, 4, 6, 7],
[7, 6, 7, 4, 9, 9]])
# 沿着每行对数据进行排序
np.sort(X, axis=1)
array([[3, 4, 6, 6, 7, 9],
[2, 3, 4, 6, 7, 7],
[1, 2, 4, 5, 7, 7],
[0, 1, 4, 5, 5, 9]])
必须注意的是,这样的排序会独立的对每一行或者每一列进行排序。因此结果中原来行或列之间的联系都会丢失。
有时候我们并不是需要对整个数组排序,而仅仅需要找到数组中的K个最小值。NumPy 提供了np.partition
函数来完成这个任务;结果会分为两部分,最小的K个值位于结果数组的左边,而其余的值位于数组的右边,顺序随机:
x = np.array([7, 2, 3, 1, 6, 5, 4])
np.partition(x, 3)
array([2, 1, 3, 4, 6, 5, 7])
你可以看到结果中最小的三个值在左边,其余 4 个值位于数组的右边,每个分区内部,元素的顺序是任意的。
和排序一样,我们可以按照任意维度对一个多维数组进行分区:
np.partition(X, 2, axis=1)
array([[3, 4, 6, 7, 6, 9],
[2, 3, 4, 7, 6, 7],
[1, 2, 4, 5, 7, 7],
[0, 1, 4, 5, 9, 5]])
结果中每行的前两个元素就是该行最小的两个值,该行其余的值会出现在后面。
最后,就像np.argsort
函数可以返回排好序的元素序号一样,np.argpartition
可以计算分区后元素的序号。后面的例子中我们会看到它的使用。
下面我们使用argsort
沿着多个维度来寻找每个点的最近邻。首先在一个二维平面上创建 10 个随机点数据。按照惯例,这将是一个
的数组:
X = rand.rand(10, 2)
我们先来观察一下这些点的分布情况,散点图很适合这种情形:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn; seaborn.set() # 图表风格,seaborn
plt.scatter(X[:, 0], X[:, 1], s=100);
现在让我们来计算每两个点之间的距离。距离平方的定义是两点坐标差的平方和。应用广播([在数组上计算:广播]和聚合([聚合:Min, Max, 以及其他]函数,我们可以使用一行代码就能计算出所有点之间的距离平方:
dist_sq = np.sum((X[:, np.newaxis, :] - X[np.newaxis, :, :]) ** 2, axis=-1)
上面的这行代码包含很多的内容值得探讨,如果对于不是特别熟悉广播机制的读者来说,看起来可能会让人难以理解。当你读到这样的代码的时候,将它们打散成一步步的操作会有帮助:
# 计算每两个点之间的坐标距离
differences = X[:, np.newaxis, :] - X[np.newaxis, :, :]
differences.shape
(10, 10, 2)
# 计算距离的平方
sq_differences = differences ** 2
sq_differences.shape
(10, 10, 2)
# 按照最后一个维度求和
dist_sq = sq_differences.sum(-1)
dist_sq.shape
(10, 10)
你可以检查这个矩阵的对角线元素,对角线元素的值是点与其自身的距离平方,应该全部为 0:
dist_sq.diagonal()
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
确认正确。现在我们已经有了一个距离平方的矩阵,然后就可以使用np.argsort
函数来按照每行来排序。最左边的列就会给出每个点的最近邻:
nearest = np.argsort(dist_sq, axis=1)
print(nearest)
[[0 3 9 7 1 4 2 5 6 8]
[1 4 7 9 3 6 8 5 0 2]
[2 1 4 6 3 0 8 9 7 5]
[3 9 7 0 1 4 5 8 6 2]
[4 1 8 5 6 7 9 3 0 2]
[5 8 6 4 1 7 9 3 2 0]
[6 8 5 4 1 7 9 3 2 0]
[7 9 3 1 4 0 5 8 6 2]
[8 5 6 4 1 7 9 3 2 0]
[9 7 3 0 1 4 5 8 6 2]]
结果中的第一列是 0 到 9 的数字:这是因为距离每个点最近的是自己,正如我们预料的一样。
上面我们进行了完整的排序,事实上我们并不需要这么做。如果我们只是对最近的
个邻居感兴趣的话,我们可以使用分区来完成,只需要在距离平方矩阵中对每行进行
分区,只需要调用np.argpartition
函数即可:
K = 2
nearest_partition = np.argpartition(dist_sq, K + 1, axis=1)
为了展示最近邻的网络结构,我们在图中为每个点和它最近的两个点之间连上线:
plt.scatter(X[:, 0], X[:, 1], s=100)
# 为每个点和它最近的两个点之间连上线
K = 2
for i in range(X.shape[0]):
for j in nearest_partition[i, :K+1]:
# 从X[i]连线到X[j]
# 使用一些zip的魔术方法画线
plt.plot(*zip(X[j], X[i]), color='black')
图上的每个点都和与它最近的两个点相连。初看起来,你可能注意到有些点的连线可能超过 2 条,这很奇怪:实际原因是如果 A 是 B 的最近两个近邻之一,并不代表着 B 也必须是 A 的最近两个近邻之一。
虽然使用广播和逐行排序的方式完成任务可能没有使用循环来的直观,但是在 Python 中这是一种非常有效的方式。你可能忍不住使用循环的方式对每个点去计算它相应的最近邻,但是这种方式几乎肯定会比我们前面使用的向量化方案要慢很多。向量化的解法还有一个优点,那就是它不关心数据的尺寸:我们可以使用同样的代码和方法计算 100 个点或 1,000,000 个点以及任意维度数的数据的最近邻。
最后,需要说明的是,当对一个非常大的数据集进行最近邻搜索时,还有一种基于树或相似的算法能够将时间复杂度从
优化到
或更好。其中一个例子是 KD-Tree[1]。
大 O 复杂度是一种衡量随着输入数据的增加,需要执行的操作的数量的量级情况的指标。要正确使用它,需要深入了解计算机科学的理论知识,要和其他相关的概念如小 O 复杂度,大
复杂度,大
复杂度区分开来,更加不容易。虽然精确地描述出这些复杂度是属于算法的范畴,除了学院派计算机科学理论的测验和评分以外,你在其他应用领域很难看到这些严格的定义和划分。在数据科学领域中,我们不会使用这样死板的大 O 复杂度概念,虽然这和算法领域的概念在精确程度上有一定差距。带着对理论学者和学院派的歉意,本书将一直使用对大 O 复杂度的这种非精确概念解释。
大 O 复杂度,简单来说,会告诉你当你的数据增大时,你的算法运行需要的时间。例如你有一个
(英文读作"Order
")的算法,对于N=1000 的数据量,它需要运行 1 秒,那么对于N=5000 的数据量,算法需要执行的时间就为 5 秒。如果你的算法复杂度为
(英文读作"Order N squared"),对于N=1000 的数据量需要运行 1 秒,那么你可以预期当数据量增长为N=5000 时,运行时间为 25 秒。
对于我们的目标来说,N通常代表着数据集的大小(数据点的数量,维度数等)。当我们需要分析的数据样本量达到百万级或十亿级时,
和
之间的差距将会是巨大的。
请记住大 O 复杂度本身并不能告诉你实际上运算消耗的时间,它仅仅能够告诉你当N变化时,运行时间会怎样随之发生变化。通常来说,
复杂度的算法被认为肯定要比
复杂度的算法要好。但对于小的数据集来说,好的大 O 复杂度算法并不一定能带来更快的执行效率。例如,某个特定情况下,
复杂度的算法可能需要 0.01 秒的运行时间而
复杂度的算法可能需要 1 秒。但是如果将N增大 1000 倍,那么
复杂度的算法将会胜出。
我们这里使用的这种非严格定义的大 O 复杂度对于算法的性能也是有指示意义的,在本书的后续部分当我们讨论到算法范畴时都会应用到它。
虽然我们的数据很多情况下都能表示成同种类的数组,但是某些情况下,这是不适用的。本小节展示了如何使用 NumPy 的结构化数组和记录数组,它们能够提供对于复合的,不同种类的数组的有效存储方式。本小节的内容,包括场景和操作,通常都会在 Pandas 的Dataframe
中使用。
import numpy as np
考虑一下,我们有一些关于人的不同种类的数据(例如姓名、年龄和体重),现在我们想要将它们保存到 Python 程序中。当然它们可以被保存到三个独立的数组之中:
name = ['Alice', 'Bob', 'Cathy', 'Doug']
age = [25, 45, 37, 19]
weight = [55.0, 85.5, 68.0, 61.5]
显然这种做法有些原始。没有任何额外的信息让我们知道这三个数组是关联的;如果我们可以使用一个结构保存所有这些数据的话,会更加的自然。NumPy 使用结构化数组来处理这种情况,结构化数组可以用来存储复合的数据类型。
回忆前面我们创建一个简单数组的方法:
x = np.zeros(4, dtype=int)
我们也可以类似的创建一个复合类型的数组,只需要指定相应的 dtype 数据类型即可:
# 使用复合的dtype参数来创建结构化数组
data = np.zeros(4, dtype={'names':('name', 'age', 'weight'),
'formats':('U10', 'i4', 'f8')})
print(data.dtype)
[('name', '<U10'), ('age', '<i4'), ('weight', '<f8')]
这里的U10
代表着“Unicode 编码的字符串,最大长度 10”,i4
代表着“4 字节(32 比特)整数”,f8
代表着“8 字节(64 比特)浮点数”。本节后面我们会介绍其他的类型选项。
现在我们已经创建了一个空的结构化数组,我们可以使用上面的数据列表将数据填充到数组中:
data['name'] = name
data['age'] = age
data['weight'] = weight
print(data)
[('Alice', 25, 55. ) ('Bob', 45, 85.5) ('Cathy', 37, 68. )
('Doug', 19, 61.5)]
正如我们希望那样,数组的数据现在被存储在一整块的内存空间中。
使用结构化数组的方便的地方是你可以使用字段的名称而不是序号来访问元素值了:
# 获得所有的名字
data['name']
array(['Alice', 'Bob', 'Cathy', 'Doug'], dtype='<U10')
# 获得第一行
data[0]
('Alice', 25, 55.)
# 获得最后一行的名字
data[-1]['name']
'Doug'
使用布尔遮盖,我们能写出更加复杂但易懂的过滤条件,比如年龄的过滤:
# 获得所有年龄小于30的人的姓名
data[data['age'] < 30]['name']
array(['Alice', 'Doug'], dtype='<U10')
请注意,如果你想要完成的工作比上面的需求还要复杂的话,你应该考虑使用 Pandas 包,下一章的主要内容。我们将会看到,Pandas 提供了Dataframe
对象,它是一个在 NumPy 数组的基础上构建的结构,提供了很多有用的数据操作功能,包括上面结构化数组的功能。
结构化数组的数据类型可以采用集中方式指定。前面我们介绍了字典的方式:
np.dtype({'names':('name', 'age', 'weight'),
'formats':('U10', 'i4', 'f8')})
dtype([('name', '<U10'), ('age', '<i4'), ('weight', '<f8')])
需要说明的是,数字类型也可以通过 Python 类型或 NumPy 数据类型来指定:
np.dtype({'names':('name', 'age', 'weight'),
'formats':((np.str_, 10), int, np.float32)})
dtype([('name', '<U10'), ('age', '<i4'), ('weight', '<f4')])
一个复合类型也可以使用一个元组的列表来指定:
np.dtype([('name', 'S10'), ('age', 'i4'), ('weight', 'f8')])
dtype([('name', 'S10'), ('age', '<i4'), ('weight', '<f8')])
如果类型的名称并不重要,你可以省略它们,你甚至可以在一个以逗号分隔的字符串中指定所有类型:
np.dtype('S10,i4,f8')
dtype([('f0', 'S10'), ('f1', '<i4'), ('f2', '<f8')])
类型的字符串形式的缩写初看起来很困惑,但实际上它们都是依据简单原则得到的。第一个(可选的)字符是<
或>
,代表这类型是小尾
还是大尾
,用来指定存储的字节序。下一个字符指定数据类型:字符、字节、整数、浮点数或其他(见下表)。最后一个字符代表类型的长度。
字符 | 说明 | 举例 |
---|---|---|
'b' | 字节 | np.dtype('b') |
'i' | 带符号整数 | np.dtype('i4') == np.int32 |
'u' | 无符号整数 | np.dtype('u1') == np.uint8 |
'f' | 浮点数 | np.dtype('f8') == np.int64 |
'c' | 复数 | np.dtype('c16') == np.complex128 |
'S', 'a' | 字符串 | np.dtype('S5') |
'U' | Unicode 字符串 | np.dtype('U') == np.str_ |
'V' | 原始数据 | np.dtype('V') == np.void |
除此之外,还可以定义更加复杂的复合类型。例如,你可以创建一个类型,其中的每一个元素都是一个数组或矩阵。下面,创建一个数据类型内含一个mat
对象,是一个
的浮点数矩阵:
tp = np.dtype([('id', 'i8'), ('mat', 'f8', (3, 3))])
X = np.zeros(1, dtype=tp)
print(X[0])
print(X['mat'][0])
(0, [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]])
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
X
数组中的每个元素都有一个id
和一个
的矩阵。为什么需要这样用,为什么不用一个多维数组或者甚至是 Python 的字典呢?原因是 NumPy 的dtype
数据类型直接对应这一个 C 语言的结构体定义,因此存储这个数组的内容内容可以直接被 C 语言的程序访问到。如果你在写访问底层 C 语言或 Fortran 语言的 Python 接口的话,你会发现这种结构化数组很有用。
NumPy 还提供了np.recarray
对象,看起来基本和前面介绍的结构化数组相同,但是有一个额外的特性:字段不是使用字典关键字来访问,而是使用属性进行访问。前面我们使用关键字来访问数组的年龄字段:
data['age']
array([25, 45, 37, 19])
如果我们使用记录数组来展示数据化,我们可以使用对象属性方式访问年龄字段,少打几个字:
data_rec = data.view(np.recarray)
data_rec.age
array([25, 45, 37, 19])
这样做的缺点是,当按照对象属性来访问数组数据时,会有额外的性能损耗。下面的例子可以看到:
%timeit data['age']
%timeit data_rec['age']
%timeit data_rec.age
218 ns ± 23 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
3.76 µs ± 149 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
5.19 µs ± 242 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
[1]
KD-Tree: http://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KDTree.html