当你在 ChatGPT、Claude 或本地部署的 Llama 上输入一段长文字,模型竟然能在几秒内流畅地逐字输出——这背后少不了一项关键优化:KV Cache(键值缓存)。
没有 KV Cache,每生成一个新 token,模型就要对整个上下文重新做一次完整的 Attention 计算,延迟随序列长度呈 O(n²) 增长。有了 KV Cache,每步只需增量计算一个 token,推理速度可提升 5–10 倍甚至更多。
本文从 Transformer 的 Attention 机制出发,用真实可运行的 PyTorch 代码,带你一步步实现 KV Cache,并讲解 PagedAttention、前缀缓存等生产级进阶方案。
一、KV Cache 的原理:从 Attention 说起
Transformer 的核心是 Multi-Head Self-Attention。给定输入序列的 token 嵌入矩阵 X,每个注意力头会将 X 分别投影为三组向量:
- Q(Query):当前 token 想要查找什么信息
- K(Key):每个历史 token 的"标签",用于被查询
- V(Value):每个历史 token 实际携带的内容
注意力得分公式:
# Attention 计算
Attention(Q, K, V) = softmax(Q @ K.T / sqrt(d_k)) @ V
LLM 推理分为两个阶段:
| 阶段 | 发生了什么 | 计算特点 |
|---|---|---|
| Prefill(预填充) | 一次性处理所有输入 token(prompt),并行计算 K/V | GPU 并行度高,速度快 |
| Decode(解码) | 逐 token 自回归生成,每步只有 1 个新 token | 串行,每步要对全部上下文做 Attention |
问题就在 Decode 阶段:生成第 t 个 token 时,需要对前面 t-1 个 token 做 Attention。如果不缓存,每一步都要重新计算所有历史 token 的 K 和 V,导致总计算量随序列长度平方增长。
KV Cache 的本质:在 Prefill 阶段计算完所有输入 token 的 K/V 矩阵之后,将它们存入内存(缓存)。Decode 每一步只为新 token 计算 K/V,然后将其 追加到缓存中,Attention 直接使用缓存中的所有历史 K/V,无需重新计算。
二、环境准备
安装 Python 依赖
本教程需要 Python 3.10+、PyTorch 2.x。推荐使用 conda 或 virtualenv 隔离环境。
# 创建虚拟环境
python -m venv kvcache-env
source kvcache-env/bin/activate # Windows: kvcache-env\Scripts\activate
# 安装依赖
pip install torch transformers numpy
# 验证 PyTorch 版本
python -c "import torch; print(torch.__version__)" # >= 2.0.0
验证 GPU 可用性(可选)
本教程的代码在 CPU 上也可运行,但 GPU 能更直观地体现加速效果。
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用设备: {device}")
if device == "cuda":
print(f"GPU 型号: {torch.cuda.get_device_name(0)}")
print(f"GPU 显存: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
三、理解"无缓存"时的开销
先实现一个不带缓存的简化版 Self-Attention,观察其在解码时的重复计算问题。
实现无缓存版 Self-Attention
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
class SelfAttentionNaive(nn.Module):
"""无 KV Cache 的朴素 Self-Attention"""
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
def forward(self, x):
"""
x: (batch, seq_len, d_model)
每次调用都对整个序列重新计算 K 和 V!
"""
B, T, C = x.shape
# 计算所有 token 的 Q, K, V
Q = self.W_q(x) # (B, T, d_model)
K = self.W_k(x) # (B, T, d_model) ← 重复计算!
V = self.W_v(x) # (B, T, d_model) ← 重复计算!
# 多头分割
Q = Q.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
K = K.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
V = V.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
# Attention 得分 (因果 mask)
scale = self.d_head ** -0.5
scores = (Q @ K.transpose(-2, -1)) * scale # (B, heads, T, T)
mask = torch.tril(torch.ones(T, T, device=x.device)).bool()
scores = scores.masked_fill(~mask, float('-inf'))
attn = F.softmax(scores, dim=-1)
out = attn @ V # (B, heads, T, d_head)
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.W_o(out)
def simulate_naive_decode(model, prompt_emb, n_new_tokens):
"""模拟无缓存自回归解码,每步都把整个序列传入"""
B, T, C = prompt_emb.shape
tokens = prompt_emb.clone()
times = []
for i in range(n_new_tokens):
t0 = time.perf_counter()
# 每步传入全部 token,包含所有历史
out = model(tokens) # (B, T+i, C)
new_token = out[:, -1:, :] # 只取最后一个 token 的输出
tokens = torch.cat([tokens, new_token], dim=1)
times.append(time.perf_counter() - t0)
return tokens, times
# 测试
d_model, n_heads, seq_len = 512, 8, 50
model_naive = SelfAttentionNaive(d_model, n_heads)
prompt = torch.randn(1, seq_len, d_model)
with torch.no_grad():
_, naive_times = simulate_naive_decode(model_naive, prompt, n_new_tokens=20)
print(f"无缓存 - 首 token: {naive_times[0]*1000:.1f}ms")
print(f"无缓存 - 末 token: {naive_times[-1]*1000:.1f}ms")
print(f"无缓存 - 总耗时: {sum(naive_times)*1000:.1f}ms")
注意观察延迟增长:你会发现随着序列变长,每步耗时在线性增加。在真实 LLM 中,这个问题被放大到极致——序列越长,用户等待越久。
四、从零实现 KV Cache
核心思路:为每个 Attention 层维护一个缓存对象,存储所有已处理 token 的 K/V 张量。每次解码只计算新 token 的 K/V,然后拼接到缓存末尾。
实现带 KV Cache 的 Self-Attention
class KVCache:
"""单层 Attention 的 KV 缓存"""
def __init__(self):
self.k_cache = None # (B, heads, T_cached, d_head)
self.v_cache = None # (B, heads, T_cached, d_head)
def update(self, k_new, v_new):
"""追加新 token 的 K/V 到缓存"""
if self.k_cache is None:
self.k_cache = k_new
self.v_cache = v_new
else:
self.k_cache = torch.cat([self.k_cache, k_new], dim=2)
self.v_cache = torch.cat([self.v_cache, v_new], dim=2)
return self.k_cache, self.v_cache
def get_seq_len(self):
return 0 if self.k_cache is None else self.k_cache.shape[2]
def clear(self):
self.k_cache = None
self.v_cache = None
class SelfAttentionWithCache(nn.Module):
"""带 KV Cache 的 Self-Attention"""
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
self.kv_cache = KVCache()
def forward(self, x, use_cache=True):
"""
x:
- Prefill 阶段: (B, T_prompt, d_model)
- Decode 阶段: (B, 1, d_model) ← 只传入 1 个新 token!
"""
B, T, C = x.shape
# 只为当前输入的 token 计算 Q、K、V
Q = self.W_q(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
K = self.W_k(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
V = self.W_v(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
if use_cache:
# 将新的 K/V 追加到缓存,返回完整的历史 K/V
K_full, V_full = self.kv_cache.update(K, V)
else:
K_full, V_full = K, V
# Q 只是当前 token 的,但 K_full/V_full 包含所有历史
T_kv = K_full.shape[2] # 完整的 KV 长度
scale = self.d_head ** -0.5
scores = (Q @ K_full.transpose(-2, -1)) * scale # (B, heads, T, T_kv)
# 因果 mask:Q 的位置 i 只能 attend 到 KV 中 <= i 的位置
T_q = Q.shape[2]
offset = T_kv - T_q # 解码时 offset = 已缓存的长度
mask = torch.ones(T_q, T_kv, device=x.device).tril(diagonal=offset).bool()
scores = scores.masked_fill(~mask, float('-inf'))
attn = F.softmax(scores, dim=-1)
out = attn @ V_full # (B, heads, T_q, d_head)
out = out.transpose(1, 2).contiguous().view(B, T_q, C)
return self.W_o(out)
def simulate_cached_decode(model, prompt_emb, n_new_tokens):
"""带 KV Cache 的自回归解码"""
# Step 1: Prefill — 一次性处理完整 prompt,填充缓存
model.kv_cache.clear()
t_prefill = time.perf_counter()
with torch.no_grad():
out = model(prompt_emb, use_cache=True)
last_emb = out[:, -1:, :] # 取最后一个 token 的输出作为第一个新 token 输入
prefill_time = time.perf_counter() - t_prefill
# Step 2: Decode — 每步只传入 1 个 token
times = []
current_emb = last_emb
for i in range(n_new_tokens):
t0 = time.perf_counter()
with torch.no_grad():
out = model(current_emb, use_cache=True) # (B, 1, C)
current_emb = out # 下一步的输入
times.append(time.perf_counter() - t0)
return times, prefill_time
# 对比测试
d_model, n_heads, seq_len = 512, 8, 50
model_cached = SelfAttentionWithCache(d_model, n_heads)
# 共享权重(保证公平对比)
model_cached.W_q.weight.data = model_naive.W_q.weight.data.clone()
model_cached.W_k.weight.data = model_naive.W_k.weight.data.clone()
model_cached.W_v.weight.data = model_naive.W_v.weight.data.clone()
model_cached.W_o.weight.data = model_naive.W_o.weight.data.clone()
prompt = torch.randn(1, seq_len, d_model)
cached_times, prefill_time = simulate_cached_decode(model_cached, prompt, n_new_tokens=20)
print(f"\n===== 对比结果 =====")
print(f"无缓存 - 总耗时: {sum(naive_times)*1000:.1f}ms")
print(f"有缓存 - Prefill: {prefill_time*1000:.1f}ms | Decode 总计: {sum(cached_times)*1000:.1f}ms")
print(f"解码加速比: {sum(naive_times)/sum(cached_times):.1f}x")
预期输出:在小模型(124M 参数)短序列(200 token)测试中,Mac M4 CPU 上可观察到约 5× 加速。随着序列长度增加,加速效果更显著。
五、生产级进阶:PagedAttention 与前缀缓存
上面的朴素实现在真实生产环境中会遇到两个核心问题:
- 内存碎片化:每个请求预分配连续内存块,但实际序列长度不可预知,导致大量内存浪费
- 多请求无法共享公共前缀:例如系统 prompt 相同的多个用户请求,其 KV Cache 被重复计算和存储
PagedAttention:KV Cache 的分页管理
vLLM 引入的 PagedAttention 借鉴操作系统的虚拟内存分页机制:
| 概念 | OS 虚拟内存 | PagedAttention |
|---|---|---|
| 存储单元 | 内存页(4KB) | KV Block(固定 token 数,如 16) |
| 地址映射 | 虚拟地址→物理地址 | 逻辑 Block ID → 物理 GPU 内存位置 |
| 按需分配 | 缺页中断时分配 | 序列增长时按需追加 Block |
| 内存共享 | Copy-on-Write 共享 | 前缀 Block 引用计数共享 |
# 使用 vLLM 的简单示例(生产推荐)
from vllm import LLM, SamplingParams
# vLLM 内部自动使用 PagedAttention
llm = LLM(
model="meta-llama/Llama-3.1-8B-Instruct",
max_model_len=4096,
gpu_memory_utilization=0.85, # 预留 85% GPU 显存给 KV Cache
enable_prefix_caching=True, # 启用前缀缓存
)
sampling_params = SamplingParams(temperature=0.7, max_tokens=256)
# 共享系统 prompt 的多个请求——前缀 KV Cache 只计算一次!
system_prompt = "你是一位专业的 Python 工程师,擅长写出简洁高效的代码。\n\n"
prompts = [
system_prompt + "请解释 Python 的 GIL 是什么?",
system_prompt + "如何用 asyncio 实现并发?",
system_prompt + "解释 Python 装饰器的工作原理。",
]
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
print(output.outputs[0].text[:100], "...")
Multi-Query Attention (MQA) 与 Grouped-Query Attention (GQA)
除了缓存管理,还可以从根源上减少 KV Cache 的大小:
class GroupedQueryAttention(nn.Module):
"""
GQA: 多个 Query Head 共享一组 K/V Head
KV Cache 大小降低为原来的 n_kv_heads/n_heads 倍
例: n_heads=32, n_kv_heads=8 → KV Cache 减少 4×
Llama-3 采用此方案
"""
def __init__(self, d_model, n_heads, n_kv_heads):
super().__init__()
assert n_heads % n_kv_heads == 0
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.n_rep = n_heads // n_kv_heads # 每组 KV 被几个 Q 共享
self.d_head = d_model // n_heads
self.W_q = nn.Linear(d_model, n_heads * self.d_head, bias=False)
# KV 投影维度更小!
self.W_k = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
self.W_v = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
def forward(self, x, kv_cache=None):
B, T, _ = x.shape
Q = self.W_q(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
K = self.W_k(x).view(B, T, self.n_kv_heads, self.d_head).transpose(1, 2)
V = self.W_v(x).view(B, T, self.n_kv_heads, self.d_head).transpose(1, 2)
# 将 KV 重复扩展以匹配 Q 的头数
K = K.repeat_interleave(self.n_rep, dim=1) # (B, n_heads, T, d_head)
V = V.repeat_interleave(self.n_rep, dim=1)
scale = self.d_head ** -0.5
scores = (Q @ K.transpose(-2, -1)) * scale
mask = torch.tril(torch.ones(T, T, device=x.device)).bool()
scores = scores.masked_fill(~mask, float('-inf'))
attn = F.softmax(scores, dim=-1)
out = (attn @ V).transpose(1, 2).contiguous().view(B, T, -1)
return self.W_o(out)
# KV Cache 大小对比
n_layers, n_heads, n_kv_heads = 32, 32, 8
d_head, seq_len, batch = 128, 4096, 1
mha_kv_bytes = 2 * n_layers * n_heads * d_head * seq_len * batch * 2 # float16
gqa_kv_bytes = 2 * n_layers * n_kv_heads * d_head * seq_len * batch * 2
print(f"MHA KV Cache: {mha_kv_bytes / 1e9:.2f} GB")
print(f"GQA KV Cache: {gqa_kv_bytes / 1e9:.2f} GB (节省 {(1-gqa_kv_bytes/mha_kv_bytes)*100:.0f}%)")
实际数据:Llama-3 70B(GQA, n_kv_heads=8)在 8K 上下文、batch=1、float16 下,KV Cache 约需 4GB 显存。如果用 MHA,同等配置需要 16GB,根本无法在单卡上运行。
六、用 Transformers 库验证实际效果
用 HuggingFace Transformers 对比有/无缓存的实际速度
HuggingFace 的 transformers 库默认开启 KV Cache(use_cache=True),你也可以手动关闭来做对比测试。
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time
model_name = "gpt2" # 小模型,无需 GPU 也能快速测试
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32)
model.eval()
prompt = "The future of artificial intelligence is"
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"]
n_new_tokens = 50
# ─── 方案 A: 禁用 KV Cache ───
t0 = time.perf_counter()
with torch.no_grad():
out_no_cache = model.generate(
input_ids,
max_new_tokens=n_new_tokens,
use_cache=False, # ← 关键:禁用 KV Cache
do_sample=False,
)
time_no_cache = time.perf_counter() - t0
# ─── 方案 B: 启用 KV Cache(默认)───
t0 = time.perf_counter()
with torch.no_grad():
out_with_cache = model.generate(
input_ids,
max_new_tokens=n_new_tokens,
use_cache=True, # ← 默认开启
do_sample=False,
cache_implementation="static", # 可选: "dynamic"(默认), "static", "quantized_static"
)
time_with_cache = time.perf_counter() - t0
print(f"\n===== 实际模型对比 =====")
print(f"无 KV Cache — 生成 {n_new_tokens} tokens 耗时: {time_no_cache:.2f}s")
print(f"有 KV Cache — 生成 {n_new_tokens} tokens 耗时: {time_with_cache:.2f}s")
print(f"加速比: {time_no_cache / time_with_cache:.1f}x")
print(f"\n生成文本: {tokenizer.decode(out_with_cache[0], skip_special_tokens=True)}")
cache_implementation 参数:transformers 4.38+ 支持多种缓存策略:dynamic(默认,动态扩展)、static(预分配固定大小,速度更快但需指定 max_length)、quantized_static(量化 KV Cache,节省显存)。
常见问题解答
总结:KV Cache 的三个演进时代
| 时代 | 方案 | 核心特点 | 代表框架 |
|---|---|---|---|
| Era 1 | 朴素 KV Cache | 连续内存分配,无碎片管理 | 早期 HuggingFace |
| Era 2 | PagedAttention | 分页管理,前缀共享,高并发 | vLLM, SGLang |
| Era 3 | 分布式 KV Cache | 跨节点缓存池,P/D 分离,RAG 加速 | Mooncake, CacheBlend |
通过本教程,你已经理解了 KV Cache 的核心原理,并动手实现了一个完整的带缓存 Self-Attention。关键要点:
- KV Cache 将解码阶段的计算从 O(n²) 降至 O(n),是大模型推理加速的基石
- 实现上只需缓存 K/V 矩阵并在每步追加新 token,修改量极小
- 生产环境推荐使用 vLLM(PagedAttention)或 HuggingFace 的 static cache
- GQA/MQA 可以从根源减少 KV Cache 体积,与 PagedAttention 正交互补
下一步建议:理解了 KV Cache 后,可以进一步探索 Speculative Decoding(用小模型预测 token 草稿,大模型并行验证)和 Flash Attention(IO 感知的高效 Attention 实现),这两项技术与 KV Cache 配合可以实现更极致的推理加速。