首页 / AI原理图解 / 图解推测解码:如何让大模型推理速度提升 2-5 倍 6 次阅读
图解推测解码:如何让大模型推理速度提升 2-5 倍
AI 原理图解

图解推测解码:如何让大模型推理速度提升 2-5 倍

使用小型草稿模型提前预测多个 token,用大型目标模型并行验证,无需牺牲输出质量即可实现显著加速

2026 年 3 月 21 日 · 8 分钟阅读

大模型推理速度慢是实际应用中的核心瓶颈。传统的自回归解码需要逐个 token 串行生成,每次只能输出一个 token,导致高延迟和高成本。

推测解码(Speculative Decoding)通过引入一个小型、快速的草稿模型提前预测多个 token,然后用大型目标模型并行验证这些预测。如果验证通过,一次就能输出多个 token,从而实现 2-5 倍的速度提升。

推测解码与传统自回归解码的速度对比图

2026 年的最新进展 Speculative Speculative Decoding (SSD) 进一步将起草和验证过程并行化,实现了比优化版 SD 再快 2 倍的性能。

核心原理

推测解码的关键在于两个模型的协作:

  • 草稿模型(Draft Model):小型、快速的模型,负责提前预测 K 个 token
  • 目标模型(Target Model):大型、慢速的模型,负责一次性验证所有草稿 token
草稿模型与目标模型协作流程图

验证过程使用拒绝采样(Rejection Sampling):如果目标模型对某个 token 的预测概率大于等于草稿模型,则接受该 token;否则以一定概率接受,并从该点重新开始起草。

准备工作

🐍
Python 3.10+
运行环境
🤗
Transformers
模型加载与推理
🔥
PyTorch
深度学习框架
📦
GPU(可选)
CUDA 加速推理

安装依赖:

pip install torch transformers accelerate

实战步骤

1

加载目标模型和草稿模型

选择一个大型目标模型和一个小型草稿模型。常见的组合包括:

  • 目标模型:Llama-3-70B,草稿模型:Llama-3-8B
  • 目标模型:CodeLlama-34B,草稿模型:CodeLlama-7B
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# 加载目标模型(大型)
target_model = AutoModelForCausalLM.from_pretrained(
    'meta-llama/Llama-3-70b-Instruct',
    torch_dtype=torch.float16,
    device_map='auto'
)

# 加载草稿模型(小型)
draft_model = AutoModelForCausalLM.from_pretrained(
    'meta-llama/Llama-3-8b-Instruct',
    torch_dtype=torch.float16,
    device_map='auto'
)

tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-3-70b-Instruct')
模型加载流程图
2

实现草稿生成函数

草稿模型自回归地生成 K 个候选 token:

def generate_draft_tokens(draft_model, input_ids, k=5):
    """使用草稿模型生成 K 个候选 token"""
    draft_ids = input_ids.clone()
    generated_tokens = []

    with torch.no_grad():
        for _ in range(k):
            outputs = draft_model(draft_ids)
            next_token_logits = outputs.logits[:, -1, :]
            next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)

            draft_ids = torch.cat([draft_ids, next_token], dim=1)
            generated_tokens.append(next_token.item())

    return generated_tokens, draft_ids
3

实现并行验证逻辑

目标模型一次性验证所有草稿 token,使用拒绝采样决定接受哪些 token:

def verify_tokens(target_model, input_ids, draft_tokens):
    """使用目标模型验证草稿 token"""
    # 拼接草稿 token
    draft_ids = torch.cat([
        input_ids,
        torch.tensor([draft_tokens], device=input_ids.device)
    ], dim=1)

    with torch.no_grad():
        # 目标模型一次性前向传播
        outputs = target_model(draft_ids)
        logits = outputs.logits[0, input_ids.shape[1]-1:-1]
        probs = torch.softmax(logits, dim=-1)

    # 获取草稿模型的概率分布
    draft_probs = []
    for i, token in enumerate(draft_tokens):
        draft_outputs = draft_model(draft_ids[:, :input_ids.shape[1]+i])
        draft_logits = draft_outputs.logits[0, -1, :]
        draft_probs.append(torch.softmax(draft_logits, dim=-1))

    # 拒绝采样
    accepted_tokens = []
    for i, (draft_token, draft_prob) in enumerate(zip(draft_tokens, draft_probs)):
        p_target = probs[i, draft_token].item()
        p_draft = draft_prob[draft_token].item()

        if p_target >= p_draft:
            # 直接接受
            accepted_tokens.append(draft_token)
        else:
            # 以概率 p_target/p_draft 接受
            if torch.rand(1).item() < p_target / p_draft:
                accepted_tokens.append(draft_token)
            else:
                # 拒绝,从目标模型采样新 token
                new_token = torch.multinomial(probs[i], 1).item()
                accepted_tokens.append(new_token)
                break

    return accepted_tokens
拒绝采样验证流程图
4

整合推测解码主循环

将草稿生成和验证整合到完整的解码循环中:

def speculative_decoding(target_model, draft_model, input_ids, max_length=100, k=5):
    """推测解码主函数"""
    generated = input_ids.clone()

    while generated.shape[1] < max_length:
        # 步骤 1: 草稿模型生成 K 个 token
        draft_tokens, _ = generate_draft_tokens(draft_model, generated, k)

        # 步骤 2: 目标模型验证
        accepted = verify_tokens(target_model, generated, draft_tokens)

        # 步骤 3: 追加接受的 token
        generated = torch.cat([
            generated,
            torch.tensor([accepted], device=generated.device)
        ], dim=1)

        # 如果全部被拒绝,直接从目标模型生成
        if len(accepted) == 0:
            with torch.no_grad():
                outputs = target_model(generated)
                next_token = torch.argmax(outputs.logits[:, -1, :], dim=-1)
                generated = torch.cat([generated, next_token.unsqueeze(0)], dim=1)

    return generated
推测解码完整流程图
5

测试与性能对比

运行推测解码并与传统自回归解码对比:

import time

prompt = "解释一下 Transformer 模型中的注意力机制是如何工作的"
input_ids = tokenizer.encode(prompt, return_tensors='pt').to('cuda')

# 推测解码
start = time.time()
output_spec = speculative_decoding(target_model, draft_model, input_ids, max_length=100)
time_spec = time.time() - start
tokens_spec = output_spec.shape[1] - input_ids.shape[1]

# 传统自回归解码
start = time.time()
output_ar = target_model.generate(input_ids, max_length=100)
time_ar = time.time() - start
tokens_ar = output_ar.shape[1] - input_ids.shape[1]

print(f"推测解码:{tokens_spec} tokens, {time_spec:.2f}s, {tokens_spec/time_spec:.1f} tokens/s")
print(f"自回归解码:{tokens_ar} tokens, {time_ar:.2f}s, {tokens_ar/time_ar:.1f} tokens/s")
print(f"加速比:{time_ar/time_spec:.2f}x")
性能对比图表
6

使用 Hugging Face 内置支持

Transformers 库已内置推测解码支持,可以直接使用:

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-3-70b-Instruct')
assistant_model = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-3-8b-Instruct')

input_text = "解释量子纠缠"
inputs = tokenizer(input_text, return_tensors='pt').to('cuda')

# 使用 assisted_decoding 参数启用推测解码
outputs = model.generate(
    **inputs,
    assistant_model=assistant_model,
    max_length=200,
    num_assistant_tokens=5  # 每次起草 5 个 token
)

print(tokenizer.decode(outputs[0]))
7

进阶:SSD 并行化优化

2026 年最新的 SSD(Speculative Speculative Decoding)进一步并行化起草和验证:

class SpeculativeSpeculativeDecoder:
    """SSD: 并行化起草和验证"""

    def __init__(self, target_model, draft_model):
        self.target = target_model
        self.draft = draft_model
        # 预测验证结果的模型
        self.predictor = self._build_predictor()

    def _build_predictor(self):
        """构建轻量级预测器,预测哪些草稿 token 会被接受"""
        # 可以使用小型 MLP 或查找表
        pass

    def decode_step(self, input_ids):
        # 异步并行执行
        # 1. 草稿模型生成当前 token
        # 2. 预测器提前预测验证结果
        # 3. 目标模型验证历史 token
        # 4. 根据预测结果预计算下一步草稿
        pass
SSD 并行化架构图

常见问题

接受率低怎么办?

接受率低说明草稿模型与目标模型差异太大。解决方案:

  • 选择同系列的模型(如 Llama-3-8B + Llama-3-70B)
  • 减少草稿 token 数量 K
  • 对草稿模型进行蒸馏微调,使其输出分布更接近目标模型
推测解码适合哪些场景?

推测解码特别适合:

  • 高延迟场景:实时对话、交互式应用
  • 大批量推理:需要服务多个用户
  • 长文本生成:文章写作、代码生成

不适合:对输出质量要求极端严格、不能有任何概率性损失的场景

显存占用会增加吗?

需要同时加载两个模型,显存占用会增加。但可以通过以下方式优化:

  • 使用量化:草稿模型用 INT8,目标模型用 FP16
  • 使用 offload:草稿模型放在 CPU,需要时再加载到 GPU
  • 选择显存占用小的草稿模型(如 1-3B 参数)

最佳实践总结

  • ✓ 选择同系列、同词汇表的模型组合,确保输出分布一致
  • ✓ 草稿 token 数量 K 通常设置为 4-8,根据接受率动态调整
  • ✓ 使用 GPU 时确保两个模型在同一设备上,避免数据传输开销
  • ✓ 对于代码生成任务,使用专门的代码模型作为草稿模型
  • ✓ 监控接受率指标,理想值为 60-80%
  • ✓ 考虑使用 SSD 等 2026 年最新变体,进一步去除序列化瓶颈

关键洞察:推测解码的核心价值在于用计算资源换时间——草稿模型的额外计算成本远低于目标模型,而一次验证多个 token 带来的并行度提升可以显著降低总体延迟。2026 年的 SSD 方案通过预测验证结果并预计算草稿,进一步将加速比提升到 5 倍以上。

选择栏目
今日简报 播客电台 实战教程 AI挣钱计划 关于我
栏目
全球AI日报国内AI日报全球金融日报国内金融日报全球大新闻日报国内大新闻日报Claude Code 玩法日报OpenClaw 动态日报GitHub 热门项目日报AI工具实战AI应用开发编程实战工作流自动化AI原理图解AI Agent开发AI变现案例库AI工具创收AI内容变现AI接单提效变现前沿研究
我的收藏