通过梳理大模型推理过程,定位KV cache 发生的阶段,阐述:
大模型推理的解码阶段,生成单个token时,Q值需要使用全部序列token的KV值计算注意力,但除当前新token的KV值外,其他token的KV值已经在上一轮中计算过,为了避免冗余计算,使用KV cache,以内存占用,换计算效率的方式,提高整体服务的吞吐量。

假设当前输入序列中已经存在n个token,将要预测第n+1个token,参考[2] 计算流程如下:
继续预测第 n+2个token时,需要重复以上步骤。此迭代过程中,存在严重的重复计算问题。即生成第 t+2 个token时,模型需要重新处理 T1 到Tn 的所有历史token,而它们的K/V向量在之前的步骤中已经计算过。这导致大量冗余计算,尤其是生成长文本时(如 t=1000),计算复杂度呈平方级增长。

最直接的优化方式:将输入token的KV值缓存下来,等下次计算时直接使用 — KV cache。
问题:缓存和不缓存的结果是否相同?
答案是相同的,利用分块矩阵可以得到一致的结果。 将矩阵A按行拆分为A1和A2两部分,然后分别和矩阵B相乘,最终结果按行拼接,该结果与不分拆结果一致。
在self-attention和MLP中,都是类似的矩阵乘法操作,因此历史计算过的部分进行缓存,避免整体输入导致的重复计算。

假设用户输入,其序列长度为n,经编码后为,通过以下公式获得KV cache。
产生的三个矩阵 ,其中的K和V向量作为KV cache被缓存,用作decoding阶段的计算。
生成的,作为MLP的输入进行前向计算。
此时输入,其序列长度为1,经编码后为,加载Prefill 生成的 KV 缓存。进行以下计算:
其中 , 分别与先前缓存的KV值进行拼接,形成新的KV cache。
生成矩阵 作为前向MLP模块的输入,最终输出单词的概率,根据不同的策略,输出最终的单词。
假设输入序列的长度为 s ,输出序列的长度为 t ,以float16来保存KV cache,KV cache的峰值显存占用大小为 :
第一个2表示K和V 两个cache缓存,第二个2表示float16占2个bytes。
举例:
以GPT3为例,对比KV cache与模型参数占用显存的大小。GPT3模型占用显存大小为350GB。
假设批次大小 b=64 ,输入序列长度 s=512 ,输出序列长度 t=32 ,则KV cache占用显存为
4blh(s+t)=164,282,499,072bytes≈164GB ,大约是模型参数显存的0.5倍。
单层计算量对比,关键参数定义
h:隐藏层维度(如1024)t:当前生成的token位置(输出序列长度)l:模型层数(如24层)
大模型推理的解码阶段,每次生成新token时,发现KV值存在重复计算,所以将计算过的历史KV值进行缓存,达到节约计算量的目的。KV cache明显是以内存换效率的方式,使得解码阶段优化KV 的内存占用成为关键。比如KV cache 量化,Page Attention 等都是通过减少KV cache的显存占用,最大化一次性批处理数量,进而提高推理的吞吐量。
操作 | 不使用KV Cache | 使用KV Cache |
|---|---|---|
Key/Value计算 | 需计算前t个token的K和V | 仅计算当前token的K和V |
计算量 | (线性变换) | (线性变换) |
注意力计算(QK^T + V) | (与KV Cache无关,必须计算) | (与KV Cache无关,必须计算) |
总计算量 |
每生成一个token,总节省量 =
参考:
[1] Zhihang Yuan, Yuzhang Shang, Yang Zhou, Zhen Dong, Zhe Zhou, Chenhao Xue, Bingzhe Wu, Zhikai Li, Qingyi Gu, Yong Jae Lee, Yan Yan, Beidi Chen, Guangyu Sun, Kurt Keutzer. LLM Inference Unveiled: Survey and Roofline Model Insights. arXiv:2402.16363
[2] https://zhuanlan.zhihu.com/p/686183300