
https://medium.com/@prathamgrover777/kv-caching-attention-optimization-from-o-n%C2%B2-to-o-n-8b605f0d4072
我们见过 LLM 如何逐字逐句地敲出上千字的回答,仿佛“边想边说”。表面顺滑,背后却低效得惊人。
在生成第 t 步时,模型必须确保下一个词与之前所有内容保持一致:你的提示、已生成的部分、系统指令,乃至任何隐藏上下文。实际上,模型会重新通过所有 Transformer 层重建之前全部的隐藏状态,并再次计算 Query、Key 和 Value,即便前面的 token 丝毫未变。这种重复计算逐层、逐头进行,没有任何复用。如图 1 所示,由于之前步骤的结果未被复用,每个 token 的计算量随序列长度持续增加。

图 1. 解码阶段每个 token 的计算成本(朴素 vs KV 缓存)。
为什么这种方式扩展性这么差? 想象你在写一段文字,每添加一个新句子之前,都要从头把整篇文档重新读一遍。然后写下一个句子时,再从头读一遍。如此反复。这就是朴素解码循环在做的事情。
现在把它放进 transformer:
现在,把这个问题放到现代 LLMs 的规模下:
更致命的是,模型对过去的 Key 和 Value 毫无记忆;因此在第 1000 个 token 时,它得把 token 1 到 999 全部重新算一遍。这就导致解码的时间复杂度是 **O(n²)**。
结果就是巨大的冗余。

图 2. 无 KV 缓存的朴素解码
如图 2 所示,模型不会记住过去的 K/V 投影,而是对每个新 token 都重新计算它们。 当 t = 2 时 → 重新计算 token **1**当 t = 3 时 → 重新计算 token 1 和 **2**对于 t = 4 → 重新计算 tokens **1, 2, 3**对于 t = n → 重新计算 tokens 1 到 n−1
在每一层内部,对于每个 token,它执行:
这些结果从未被复用,因此到第 n 个 token 时,你已经重复了 (n–1) × L × H 次!这意味着计算量持续增长, 永远不会稳定 。
在解决它之前,我们必须先理解瓶颈所在: 搬运数据代价高昂。GPU 做数学运算极快,但真正的成本往往在于把正确的数据在正确的时间送到正确的位置。
举个简单的例子,把它想象成厨房里的厨师:
用 ML 的话说:
这就是为什么序列变长时生成会变慢——不是因为计算变复杂,而是因为 GPU 等数据 的时间比真正计算的时间还多。
模型的权重和对话的历史 token 都必须放在显存里才能快速处理。但在朴素的解码中,每生成一个新 token,厨师(GPU)就得折回储藏室(内存)再拿一遍同样的食材(历史 K/V)。储藏室与灶台之间的路越来越挤,走路时间越来越多,真正炒菜的时间越来越少。
那种对旧数据持续、重复的抓取,就是我们所说的带宽之痛 。“带宽”指的是这条通路的容量,“痛”则源于把刚才已经存在的数据再次塞进去所造成的严重拥堵。带着这个概念,我们来看看到底在哪个环节,这一过程在规模扩大时会崩溃。更具体地说,注意力机制会变成受内存限制,而非计算限制 。你的 GPU 空转着,等待从显存取数,尽管它的算力足以瞬间完成运算。
此刻,一个自然的问题浮现:**“等等……我们为什么每次都重复同样的投影?”****KV 缓存**就是你拒绝重读过去的瞬间。简单说,KV 缓存就是:一旦某 token 在某层里的 Key 和 Value 向量算完,我们把它存进 GPU 内存,而不是直接扔掉。
在注意力机制中,每个 token 被转换为两个紧凑向量:Key(K) 和 Value(V),它们描述_该 token 应如何与后续 token 交互_ 。生成方式是将 token 的嵌入(x)通过该层的权重矩阵:
Key = X @ Wₖ Value = X @ Wᵥ
这并非只计算一次,而是对模型的每一层、每一个注意力头,都各自计算一套 K 和 V。 举例:若模型有 32 层、每层 32 个头,则每个 token 要计算 32 × 32 = 1024 组 K/V 投影。
诀窍在这里:
那为何每一步都要重新计算它们呢?
于是,不再 :每生成一个新 token → 为所有旧 token 重新计算 K/V。 而是: 一次性算出 K/V → 存起来 → 后续所有 token 直接复用。
这就是 KV 缓存。
底层到底发生了什么变化? 当模型生成第 t 个 token 时,常规的注意力操作会这样执行:

在朴素解码中,每当时间步 t 增加,模型会重新计算所有层、所有头之前的 K 和 V 向量,只为再次把它们代入这个方程。
启用 KV 缓存后,方程本身_并未_改变,但 K 和 V 的来源变了 。
因此,同样的公式变为:

模型不再在每一步重新计算 K₁…Kₜ₋₁ 和 V₁…Vₜ₋₁,而是直接从 GPU 内存中读取。在 token t 时唯一需要的新工作是计算 Kₜ 和 Vₜ,并将它们追加到缓存中。
它既不是列表,也不是 Python 字典。在实际的 LLM 实现中(如 vLLM、TensorRT-LLM、Hugging Face),缓存以张量形式存储在 GPU 显存中,维度固定。

每当生成一个新 token: → 我们计算 Kₜ 和 Vₜ → 我们沿着 seq_len 维度将它们追加到末尾 → 其他所有内容保持不变
因为 Q(Query 向量)与 K 和 V 在本质上不同。
所以:
如果我们缓存了 Q,就永远不会再用到它,因为它只在当前时间步使用。缓存它只会白白浪费内存,毫无收益。
为了真正理解 KV 缓存的神奇之处,让我们跟随动画,看看我们的 LLM 如何生成短语“I Love cats”。我们将重点观察模型如何处理这些 token,以预测序列中的下一个词。(GIF 可能加载较慢——稍等片刻)

图 3. 有无 KV 缓存的对比:过去的 K/V 是重新计算 vs. 复用。
原始而浪费的做法。每一步都必须从头重新处理全部内容。
步骤 1:预测 “Love”(历史:“I”)
[num_heads, head_dim] 的矩阵 K₁ 和 V₁,却立即将它们丢弃。当模型处理下一个 token 时,会毫无必要地重新构建这些相同的矩阵。步骤 2:预测“cats”(历史:“I Love”)
这样,模型只需计算每个 token 的 Key 和 Value 一次 ,并将其保存下来。
步骤 1:预测 “Love”(历史:“I”)
步骤 2:预测“cats”(历史:“I Love”)
模型依旧会“回顾”历史,只是它通过查询一张简单的查找表(缓存)来完成,而无需从头重新计算。注意力计算仍会将 Qₜ 与所有过去的键进行比较,但由于过去的 K/V 已从缓存中取出,重算成本降至 **O(1)**。每一步的工作现在仅仅是将一个新查询与已有键进行比较。
这种优化是现代推理服务器(如 vLLM、TensorRT-LLM 或 Hugging Face 的 transformers 库)能够实现实时文本生成的原因。没有它,每生成几百个 token,延迟就会翻倍,让聊天模型根本无法使用。
老实说,这并不像发明一种新算法,你只是不再做重复劳动。这就是 KV 缓存的作用:过去不再重新计算,而是被记住。

图 4. KV 缓存作为时空权衡
KV 缓存并非“免费的午餐”,而是一种典型的时空权衡:我们消除了冗余计算(时间),但必须把缓存塞进 GPU 显存(空间),而且这块缓存会变得非常非常庞大!
KV 缓存的大小由以下几个因素决定:

对于一个上下文窗口长达 32,000 token 的大模型,这块缓存就能吃掉几十 GB 的宝贵 GPU 显存。如果服务器还要同时服务这么多并发用户(大批量),内存需求就会成为主要瓶颈,直接限制系统容量。
到目前为止,我们只_讨论_了 KV 缓存。现在让我们通过代码来_亲眼看看它是如何工作的_ 。
下面的示例清楚地展示了朴素解码如何一次又一次地重复计算所有内容,以及 KV 缓存如何通过存储先前计算好的键(K)和值(V)来避免这种浪费。
import torch
import torch.nn as nn
# A single multi-head attention layer
attn = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)
# Dummy input sequence
tokens = torch.randn(1, 5, 512) # [batch, seq_len, embedding_dim]
# -------------------------------
# 1. Naive decoding (no caching)
# Recomputes attention over full history at every step
# -------------------------------
for t in range(1, tokens.size(1)):
x = tokens[:, :t, :] # tokens from 1 to t
out, _ = attn(x, x, x) # recompute Q,K,V for all past tokens again
# -------------------------------
# 2. KV Caching (compute K/V once → reuse forever)
# -------------------------------
past_k, past_v = None, None
for t in range(tokens.size(1)):
x = tokens[:, t:t+1, :] # only the new token
# Project to Q, K, V (like attention does internally)
q = attn.in_proj_q(x)
k = attn.in_proj_k(x)
v = attn.in_proj_v(x)
# Save (or append) K/V into cache
past_k = k if past_k isNoneelse torch.cat([past_k, k], dim=1)
past_v = v if past_v isNoneelse torch.cat([past_v, v], dim=1)
# Attention now only compares new query with cached keys
attn_scores = torch.matmul(q, past_k.transpose(-1, -2)) / (k.size(-1) ** 0.5)
attn_probs = attn_scores.softmax(dim=-1)
output = torch.matmul(attn_probs, past_v)
故事基本上就是这样。
KV 缓存并没有让注意力机制变得更智能,它只是让它不再愚蠢。与其在每个 token 上重新计算过去的内容,我们只需计算一次并记住它。这就是为什么即使上下文变长,生成速度依然很快。