前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >离散与提炼——一些关于向量召回算法优化方法的思考

离散与提炼——一些关于向量召回算法优化方法的思考

作者头像
Zilliz RDS
发布2021-01-18 11:18:50
1.3K0
发布2021-01-18 11:18:50
举报

✏️ 作者介绍:

周语馨,高级云智能工程师

最近做的很多向量召回的相关工作,主要集中在优化 Faiss 里面常用的几个算法,包括 IVFFlat 和 IVFPQ,并且针对这两个算法都做出了专门的优化。

前一阵子灵光乍现,想到了一种与具体算法无关的(或者更严格地说,与具体算法相关性较小的)优化方法,可以优化诸如 Flat、IVFFlat 或者 HNSW 等算法。我称之为“方法”而不是“算法”,是因为它遵从原有算法的逻辑,只是在计算过程中大幅降低了内存带宽,从而提升性能。因此经过优化的算法在召回率等方面完全不变。我把这种方法称为“离散与提炼(Discretize and Refine)”,以类比于 Google 的“Map and Reduce”,强调其方法论的地位。

开门见山,先来说说“离散与提炼”最核心的思想:

  • 离散:使用 int8 和 bfp16 来表达原始向量,计算距离过程中使用这些压缩后的向量,从而降低内存带宽;
  • 提炼:使用压缩后的向量计算得到的距离存在误差,但是足以剔除距离太远的点,之后在很小的一个候选集中使用原始向量计算精确距离。

接下来开始让我帮你一步步理解。

1. 背景

向量召回其实就是经典的 KNN(k-NearestNeighbor)问题。KNN 问题的表述为:假设现在有 N 个向量 yi,当给定一个查询向量 x 时,找出距离 x 最近的 K 个 yi。为了方便理解,我以最简单的暴力搜索算法(Flat)为例。那么当给定一个查询向量 x 时,需要把 x 到每一个 yi 的距离都计算出来。

在 Faiss 以及其他的 ANN(近似近邻)搜索库中,向量都是使用 32 位浮点数(下文中简称为“fp32”)表达的,因此每个维度占用 4 个字节。比如一个 128 维度的向量就占用 512 字节。把向量维度记作 d。那么一次查询需要访问的内存容量有 d*4*N,而且是顺序访问。不管是通过实验还是 Faiss wiki 都可以得知,Flat 类(纯 Flat 或者 IVFFlat 等)算法的性能瓶颈就在于内存带宽,大多数时候 CPU 的计算单元都在等待 yi 被送到 cache 中。而且通过统计发现,当压力较大时,内存读的时间占到了整个算法的 90% 以上。

因此,如果能够使用 int8 来表达向量的每一维,那么一次查询的内存访问量会降低到 d*1*N,即本来的四分之一,理论上性能会提升到接近原来的 4 倍。这就是 “Discretize and Refine” 的动机。

2. 最简单的离散化

先来考虑最容易处理的情况,即所有 yi 的每一维都落在区间 [-128.0, 127.0) 中。

此时,把 yi 的每一维都四舍五入到最接近的整数上,即 -128 到 127 这 256 个整数,那么就可以使用 int8 存储。把 yi 经过四舍五入(下称为“离散化”)后使用 int8 表达的向量记作 zi,并且计算 yi 到 zi 的距离,记作 ei。下图是一个简单的例子来理解 zi 和 ei 的几何含义。

很直观,yi 就是实数(暂且把 fp32 看作实数)空间中的点,而 zi 就是其最接近的格点(坐标均为整数的点),而 ei 就是两者的距离。

那么经过这样的预处理后,我们有了 N 个 yi,N 个 zi 和 N 个 ei,分别使用 fp32、int8 和 fp32 表达,占用 d*4*N 字节、d*1*N 字节和 4*N 字节。 当给定一个 x 时,先计算 x 到每一个 zi 的距离,记作 bi。这里需要强调两点:

  1. x 是用户传入的、未经离散化的、使用 fp32 表达的原始向量;
  2. zi 的每一维 int8 被依次读入寄存器后,再通过 CPU 指令在寄存器中转成 fp32,与 x 的每一维 fp32 计算差的平方,进行累加,最后开方。

因此,从内存到 CPU 的总线上传输的都是 int8,而计算的结果是 fp32,即 x 到 zi 的精确距离。把 bi 减去 ei,得到 Li。至此,搜索的第一步可以用数学语言表达:

Li=|x−zi|−ei

把 x 到yi 的精确距离记作 di,那么 Li 是 di 的下界。为什么呢?看下图:

图中的绿线即为 bi = |x − zi|,是 x 到 zi 的精确距离。红线为 ei = |yi − zi| ,是 yi 到 zi 的精确距离。根据三角不等式,|x − yi| ≥ |x − zi| − |yi − zi| ,因此有 di ≥ bi − ei = Li 。至此,我们使用离散化后的数据得到了 N 个 Li。

3. 最简单的提炼

如何根据 Li 来得到最小的 K 个 di 呢?这就是整个方法的第二步——提炼(Refine)。

先来看一个生活中的例子。有 100 个运动员赛跑,选出前三名。终点处的裁判手里有个秒表,每一个运动员到达终点线,裁判都会掐表来记录时间。由于裁判是人,存在几毫秒的反应时间,因此第 i 个运动员过线的客观时间 Ti,与秒表记录下的测量时间 ti 之间存在误差。但是,ti 与 Ti 之间存在非常大的正相关性,如果 ti 远比 tj 小,那么几乎可以断定 Ti 比 Tj 小。只有在 ti 和 tj 都比较小且很接近时(比如第三名、第四名几乎同时过线),才需要调取录像来获取 Ti 和 Tj 再进行比较。这个例子很直观地表达了提炼的思路——通过估算距离就能排除掉绝大部分不可能入选 topK 的点,而在剩余的少数可能入选 topK 的点中使用精确距离进行最后的筛选。

这里,我们选用 Li 作为估算距离进行提炼。为什么不选 bi 呢?后续算法会展现 Li 的妙用。

既然是最简单情况,这里的提炼算法也是力求最容易理解的,只为讲清数学原理,在性能上并不是最优的,后续会给出更优化的实现方法。

首先,将 N 个 i 与 Li 打包成(i,Li)这样的元组,并且按照 Li 从小到大排序,把排好序的元组数组记作 S。比如 S = [ (14,21.5),(31,25.2),(16,33,7), ...],含义即为 到 x 的距离至少为 21.5,到 x 的距离至少为25.2,到 x 的距离至少为 33.7……

然后,初始化一个最大堆 topK。topK 的最大容量为 K,往里面丢入若干个(label, distance),只会保留最多 K 个 distance 最小的(label, distance)。并且把 topK 的门槛值记为 T(当 topK 中的项数小于 K 时,T 为+∞,否则 T 为 topK 中最大的 distance)。T 的数学含义为:如果试图丢入(label, distance),而 distance>=T,那么(label, distance)不可能成为 topK 其中一项。

代码语言:javascript
复制
for (label, lower_bound) in S:    
    if lower_bound >=topK.T:        
        break    
    distance = |x -y[label]|    
    topK += (label,distance)

最后,对 S 和 topK 执行如下循环:

代码依次遍历 S 中的每一对(label, lower_bound),如果 lower_bound 不小于 topK 的门槛值,那么算法终止,此时的 topK 即为所求。否则,计算 x 与 yi 的精确距离 distance,并且把(label, distance)丢入 topK 中。

代码非常简单。但是,为什么代码第 2 行的判断可以确保当前的 topK 就是最后的结果呢?

我们知道,S 是根据(i, Li)中的 Li 从小到大排序的,如果其中第 i 项满足Si. lower_bound ≥ topK. T ,那么第 i 项后续的任一项都大于 T,即:

Sj.lower_bound≥Si.lower_bound≥topK.T,∀j≥i

又因为 L 作为 d 的下界,有:

dSj.label ≥Sj.lower_bound

联立两式,有:

dSj.label ≥topK.T,∀j≥i

所以,一旦代码第 2 行的判断成立,那么 S 中后续的所有 yi 都不可能入选 topK。

计算 N 个 Li 时,顺序访问了 N 个 zi 和 ei,共 d*1*N+4*N 字节。而提炼过程中,只需要访问 Li 最小的 M 个 yi,共 d*4*M 字节。一次查询的内存访问量,从 d*4*N 字节变成了 d*N+4*N+d*4*M 字节。经验上,M 大约在 K 到 4*K 之间(因不同数据集而异)。以 SIFT1M 数据集(d=128)、K=100 为例,内存访问量从 128*4*1M=512MB 降到 128*1M+4*1M+128*4*4*100≈132MB。事实上,计算 Li 的过程中,因为需要在寄存器中把 int8 转回 fp32,所以计算指令反而变多了。但是由于 Flat 类算法瓶颈几乎完全在内存访问上,因此以增加 CPU 指令为代价来换内存带宽的降低,是非常值得的。至此,离散与提炼的加速原理已经明了了。

4. 基于 heap 的提炼

提炼步骤中有一个明显可以优化的地方,即对 N 个 Li 排序。上文提到,S 中只有前 M 项会被访问,而 M 通常远小于 N。算法只要求前 M 项按照 Li 排序即可,而之后的 N-M 项既然都不会被访问,那么花费大部分 CPU 时间排序后 N-M 项就很不划算。

改进方法是使用堆(优先队列同理)。如果把 N 个(i, Li)放入最小堆中,每次从中取出一个当前 L 最小的(i, Li),那么就可以避免完全排序。用伪代码表述就是:

代码语言:javascript
复制
min_heap = build_min_heap ((1, L[1]), (2, L[2]), ... , (N, L[N]))
while True:    
    (label, lower_bound) = min_heap.pop()    
    if lower_bound >= topK.T:        
        break    
    distance = |x - y[label]|    
    topK += (label, distance)

建堆过程复杂度是 O(N),进行 M 次取出操作复杂度是 O(M*logN),加起来就是 O(N+MlogN),比原本的完全排序的 O(NlogN) 的复杂度要低很多。

5. 基于 nth_element 的提炼

仔细观察,会发现基于最小堆的提炼仍然不是最优的。在提炼过程中,后 N-M 完全不需要排序,然而堆算法会建堆,尽管不是完全排序,但是会调整至小根二叉树的结构,可以看作一种“浅排序“,也就是仍然浪费了不必要的 CPU 时间。

既然经验上 M 大约为 K 到 4K,那么是否可以使用 nth_element 算法把 N 个项中最小的 4K 个项取到数组最开头处,然后只对这 4K 个项做完全排序呢?如果很不巧,遍历完这 4K 个仍然无法结束算法,那么就继续从剩下的 N-4K 个项中抽取出 4K 个……如此循环。算法用伪代码表示如下:

6. 使用线性变化的离散化

上文中,我们假定 yi 的每一维都落在 int8 的表达访问之内,因此离散化操作只是将浮点数近似到最接近的整数。但是,如果 yi 的各个维度的取值区间很大呢,比如[-10000, 10000)?或者不是关于原点对称的,比如[0, 10000)?这个时候,我们需要利用欧几里得距离的两个特性:

  1. 坐标系的平移不会改变两点之间的距离;
  2. 所有的维度同时做相同比例的线性缩放,则两点之间的距离也被按该比例缩放。

这两点都很直观,当然,如果需要数学证明也非常容易。

因此,对于一个超出 int8 表达范围的空间,先对每一个维度分别使用一个恰当的偏置,把取值区间平移到关于原点对称。之后选取一个恰当的比例,把每个维度做一个缩放,使得每一维都落入 int8 的表达范围。该操作用数学语言表达就是:

y′ = ky + b

实际操作中,对于每一个加入的 yi,都先通过线性变换得到 yi',然后对 yi' 做离散化得到 zi 和 ei。

搜索时,对于用户给定的 x,也要先执行 x′ = kx + b ,将 x 映射到 int8 空间中。之后计算 Li 的过程中使用 x'、zi 和 ei 作为参数。需要注意的时候,最后 Li 需要除以 k。搜索的第一步用数学语言表达为:

由于后续的提炼过程中,计算的精确距离都是基于原始空间的,所以需要将 Li 除以 k,映射回原始空间中去。

7. 使用残差的离散化

真实的数据集往往有一定的聚类特征,很可能表现为下图所示的样子:

这样的聚类特征会导致上文中的离散化方法产生较大的信息损失。比如,上图中的所有绿色点可能都被离散化到了(100, 80)这个格点上,或者附近少数几个格点上。那么搜索时,Li 就过于“粗糙”,点与点之间的区分度不大,以至于不能高效地提炼。

在 ANN 算法中,IVF 类算法特别适用于这种具有聚类特征的数据集。IVF 算法在构建索引时,将原始数据聚类成 nlist 个类(每个类的聚类中心记作 Ci),每个点属于其中一个类。当给定一个待搜索的 x 时,找到距离 x 最近的 nprobe 个 Ci,在这 nprobe 个聚类中的点中执行暴力搜索。IVF 算法利用了聚类特征大幅减少了候选集的大小(原本的 nprobe/nlist),从而牺牲一定精度来换取性能的大幅提升。

基于 IVF 算法,离散化操作可以进一步优化。对于每一个 yi,假设其所在聚类的中心点为 Cy 。使用 ri = yi − Cy ,即 yi 的残差作为“基于线性变换的离散化”的输入,得到 zi 与 ei。该过程的数学本质是,以每个聚类中心为原点建立一个坐标系,在该“局部坐标系”中对属于该聚类的点做离散化。如此即可解决信息损失的问题。

当给定 x 时,按照 IVF 算法找出最近的 nprobe 个聚类。对于每一个聚类,计算 x 的残差 q = x − Cx ,然后在该聚类中计算 Li。该操作本质上就是把 x 分别映射到每一个“局部坐标系”中计算距离下界。

8. 使用 bfp16 应对出界点

上文中所有的讨论,都是基于一个假设:向量的取值范围都是有界的。在有界的情况下,可以通过线性变换的方法将向量映射到 int8 的表达范围中。但是,如果:

  1. 向量的取值范围是无界的;
  2. 绝大多数向量都落在某个有限范围内,可是极少数向量远远地出界。

第一种情况显然无法利用 int8 来离散化。

第二种情况中,如果选用绝大多数向量所在的那个有限范围,就不得不剔除那些出界的向量。而如果为了这极少数的向量而扩大取值范围,就会使得离散化的信息损失变大,从而降低提炼效率。

这两种情况,都可以利用 bfp16 来离散化。fp32 使用最高位表达符号,之后的 8 位表达指数,最低的 23 位表达小数部分。而 bfp16 的符号位、指数位与 fp32 相同,唯一的区别是只使用 7 位来表达小数。

bfp16 拥有与 fp32 一样的非常宽阔的表达范围,只是精度会更低。把 fp32变成 bfp16 的过程,其实就是把一个 23 位小数保留到一个 7 位小数(当然,这里的小数是指二进制中的概念)。使用 bfp16之后,由于带宽可以降低到原本的一半,因此性能仍然有明显提升(接近原本的两倍)。

在上述的第一种情况中,可以全部使用 bfp16 做离散化,虽然性能不及使用 int8,但是性能提升到接近两倍仍然非常有诱惑力。

而在上述的第二种情况中,可以混合使用 int8 和 bfp16,即绝大多数能够落入有限范围的向量,使用 int8 做离散化,而少部分“出界”的向量,使用bfp16 做离散化。这样,即能够享受到 int8 的性能,又能保证"出界点"也被正确处理。

9. 实测性能

测试平台为 Intel(R) Xeon(R) Platinum 8268 CPU @ 2.90GHz,单个 socket 上 6 个内存通道全部插满了内存条。选用的算法是 IVF1024,Flat (nprobe=20),测试数据集分别为 SIFT1M、GIST1M。测试结果如下:

横坐标为执行搜索的线程数。由于 8268 单 socket 有 24 个物理核心,48 个超线程,因此线程数从 1 逐渐增加到 48。图中共用 6 条线,深蓝色线、灰色现和浅蓝色线分别是原始的 IVF1024, Flat 算法在使用 SSE4、AVX2 和 AVX-512 时的总吞吐量(QPS,即 query per second),而橙色线、黄色线和绿色线分别是优化后 IVF1024, Flat算法在使用 SSE4、AVX2 和 AVX-512 时的总吞吐量。观察这六条线的走势,可以得出如下结论:

  1. 对于原始算法,随着线程数的增加,QPS 很快就触顶,不再增加;
  2. 对于原始算法,更先进的计算指令集几乎发挥不了优势;
  3. 对于优化后的算法,随着线程数增加,总吞吐量逐步上升,最多能达到 3~3.5 倍的性能;
  4. 对于优化后的算法,更先进的计算指令集可以充分发挥优势。

很显然,优化后的算法缓解了内存带宽瓶颈问题,使得多核平台与先进的向量化指令集可以充分发挥性能优势。事实上,对于 SIFT1M 数据集,原始算法需要的内存带宽高达 91GB/s,而优化后的算法仅需要 20GB/s,这使得使用 DCPMM 等廉价存储器来降低内存采购成本成为可能。

欢迎加入 Milvus 社区

github.com/milvus-io/milvus | 源码

milvus.io | 官网

milvusio.slack.com | Slack 社区

zhihu.com/org/zilliz-11| 知乎

zilliz.blog.csdn.net | CSDN 博客

space.bilibili.com/478166626 | Bilibili

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2021-01-15,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 ZILLIZ 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. 背景
  • 2. 最简单的离散化
  • 3. 最简单的提炼
  • 4. 基于 heap 的提炼
  • 5. 基于 nth_element 的提炼
  • 6. 使用线性变化的离散化
  • 7. 使用残差的离散化
  • 8. 使用 bfp16 应对出界点
  • 9. 实测性能
  • 欢迎加入 Milvus 社区
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档