首页 / AI原理图解 / 从零理解KV Cache与推理加速 1 次阅读
从零理解KV Cache与推理加速 — AI原理图解
AI原理图解 ★★☆ 中级难度

从零理解
KV Cache 与推理加速

2026年2月28日 · 预计阅读 15 分钟 · 含完整可运行代码 · 5 个核心步骤

当你在 ChatGPT、Claude 或本地部署的 Llama 上输入一段长文字,模型竟然能在几秒内流畅地逐字输出——这背后少不了一项关键优化:KV Cache(键值缓存)。

没有 KV Cache,每生成一个新 token,模型就要对整个上下文重新做一次完整的 Attention 计算,延迟随序列长度呈 O(n²) 增长。有了 KV Cache,每步只需增量计算一个 token,推理速度可提升 5–10 倍甚至更多。

本文从 Transformer 的 Attention 机制出发,用真实可运行的 PyTorch 代码,带你一步步实现 KV Cache,并讲解 PagedAttention、前缀缓存等生产级进阶方案。

小模型典型加速比
O(n²)→O(n)
计算复杂度变化
70B
大模型 KV Cache 可达数 GB
3 代
KV Cache 演进阶段

一、KV Cache 的原理:从 Attention 说起

Transformer Attention 与 KV Cache 原理架构图

Transformer 的核心是 Multi-Head Self-Attention。给定输入序列的 token 嵌入矩阵 X,每个注意力头会将 X 分别投影为三组向量:

  • Q(Query):当前 token 想要查找什么信息
  • K(Key):每个历史 token 的"标签",用于被查询
  • V(Value):每个历史 token 实际携带的内容

注意力得分公式:

# Attention 计算
Attention(Q, K, V) = softmax(Q @ K.T / sqrt(d_k)) @ V

LLM 推理分为两个阶段:

阶段发生了什么计算特点
Prefill(预填充) 一次性处理所有输入 token(prompt),并行计算 K/V GPU 并行度高,速度快
Decode(解码) 逐 token 自回归生成,每步只有 1 个新 token 串行,每步要对全部上下文做 Attention

问题就在 Decode 阶段:生成第 t 个 token 时,需要对前面 t-1 个 token 做 Attention。如果不缓存,每一步都要重新计算所有历史 token 的 K 和 V,导致总计算量随序列长度平方增长。

KV Cache 的本质:在 Prefill 阶段计算完所有输入 token 的 K/V 矩阵之后,将它们存入内存(缓存)。Decode 每一步只为新 token 计算 K/V,然后将其 追加到缓存中,Attention 直接使用缓存中的所有历史 K/V,无需重新计算。

二、环境准备

环境准备与依赖安装流程图
1

安装 Python 依赖

本教程需要 Python 3.10+、PyTorch 2.x。推荐使用 conda 或 virtualenv 隔离环境。

# 创建虚拟环境
python -m venv kvcache-env
source kvcache-env/bin/activate  # Windows: kvcache-env\Scripts\activate

# 安装依赖
pip install torch transformers numpy

# 验证 PyTorch 版本
python -c "import torch; print(torch.__version__)"  # >= 2.0.0
2

验证 GPU 可用性(可选)

本教程的代码在 CPU 上也可运行,但 GPU 能更直观地体现加速效果。

import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用设备: {device}")

if device == "cuda":
    print(f"GPU 型号: {torch.cuda.get_device_name(0)}")
    print(f"GPU 显存: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

三、理解"无缓存"时的开销

有无 KV Cache 的计算开销对比图

先实现一个不带缓存的简化版 Self-Attention,观察其在解码时的重复计算问题。

3

实现无缓存版 Self-Attention

import torch
import torch.nn as nn
import torch.nn.functional as F
import time

class SelfAttentionNaive(nn.Module):
    """无 KV Cache 的朴素 Self-Attention"""
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x):
        """
        x: (batch, seq_len, d_model)
        每次调用都对整个序列重新计算 K 和 V!
        """
        B, T, C = x.shape

        # 计算所有 token 的 Q, K, V
        Q = self.W_q(x)  # (B, T, d_model)
        K = self.W_k(x)  # (B, T, d_model)  ← 重复计算!
        V = self.W_v(x)  # (B, T, d_model)  ← 重复计算!

        # 多头分割
        Q = Q.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        K = K.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        V = V.view(B, T, self.n_heads, self.d_head).transpose(1, 2)

        # Attention 得分 (因果 mask)
        scale = self.d_head ** -0.5
        scores = (Q @ K.transpose(-2, -1)) * scale  # (B, heads, T, T)
        mask = torch.tril(torch.ones(T, T, device=x.device)).bool()
        scores = scores.masked_fill(~mask, float('-inf'))
        attn = F.softmax(scores, dim=-1)

        out = attn @ V  # (B, heads, T, d_head)
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        return self.W_o(out)


def simulate_naive_decode(model, prompt_emb, n_new_tokens):
    """模拟无缓存自回归解码,每步都把整个序列传入"""
    B, T, C = prompt_emb.shape
    tokens = prompt_emb.clone()

    times = []
    for i in range(n_new_tokens):
        t0 = time.perf_counter()
        # 每步传入全部 token,包含所有历史
        out = model(tokens)               # (B, T+i, C)
        new_token = out[:, -1:, :]       # 只取最后一个 token 的输出
        tokens = torch.cat([tokens, new_token], dim=1)
        times.append(time.perf_counter() - t0)

    return tokens, times


# 测试
d_model, n_heads, seq_len = 512, 8, 50
model_naive = SelfAttentionNaive(d_model, n_heads)
prompt = torch.randn(1, seq_len, d_model)

with torch.no_grad():
    _, naive_times = simulate_naive_decode(model_naive, prompt, n_new_tokens=20)

print(f"无缓存 - 首 token: {naive_times[0]*1000:.1f}ms")
print(f"无缓存 - 末 token: {naive_times[-1]*1000:.1f}ms")
print(f"无缓存 - 总耗时: {sum(naive_times)*1000:.1f}ms")
⚠️

注意观察延迟增长:你会发现随着序列变长,每步耗时在线性增加。在真实 LLM 中,这个问题被放大到极致——序列越长,用户等待越久。

四、从零实现 KV Cache

KV Cache 实现步骤流程图

核心思路:为每个 Attention 层维护一个缓存对象,存储所有已处理 token 的 K/V 张量。每次解码只计算新 token 的 K/V,然后拼接到缓存末尾。

4

实现带 KV Cache 的 Self-Attention

class KVCache:
    """单层 Attention 的 KV 缓存"""
    def __init__(self):
        self.k_cache = None  # (B, heads, T_cached, d_head)
        self.v_cache = None  # (B, heads, T_cached, d_head)

    def update(self, k_new, v_new):
        """追加新 token 的 K/V 到缓存"""
        if self.k_cache is None:
            self.k_cache = k_new
            self.v_cache = v_new
        else:
            self.k_cache = torch.cat([self.k_cache, k_new], dim=2)
            self.v_cache = torch.cat([self.v_cache, v_new], dim=2)
        return self.k_cache, self.v_cache

    def get_seq_len(self):
        return 0 if self.k_cache is None else self.k_cache.shape[2]

    def clear(self):
        self.k_cache = None
        self.v_cache = None


class SelfAttentionWithCache(nn.Module):
    """带 KV Cache 的 Self-Attention"""
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

        self.kv_cache = KVCache()

    def forward(self, x, use_cache=True):
        """
        x:
          - Prefill 阶段: (B, T_prompt, d_model)
          - Decode  阶段: (B, 1, d_model) ← 只传入 1 个新 token!
        """
        B, T, C = x.shape

        # 只为当前输入的 token 计算 Q、K、V
        Q = self.W_q(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        K = self.W_k(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        V = self.W_v(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)

        if use_cache:
            # 将新的 K/V 追加到缓存,返回完整的历史 K/V
            K_full, V_full = self.kv_cache.update(K, V)
        else:
            K_full, V_full = K, V

        # Q 只是当前 token 的,但 K_full/V_full 包含所有历史
        T_kv = K_full.shape[2]  # 完整的 KV 长度
        scale = self.d_head ** -0.5
        scores = (Q @ K_full.transpose(-2, -1)) * scale  # (B, heads, T, T_kv)

        # 因果 mask:Q 的位置 i 只能 attend 到 KV 中 <= i 的位置
        T_q = Q.shape[2]
        offset = T_kv - T_q  # 解码时 offset = 已缓存的长度
        mask = torch.ones(T_q, T_kv, device=x.device).tril(diagonal=offset).bool()
        scores = scores.masked_fill(~mask, float('-inf'))

        attn = F.softmax(scores, dim=-1)
        out = attn @ V_full  # (B, heads, T_q, d_head)
        out = out.transpose(1, 2).contiguous().view(B, T_q, C)
        return self.W_o(out)


def simulate_cached_decode(model, prompt_emb, n_new_tokens):
    """带 KV Cache 的自回归解码"""
    # Step 1: Prefill — 一次性处理完整 prompt,填充缓存
    model.kv_cache.clear()
    t_prefill = time.perf_counter()
    with torch.no_grad():
        out = model(prompt_emb, use_cache=True)
    last_emb = out[:, -1:, :]  # 取最后一个 token 的输出作为第一个新 token 输入
    prefill_time = time.perf_counter() - t_prefill

    # Step 2: Decode — 每步只传入 1 个 token
    times = []
    current_emb = last_emb
    for i in range(n_new_tokens):
        t0 = time.perf_counter()
        with torch.no_grad():
            out = model(current_emb, use_cache=True)  # (B, 1, C)
        current_emb = out  # 下一步的输入
        times.append(time.perf_counter() - t0)

    return times, prefill_time


# 对比测试
d_model, n_heads, seq_len = 512, 8, 50
model_cached = SelfAttentionWithCache(d_model, n_heads)
# 共享权重(保证公平对比)
model_cached.W_q.weight.data = model_naive.W_q.weight.data.clone()
model_cached.W_k.weight.data = model_naive.W_k.weight.data.clone()
model_cached.W_v.weight.data = model_naive.W_v.weight.data.clone()
model_cached.W_o.weight.data = model_naive.W_o.weight.data.clone()

prompt = torch.randn(1, seq_len, d_model)
cached_times, prefill_time = simulate_cached_decode(model_cached, prompt, n_new_tokens=20)

print(f"\n===== 对比结果 =====")
print(f"无缓存 - 总耗时: {sum(naive_times)*1000:.1f}ms")
print(f"有缓存 - Prefill: {prefill_time*1000:.1f}ms | Decode 总计: {sum(cached_times)*1000:.1f}ms")
print(f"解码加速比: {sum(naive_times)/sum(cached_times):.1f}x")

预期输出:在小模型(124M 参数)短序列(200 token)测试中,Mac M4 CPU 上可观察到约 5× 加速。随着序列长度增加,加速效果更显著。

五、生产级进阶:PagedAttention 与前缀缓存

PagedAttention 与前缀缓存原理信息图

上面的朴素实现在真实生产环境中会遇到两个核心问题:

  1. 内存碎片化:每个请求预分配连续内存块,但实际序列长度不可预知,导致大量内存浪费
  2. 多请求无法共享公共前缀:例如系统 prompt 相同的多个用户请求,其 KV Cache 被重复计算和存储

PagedAttention:KV Cache 的分页管理

vLLM 引入的 PagedAttention 借鉴操作系统的虚拟内存分页机制:

概念OS 虚拟内存PagedAttention
存储单元内存页(4KB)KV Block(固定 token 数,如 16)
地址映射虚拟地址→物理地址逻辑 Block ID → 物理 GPU 内存位置
按需分配缺页中断时分配序列增长时按需追加 Block
内存共享Copy-on-Write 共享前缀 Block 引用计数共享
# 使用 vLLM 的简单示例(生产推荐)
from vllm import LLM, SamplingParams

# vLLM 内部自动使用 PagedAttention
llm = LLM(
    model="meta-llama/Llama-3.1-8B-Instruct",
    max_model_len=4096,
    gpu_memory_utilization=0.85,  # 预留 85% GPU 显存给 KV Cache
    enable_prefix_caching=True,   # 启用前缀缓存
)

sampling_params = SamplingParams(temperature=0.7, max_tokens=256)

# 共享系统 prompt 的多个请求——前缀 KV Cache 只计算一次!
system_prompt = "你是一位专业的 Python 工程师,擅长写出简洁高效的代码。\n\n"
prompts = [
    system_prompt + "请解释 Python 的 GIL 是什么?",
    system_prompt + "如何用 asyncio 实现并发?",
    system_prompt + "解释 Python 装饰器的工作原理。",
]

outputs = llm.generate(prompts, sampling_params)
for output in outputs:
    print(output.outputs[0].text[:100], "...")

Multi-Query Attention (MQA) 与 Grouped-Query Attention (GQA)

除了缓存管理,还可以从根源上减少 KV Cache 的大小:

class GroupedQueryAttention(nn.Module):
    """
    GQA: 多个 Query Head 共享一组 K/V Head
    KV Cache 大小降低为原来的 n_kv_heads/n_heads 倍

    例: n_heads=32, n_kv_heads=8 → KV Cache 减少 4×
    Llama-3 采用此方案
    """
    def __init__(self, d_model, n_heads, n_kv_heads):
        super().__init__()
        assert n_heads % n_kv_heads == 0
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.n_rep = n_heads // n_kv_heads  # 每组 KV 被几个 Q 共享
        self.d_head = d_model // n_heads

        self.W_q = nn.Linear(d_model, n_heads * self.d_head, bias=False)
        # KV 投影维度更小!
        self.W_k = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
        self.W_v = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, kv_cache=None):
        B, T, _ = x.shape

        Q = self.W_q(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        K = self.W_k(x).view(B, T, self.n_kv_heads, self.d_head).transpose(1, 2)
        V = self.W_v(x).view(B, T, self.n_kv_heads, self.d_head).transpose(1, 2)

        # 将 KV 重复扩展以匹配 Q 的头数
        K = K.repeat_interleave(self.n_rep, dim=1)  # (B, n_heads, T, d_head)
        V = V.repeat_interleave(self.n_rep, dim=1)

        scale = self.d_head ** -0.5
        scores = (Q @ K.transpose(-2, -1)) * scale
        mask = torch.tril(torch.ones(T, T, device=x.device)).bool()
        scores = scores.masked_fill(~mask, float('-inf'))
        attn = F.softmax(scores, dim=-1)

        out = (attn @ V).transpose(1, 2).contiguous().view(B, T, -1)
        return self.W_o(out)


# KV Cache 大小对比
n_layers, n_heads, n_kv_heads = 32, 32, 8
d_head, seq_len, batch = 128, 4096, 1

mha_kv_bytes = 2 * n_layers * n_heads * d_head * seq_len * batch * 2  # float16
gqa_kv_bytes = 2 * n_layers * n_kv_heads * d_head * seq_len * batch * 2

print(f"MHA KV Cache: {mha_kv_bytes / 1e9:.2f} GB")
print(f"GQA KV Cache: {gqa_kv_bytes / 1e9:.2f} GB  (节省 {(1-gqa_kv_bytes/mha_kv_bytes)*100:.0f}%)")
💡

实际数据:Llama-3 70B(GQA, n_kv_heads=8)在 8K 上下文、batch=1、float16 下,KV Cache 约需 4GB 显存。如果用 MHA,同等配置需要 16GB,根本无法在单卡上运行。

六、用 Transformers 库验证实际效果

HuggingFace Transformers KV Cache 对比测试信息图
5

用 HuggingFace Transformers 对比有/无缓存的实际速度

HuggingFace 的 transformers 库默认开启 KV Cache(use_cache=True),你也可以手动关闭来做对比测试。

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time

model_name = "gpt2"  # 小模型,无需 GPU 也能快速测试
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32)
model.eval()

prompt = "The future of artificial intelligence is"
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"]

n_new_tokens = 50

# ─── 方案 A: 禁用 KV Cache ───
t0 = time.perf_counter()
with torch.no_grad():
    out_no_cache = model.generate(
        input_ids,
        max_new_tokens=n_new_tokens,
        use_cache=False,        # ← 关键:禁用 KV Cache
        do_sample=False,
    )
time_no_cache = time.perf_counter() - t0

# ─── 方案 B: 启用 KV Cache(默认)───
t0 = time.perf_counter()
with torch.no_grad():
    out_with_cache = model.generate(
        input_ids,
        max_new_tokens=n_new_tokens,
        use_cache=True,         # ← 默认开启
        do_sample=False,
        cache_implementation="static",  # 可选: "dynamic"(默认), "static", "quantized_static"
    )
time_with_cache = time.perf_counter() - t0

print(f"\n===== 实际模型对比 =====")
print(f"无 KV Cache — 生成 {n_new_tokens} tokens 耗时: {time_no_cache:.2f}s")
print(f"有 KV Cache — 生成 {n_new_tokens} tokens 耗时: {time_with_cache:.2f}s")
print(f"加速比: {time_no_cache / time_with_cache:.1f}x")
print(f"\n生成文本: {tokenizer.decode(out_with_cache[0], skip_special_tokens=True)}")
ℹ️

cache_implementation 参数:transformers 4.38+ 支持多种缓存策略:dynamic(默认,动态扩展)、static(预分配固定大小,速度更快但需指定 max_length)、quantized_static(量化 KV Cache,节省显存)。

常见问题解答

KV Cache 会影响输出结果的正确性吗?
不会。KV Cache 是纯粹的计算优化,缓存的 K/V 值与不使用缓存时的计算结果完全相同,只是避免了重复计算。输出结果(token 概率)是完全一致的。
KV Cache 需要多少显存?有公式吗?
每 token 的 KV Cache 大小 = 2(K+V)× n_layers × n_kv_heads × d_head × bytes_per_element。以 Llama-3 8B(float16)为例:2 × 32层 × 8头 × 128维 × 2字节 ≈ 每 token 131KB,8K 上下文约 1GB。
为什么 Decode 阶段无法像 Prefill 一样并行?
因为自回归生成具有天然的数据依赖:第 t 个 token 的生成依赖于第 t-1 个 token,无法并行化。这是 LLM 推理效率的根本瓶颈,也是 Speculative Decoding、并行采样等技术的动机。
KV Cache 和 Beam Search 能一起用吗?
可以,但复杂度更高。Beam Search 需要为每条搜索路径维护独立的 KV Cache。当 beam width=4 时,显存消耗是 greedy search 的 4 倍。这也是大模型推理更倾向于 temperature sampling 而非 beam search 的原因之一。
如何在资源有限的环境中运行带 KV Cache 的大模型?
可以采用:① 量化(INT8/INT4)KV Cache 降低显存;② CPU offload(将 KV Cache 卸载到内存甚至磁盘,但增加延迟);③ 减小 max_length 限制最大缓存大小;④ 使用 GQA/MQA 减少 KV head 数量。llama.cpp 和 mlx-lm 都支持这些优化。

总结:KV Cache 的三个演进时代

时代方案核心特点代表框架
Era 1 朴素 KV Cache 连续内存分配,无碎片管理 早期 HuggingFace
Era 2 PagedAttention 分页管理,前缀共享,高并发 vLLM, SGLang
Era 3 分布式 KV Cache 跨节点缓存池,P/D 分离,RAG 加速 Mooncake, CacheBlend

通过本教程,你已经理解了 KV Cache 的核心原理,并动手实现了一个完整的带缓存 Self-Attention。关键要点:

  • KV Cache 将解码阶段的计算从 O(n²) 降至 O(n),是大模型推理加速的基石
  • 实现上只需缓存 K/V 矩阵并在每步追加新 token,修改量极小
  • 生产环境推荐使用 vLLM(PagedAttention)或 HuggingFace 的 static cache
  • GQA/MQA 可以从根源减少 KV Cache 体积,与 PagedAttention 正交互补

下一步建议:理解了 KV Cache 后,可以进一步探索 Speculative Decoding(用小模型预测 token 草稿,大模型并行验证)和 Flash Attention(IO 感知的高效 Attention 实现),这两项技术与 KV Cache 配合可以实现更极致的推理加速。

选择栏目
今日简报 播客电台 AI 实战教程 关于我
栏目
全球AI日报国内AI日报全球金融日报国内金融日报全球大新闻日报国内大新闻日报Claude Code 玩法日报OpenClaw 动态日报GitHub 热门项目日报AI工具实战AI应用开发编程实战工作流自动化AI原理图解AI Agent开发
我的收藏
播客版
0:00
--:--