首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >大模型KV缓存,形象理解

大模型KV缓存,形象理解

作者头像
Ai学习的老章
发布2025-11-20 11:22:30
发布2025-11-20 11:22:30
200
举报
https://medium.com/@prathamgrover777/kv-caching-attention-optimization-from-o-n%C2%B2-to-o-n-8b605f0d4072
https://medium.com/@prathamgrover777/kv-caching-attention-optimization-from-o-n%C2%B2-to-o-n-8b605f0d4072

https://medium.com/@prathamgrover777/kv-caching-attention-optimization-from-o-n%C2%B2-to-o-n-8b605f0d4072

我们见过 LLM 如何逐字逐句地敲出上千字的回答,仿佛“边想边说”。表面顺滑,背后却低效得惊人。

在生成第 t 步时,模型必须确保下一个词与之前所有内容保持一致:你的提示、已生成的部分、系统指令,乃至任何隐藏上下文。实际上,模型会重新通过所有 Transformer 层重建之前全部的隐藏状态,并再次计算 Query、Key 和 Value,即便前面的 token 丝毫未变。这种重复计算逐层、逐头进行,没有任何复用。如图 1 所示,由于之前步骤的结果未被复用,每个 token 的计算量随序列长度持续增加。

图 1. 解码阶段每个 token 的计算成本(朴素 vs KV 缓存)。

为什么这种方式扩展性这么差? 想象你在写一段文字,每添加一个新句子之前,都要从头把整篇文档重新读一遍。然后写下一个句子时,再从头读一遍。如此反复。这就是朴素解码循环在做的事情。

现在把它放进 transformer:

  • L = 层数
  • H = 每层注意力头数
  • n = 当前序列长度
  • 因此,每个 token 的计算量不仅与序列长度 n 成正比,内部还要再乘以 L × H。这就是成本飙升的原因。

现在,把这个问题放到现代 LLMs 的规模下:

  • 几十层网络层层堆叠
  • 每一层有多个注意力头,各自“角度”不同地回顾历史。
  • 长提示(上千条),还要生成一长段回答。

更致命的是,模型对过去的 Key 和 Value 毫无记忆;因此在第 1000 个 token 时,它得把 token 1 到 999 全部重新算一遍。这就导致解码的时间复杂度是 **O(n²)**。

结果就是巨大的冗余。

图 2. 无 KV 缓存的朴素解码

如图 2 所示,模型不会记住过去的 K/V 投影,而是对每个新 token 都重新计算它们。 t = 2 时 → 重新计算 token **1**当 t = 3 时 → 重新计算 token 1 **2**对于 t = 4 → 重新计算 tokens **1, 2, 3**对于 t = n → 重新计算 tokens 1 n−1

在每一层内部,对于每个 token,它执行:

  • K = X · Wₖ
  • V = X · Wᵥ
  • Q · Kᵀ(与所有先前 key 的点积)
  • 对所有先前 Value 的加权和。

这些结果从未被复用,因此到第 n 个 token 时,你已经重复了 (n–1) × L × H 次!这意味着计算量持续增长, 永远不会稳定

硬件瓶颈

在解决它之前,我们必须先理解瓶颈所在: 搬运数据代价高昂。GPU 做数学运算极快,但真正的成本往往在于把正确的数据在正确的时间送到正确的位置。

举个简单的例子,把它想象成厨房里的厨师:

  • GPU 核心就是那位厨师,切菜和烹饪的速度快得惊人。
  • VRAM(GPU 自有内存) 就像紧挨着厨师的小备餐台,空间有限,但伸手就能拿到上面的东西。
  • 系统 RAM 则是走廊尽头的大储藏室,空间充足,可每次都得停下、走过去再把食材搬回来。

用 ML 的话说:

  • GPU 核心每秒能执行数万亿次 FLOPs,真正的瓶颈并不是算力。
  • 内存带宽 (将 K/V 张量搬运到高带宽内存)是有限的。
  • 并且注意力机制需要反复读取这些张量 ,这会把带宽压垮。

这就是为什么序列变长时生成会变慢——不是因为计算变复杂,而是因为 GPU 等数据 的时间比真正计算的时间还多。

模型的权重和对话的历史 token 都必须放在显存里才能快速处理。但在朴素的解码中,每生成一个新 token,厨师(GPU)就得折回储藏室(内存)再拿一遍同样的食材(历史 K/V)。储藏室与灶台之间的路越来越挤,走路时间越来越多,真正炒菜的时间越来越少。

那种对旧数据持续、重复的抓取,就是我们所说的带宽之痛 。“带宽”指的是这条通路的容量,“痛”则源于把刚才已经存在的数据再次塞进去所造成的严重拥堵。带着这个概念,我们来看看到底在哪个环节,这一过程在规模扩大时会崩溃。更具体地说,注意力机制会变成受内存限制,而非计算限制 。你的 GPU 空转着,等待从显存取数,尽管它的算力足以瞬间完成运算。

朴素的注意力机制在何处因规模而崩溃?

  • 延迟悄然上升 :早期的 token 反应迅捷,但随着对话变长,每一个新 token 的生成时间都比前一个更久。你正拖着一段越来越长的历史前行。这正是注意力计算复杂度带来的直接后果——它是二次方 ,即 **O(n²)**,其中 n 为序列长度。
  • 带宽之痛,而非算力之痛: 现代 GPU 的数学运算飞快,但数据在内存之间搬来搬去才是“征税员”;反复把整个“过去”拖过总线,会把带宽压垮。你更多时间是在等数据,而不是在计算。
  • 推理崩溃: 在线上,你不是给一个人生成,而是同时给成千上万人生成。如果每条 token 流都重新处理自己的完整历史,系统立刻垮掉,成本飙升。

引入 KV 缓存

此刻,一个自然的问题浮现:**“等等……我们为什么每次都重复同样的投影?”****KV 缓存**就是你拒绝重读过去的瞬间。简单说,KV 缓存就是:一旦某 token 在某层里的 Key 和 Value 向量算完,我们把它存进 GPU 内存,而不是直接扔掉。

在注意力机制中,每个 token 被转换为两个紧凑向量:Key(K)Value(V),它们描述_该 token 应如何与后续 token 交互_ 。生成方式是将 token 的嵌入(x)通过该层的权重矩阵:

Key = X @ Wₖ Value = X @ Wᵥ

这并非只计算一次,而是对模型的每一层、每一个注意力头,都各自计算一套 K 和 V。 举例:若模型有 32 层、每层 32 个头,则每个 token 要计算 32 × 32 = 1024 组 K/V 投影。

诀窍在这里:

  • 在推理过程中,模型的权重(WₖWᵥ)不会改变。
  • 一旦计算完成,token 的嵌入向量 X 也不再改变。
  • 这意味着它的 KeyValue 向量是确定性的 。一旦算出,它们就像被刻在石头上一样,在整个序列中不会变化。

那为何每一步都要重新计算它们呢?

于是,不再 :每生成一个新 token → 为所有旧 token 重新计算 K/V。 而是: 一次性算出 K/V → 存起来 → 后续所有 token 直接复用。

这就是 KV 缓存。

底层到底发生了什么变化? 当模型生成第 t 个 token 时,常规的注意力操作会这样执行:

在朴素解码中,每当时间步 t 增加,模型会重新计算所有层、所有头之前的 K 和 V 向量,只为再次把它们代入这个方程。

启用 KV 缓存后,方程本身_并未_改变,但 K 和 V 的来源变了

  • 我们停止重新计算 K₁…Kₜ₋₁ 和 V₁…Vₜ₋₁
  • 取而代之的是,我们从 GPU 内存中一个名为 KV Cache 的张量里直接读取它们

因此,同样的公式变为:

模型不再在每一步重新计算 K₁…Kₜ₋₁ 和 V₁…Vₜ₋₁,而是直接从 GPU 内存中读取。在 token t 时唯一需要的新工作是计算 Kₜ 和 Vₜ,并将它们追加到缓存中。

KV 缓存在 GPU 内部到底是什么样子?

它既不是列表,也不是 Python 字典。在实际的 LLM 实现中(如 vLLM、TensorRT-LLM、Hugging Face),缓存以张量形式存储在 GPU 显存中,维度固定。

  • num_layers = 模型中 Transformer 块的总数(例如 LLaMA-7B 为 32)
  • num_heads = 每层注意力头数(例如 32)
  • seq_len = 当前已见的 token 数量
  • head_dim = 每个注意力头的维度(例如 64 或 128)

每当生成一个新 token: → 我们计算 KₜVₜ → 我们沿着 seq_len 维度将它们追加到末尾 → 其他所有内容保持不变

我们为什么不缓存 Query(Q)?

因为 Q(Query 向量)与 KV 在本质上不同。

  • Key 和 Value 代表 之前 token 的记忆 。 它们一旦计算完成就不会再改变。
  • 但 Query 只依赖于_当前正在生成的 token_,而非过去的。 Qₜ 用来提问:“鉴于我已看到的全部(所有已缓存的 K/V),下一个 token 应该是什么?”

所以:

  • K 和 V = 记忆 → 缓存一次,反复使用
  • Q = 按步骤、临时生成 → 无需存储

如果我们缓存了 Q,就永远不会再用到它,因为它只在当前时间步使用。缓存它只会白白浪费内存,毫无收益。

可视化差异

为了真正理解 KV 缓存的神奇之处,让我们跟随动画,看看我们的 LLM 如何生成短语“I Love cats”。我们将重点观察模型如何处理这些 token,以预测序列中的下一个词。(GIF 可能加载较慢——稍等片刻)

图 3. 有无 KV 缓存的对比:过去的 K/V 是重新计算 vs. 复用。

1. 无缓存

原始而浪费的做法。每一步都必须从头重新处理全部内容。

步骤 1:预测 “Love”(历史:“I”)

  • 模型接收第一个 token “I”。
  • 它计算其 Key(记为 K₁)及其 Value(记为 V₁)。
  • 它还计算其 Query(Q₁)。
  • 它执行注意力计算(Q₁ 关注 K₁),以预测下一个词:“Love”。
  • 然后,它会丢弃掉 K₁V₁,所有这些计算成果都被浪费。
  • 在张量层面,这意味着 GPU 刚刚计算出形状为 [num_heads, head_dim] 的矩阵 K₁V₁,却立即将它们丢弃。当模型处理下一个 token 时,会毫无必要地重新构建这些相同的矩阵。

步骤 2:预测“cats”(历史:“I Love”)

  • 模型现在需要处理新 token“Love”,但它对“I”没有任何记忆。
  • 它必须重新为“I”计算 Key 和 Value(生成 K’₁V’₁)。
  • 它还会为“Love”计算 Key 和 Value(生成 K₂V₂),并为新 token“Love”计算 Query(Q₂)。
  • 它执行注意力计算(Q₂ 同时查看 K’₁K₂)来预测“cats”。
  • 随后,它把 K’₁V’₁K₂V₂ 全部丢弃。
  • 你看出规律了吗?为了预测第三个词,我们不得不把第一个词的所有计算重新做一遍。

2. 带缓存

这样,模型只需计算每个 token 的 Key 和 Value 一次 ,并将其保存下来。

步骤 1:预测 “Love”(历史:“I”)

  • 模型接收第一个 token “I”。
  • 它计算自己的 Key(K₁)和 Value(V₁)。
  • 执行注意力计算以预测 “Love”。
  • 关键是,它将 K₁V₁ 存入一块特殊内存:KV 缓存
  • 缓存现在包含:{(K₁, V₁)}。
  • 这意味着 GPU 现在为每一层、每一个注意力头都保存着一个小张量,代表该词元对所有未来注意力查询的贡献。无需重新计算,只需查表即可。

步骤 2:预测“cats”(历史:“I Love”)

  • 模型忽略 “I” 词元。
  • 它知道它已经存在于缓存中。
  • 它只处理新的 token “Love”。
  • 它只为 “Love” 计算 Key(K₂)和 Value(V₂)。
  • 它将这对新的键值追加到缓存中。
  • 缓存现在包含:_{(K₁, V₁), (K₂, V₂)}_ 它为“Love”计算查询(Q₂),并通过查看整个缓存(K₁K₂)进行注意力计算。
  • 它预测“cats”。没有任何内容被丢弃。
  • 此时工作量恒定且最小。这就是每一步的工作量变为线性(**O(n)**)的方式。

模型依旧会“回顾”历史,只是它通过查询一张简单的查找表(缓存)来完成,而无需从头重新计算。注意力计算仍会将 Qₜ 与所有过去的键进行比较,但由于过去的 K/V 已从缓存中取出,重算成本降至 **O(1)**。每一步的工作现在仅仅是将一个新查询与已有键进行比较。

这种优化是现代推理服务器(如 vLLM、TensorRT-LLM 或 Hugging Face 的 transformers 库)能够实现实时文本生成的原因。没有它,每生成几百个 token,延迟就会翻倍,让聊天模型根本无法使用。

老实说,这并不像发明一种新算法,你只是不再做重复劳动。这就是 KV 缓存的作用:过去不再重新计算,而是被记住。

权衡:用内存换计算

图 4. KV 缓存作为时空权衡

KV 缓存并非“免费的午餐”,而是一种典型的时空权衡:我们消除了冗余计算(时间),但必须把缓存塞进 GPU 显存(空间),而且这块缓存会变得非常非常庞大!

KV 缓存的大小由以下几个因素决定:

对于一个上下文窗口长达 32,000 token 的大模型,这块缓存就能吃掉几十 GB 的宝贵 GPU 显存。如果服务器还要同时服务这么多并发用户(大批量),内存需求就会成为主要瓶颈,直接限制系统容量。

极简代码示例(PyTorch 朴素实现 vs KV 缓存)

到目前为止,我们只_讨论_了 KV 缓存。现在让我们通过代码来_亲眼看看它是如何工作的_ 。

下面的示例清楚地展示了朴素解码如何一次又一次地重复计算所有内容,以及 KV 缓存如何通过存储先前计算好的键(K)和值(V)来避免这种浪费。

代码语言:javascript
复制
import torch  
import torch.nn as nn  

# A single multi-head attention layer  
attn = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)  

# Dummy input sequence  
tokens = torch.randn(1, 5, 512)  # [batch, seq_len, embedding_dim]  

# -------------------------------  
# 1. Naive decoding (no caching)  
# Recomputes attention over full history at every step  
# -------------------------------  
for t in range(1, tokens.size(1)):  
    x = tokens[:, :t, :]         # tokens from 1 to t  
    out, _ = attn(x, x, x)       # recompute Q,K,V for all past tokens again  

# -------------------------------  
# 2. KV Caching (compute K/V once → reuse forever)  
# -------------------------------  
past_k, past_v = None, None

for t in range(tokens.size(1)):  
    x = tokens[:, t:t+1, :]      # only the new token  

    # Project to Q, K, V (like attention does internally)  
    q = attn.in_proj_q(x)  
    k = attn.in_proj_k(x)  
    v = attn.in_proj_v(x)  

    # Save (or append) K/V into cache  
    past_k = k if past_k isNoneelse torch.cat([past_k, k], dim=1)  
    past_v = v if past_v isNoneelse torch.cat([past_v, v], dim=1)  

    # Attention now only compares new query with cached keys  
    attn_scores = torch.matmul(q, past_k.transpose(-1, -2)) / (k.size(-1) ** 0.5)  
    attn_probs = attn_scores.softmax(dim=-1)  
    output = torch.matmul(attn_probs, past_v)

故事基本上就是这样。

KV 缓存并没有让注意力机制变得更智能,它只是让它不再愚蠢。与其在每个 token 上重新计算过去的内容,我们只需计算一次并记住它。这就是为什么即使上下文变长,生成速度依然很快。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2025-10-28,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 机器学习与统计学 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 硬件瓶颈
    • 朴素的注意力机制在何处因规模而崩溃?
  • 引入 KV 缓存
    • KV 缓存在 GPU 内部到底是什么样子?
    • 我们为什么不缓存 Query(Q)?
  • 可视化差异
    • 1. 无缓存
    • 2. 带缓存
  • 权衡:用内存换计算
    • 极简代码示例(PyTorch 朴素实现 vs KV 缓存)
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档