首页 / AI原理图解 / 图解 Transformer 注意力机制:从 Self-Attention 到 Multi-Head 1 次阅读
图解 Transformer 注意力机制:从 Self-Attention 到 Multi-Head
AI 原理图解

图解 Transformer 注意力机制
从 Self-Attention 到 Multi-Head

通过可视化图解和 PyTorch 代码实战,深入理解 Transformer 注意力机制的核心原理。掌握 Q/K/V 矩阵计算、Multi-Head 设计动机,以及 2026 年最新高效注意力变体。

2026 年 3 月 1 日
Transformer · 注意力机制 · PyTorch
约 15 分钟阅读

当你阅读这段文字时,你的大脑并不会平等地处理每个字——你会自然地将注意力集中在关键信息上。这正是注意力机制的核心直觉。2017 年,《Attention Is All You Need》论文彻底改变了深度学习,而今天,注意力机制已成为大语言模型的基石。

本教程你将学到

  • Self-Attention 的 Q/K/V 矩阵计算与数学直觉
  • Multi-Head Attention 的设计动机与并行计算优势
  • 如何用 PyTorch 从零实现注意力机制
  • 注意力权重的可视化方法与调试技巧
  • 2026 年最新高效注意力变体(FlashAttention、GQA)

理解注意力的直觉:从人类认知到数学公式

想象你在阅读这句话:"The pizza came out of the oven and it tasted good."当你的大脑处理代词"it"时,会自动将它与"pizza"而非"oven"建立联系——这种动态关联上下文的能力就是注意力的本质。

在传统 RNN 中,信息必须按顺序传递,就像接力赛一样,这不仅效率低下,还容易导致长距离依赖的丢失。而注意力机制允许模型直接跨越任意距离建立连接,实现了真正的并行计算。

RNN 与 Transformer 处理序列的对比:RNN 是串行处理,Transformer 是并行全连接
图 1:RNN 的序列化处理(左)vs Transformer 的并行注意力(右)

注意力机制的数学表达出奇地优雅:给定一个查询(Query)和一组键值对(Key-Value pairs),注意力输出是值的加权和,权重由查询与对应键的相似度决定。

Q/K/V 三要素:注意力机制的核心组件

Self-Attention 的核心是将输入转换为三个向量:Query(查询)、Key(键)、Value(值)。这个设计灵感来源于信息检索系统——当你搜索信息时,你提供一个查询(Query),系统用它匹配数据库中的键(Key),找到匹配后返回对应的值(Value)。

向量符号维度物理含义
QueryQ(seq_len, d_k)查询向量,表示"我想找什么"
KeyK(seq_len, d_k)键向量,表示"我有什么信息"
ValueV(seq_len, d_v)值向量,表示"实际的内容"

注意力公式如下:

Attention(Q, K, V) = softmax(QK^T / √d_k) · V

其中 Q = XW_Q,K = XW_K,V = XW_V,W_Q、W_K、W_V 是可学习的投影矩阵。

Q/K/V 矩阵计算流程图:输入 X 经过三个线性变换得到 Q、K、V
图 2:Q/K/V 矩阵的生成与注意力计算流程

缩放点积注意力:为什么需要除以 √d_k?

当你第一次看到公式中的 / √d_k 时,可能会疑惑:这个缩放因子是做什么的?答案关乎数值稳定性

当 d_k 较大时,QK^T 的点积结果会变得很大。考虑极端情况:两个 512 维向量的点积可能在 [-512, 512] 范围内。当这个值输入 softmax 函数时,会导致梯度消失——softmax 在输入值很大或很小时会变得非常"平坦",梯度接近于零。

import torch
import torch.nn as nn
import math

class ScaledDotProductAttention(nn.Module):
    def __init__(self, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        d_k = query.size(-1)
        # 1. 计算相似度分数 QK^T
        scores = torch.matmul(query, key.transpose(-2, -1))
        # 2. 缩放:防止点积过大导致 softmax 梯度消失
        scaled_scores = scores / math.sqrt(d_k)
        # 3. Mask(用于 decoder 或 padding)
        if mask is not None:
            scaled_scores = scaled_scores.masked_fill(mask == 0, -1e9)
        # 4. Softmax 归一化得到注意力权重
        attention_weights = torch.softmax(scaled_scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        # 5. 加权求和
        output = torch.matmul(attention_weights, value)
        return output, attention_weights
缩放因子对 softmax 梯度的影响:不同 d_k 下的梯度分布对比
图 3:缩放因子对 softmax 梯度的影响——d_k 越大,越需要缩放

Multi-Head Attention:多个视角同时学习

Single-Head 注意力已经能工作了,但为什么 Transformer 要使用 Multi-Head?答案是:让模型在多个不同的子空间中同时学习注意力

不同的"注意力头"可能关注不同类型的信息:头 1 可能关注语法结构(主谓关系),头 2 可能关注语义关联(代词指代),头 3 可能关注位置关系(相邻词)。

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        # 合并 QKV 投影以提高效率
        self.qkv_proj = nn.Linear(d_model, 3 * d_model)
        self.output_proj = nn.Linear(d_model, d_model)
        self.attention = ScaledDotProductAttention(dropout)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        # 生成 QKV 并 reshape 为多头格式
        qkv = self.qkv_proj(x).view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        query, key, value = qkv[0], qkv[1], qkv[2]
        # Self-Attention(每个头并行计算)
        attn_output, attn_weights = self.attention(query, key, value, mask)
        # 合并头并输出
        attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)
        return self.output_proj(attn_output), attn_weights
Multi-Head Attention 结构图:多个头并行计算注意力然后拼接
图 4:Multi-Head Attention 结构——h 个头并行计算,然后拼接并通过线性层

注意力权重可视化:看见模型的"注意力"

注意力权重是一个矩阵,shape 为 (seq_len, seq_len),每个元素表示某个 token 对另一个 token 的关注程度。可视化这个矩阵,我们就能"看见"模型在关注什么。

import matplotlib.pyplot as plt
import seaborn as sns

def plot_attention_heatmap(attention_weights, tokens, head_idx=0, layer_idx=0):
    # 提取特定层和头的注意力权重
    attn = attention_weights[layer_idx][0][head_idx].cpu().numpy()
    plt.figure(figsize=(10, 8))
    sns.heatmap(attn, xticklabels=tokens, yticklabels=tokens, cmap='YlOrRd')
    plt.title(f"Layer {layer_idx}, Head {head_idx}")
    plt.xlabel("Keys"); plt.ylabel("Queries")
    plt.savefig(f"attention_l{layer_idx}_h{head_idx}.png", dpi=150)
    plt.show()
注意力权重热力图:展示每个 token 对其他 token 的关注程度
图 5:注意力权重热力图——颜色越深表示注意力越强

常见的注意力模式:对角线模式(自注意力)、垂直线模式(某些 token 被全局关注)、块状模式(短语内部注意力)、长距离模式(跨句子依赖)。

2026 年高效注意力变体:FlashAttention 与 GQA

标准注意力(MHA)的 KV Cache 大小为 O(h·n·d),对于长序列和大头数会消耗大量内存。2026 年,两种优化成为主流:

变体全称KV Cache质量适用场景
MHAMulti-Head AttentionO(h·n·d)100%训练、高质量推理
MQAMulti-Query AttentionO(n·d)~85%高吞吐推理
GQAGrouped-Query AttentionO(g·n·d)~95%Llama 3、Mistral
FlashAttnIO-Aware Attention同 MHA100% + 更快训练和推理加速

GQA(Grouped-Query Attention) 是 Llama 3 的标准配置:32 个 Query 头共享 8 个 KV 头,在保持质量的同时减少 4 倍 KV Cache。

FlashAttention-3 针对 H100/H200 优化,通过 Tiling 分块、Kernel Fusion 内核融合、Recomputation 重计算,实现 7-10 倍速度提升。

# 使用 FlashAttention-2/3
from flash_attn import flash_attn_func
output = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)

# PyTorch 2.0+ SDPA(自动选择最优后端)
import torch.nn.functional as F
output = F.scaled_dot_product_attention(q, k, v, is_causal=True)

常见问题 FAQ

Q: 为什么注意力权重需要除以 √d_k?
当 d_k 较大时,QK^T 的点积会很大,导致 softmax 进入梯度饱和区。除以 √d_k 可以保持数值稳定性,防止梯度消失。
Q: Multi-Head 的头数越多越好吗?
不是。头数增加会增加计算量和参数量。常见配置:小模型用 8 头,大模型用 16-32 头。Llama 3 8B 用 32 头,70B 用 64 头。
Q: 注意力坍塌是什么?如何解决?
注意力过度集中在单个 token 上,导致训练不稳定。解决方案:使用 warmup 学习率调度、监控注意力熵、使用 gradient clipping。

核心收获

  • Self-Attention 通过 Q/K/V 矩阵实现动态上下文关联,公式为 softmax(QK^T/√d_k)V
  • Multi-Head 让模型在多个子空间并行学习注意力,增强表达能力
  • 缩放因子 √d_k 防止点积过大导致 softmax 梯度消失
  • 注意力权重可视化可以"看见"模型的关注模式
  • GQA 和 FlashAttention 是 2026 年主流高效注意力变体

立即开始实践

打开 PyTorch,从零实现一个 Multi-Head Attention 层,用 BertViz 可视化工具观察注意力权重。动手写一遍代码,比读十篇教程都有效。

选择栏目
今日简报 播客电台 实战教程 AI挣钱计划 关于我
栏目
全球AI日报国内AI日报全球金融日报国内金融日报全球大新闻日报国内大新闻日报Claude Code 玩法日报OpenClaw 动态日报GitHub 热门项目日报AI工具实战AI应用开发编程实战工作流自动化AI原理图解AI Agent开发AI变现案例库AI工具创收AI内容变现AI接单提效变现前沿研究
我的收藏
播客版
0:00
--:--