自注意力机制详解:从数学原理到PyTorch实现

🎙️ 语音朗读 当前: 晓晓 (温柔女声)

概述

自注意力机制(Self-Attention)是Transformer架构的核心组件,它允许序列中的每个位置关注序列中的所有其他位置。本文将深入解析自注意力的数学原理,并通过PyTorch实现来加深理解。

自注意力机制原理

核心思想

自注意力机制的核心思想是:通过Query、Key、Value三个向量来计算序列内部元素之间的关联程度

flowchart LR
    subgraph 输入
        X[输入序列]
    end
    subgraph QKV生成
        X --> WQ[WQ权重矩阵]
        X --> WK[WK权重矩阵]
        X --> WV[WV权重矩阵]
        WQ --> Q[Query向量]
        WK --> K[Key向量]
        WV --> V[Value向量]
    end
    subgraph 注意力计算
        Q --> Dot[点积]
        K --> Dot
        Dot --> Soft[Softmax]
        Soft --> Attn[注意力权重]
        Attn --> Out[输出]
        V --> Mul[加权求和]
        Attn --> Mul
        Mul --> Out
    end

数学公式

自注意力的计算过程:

$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$

其中 $\sqrt{d_k}$ 是缩放因子,用于防止点积过大导致梯度消失。

缩放点积注意力详解

sequenceDiagram
    participant Q as Query
    participant K as Key
    participant V as Value
    participant Out as Output
    
    Note over Q,K,V: 1. 计算点积
    Q->>K: Q · K^T
    Note over Q,K: 缩放因子: ÷ √d_k
    
    Note over K: 2. Softmax归一化
    K->>K: softmax(QK^T/√d_k)
    
    Note over K: 3. 加权求和
    K->>Out: Attention × V

PyTorch实现

基础自注意力层

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SelfAttention(nn.Module):
"""缩放点积注意力实现"""

def __init__(self, embed_dim, dropout=0.1):
super().__init__()
self.embed_dim = embed_dim
self.dropout = nn.Dropout(dropout)

# QKV投影矩阵
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)

def forward(self, x, mask=None):
"""
Args:
x: [batch_size, seq_len, embed_dim]
mask: [batch_size, 1, seq_len, seq_len] 可选掩码
Returns:
output: [batch_size, seq_len, embed_dim]
attention_weights: [batch_size, num_heads, seq_len, seq_len]
"""
batch_size, seq_len, _ = x.shape

# 生成QKV
Q = self.q_proj(x) # [B, L, D]
K = self.k_proj(x)
V = self.v_proj(x)

# 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.embed_dim)

# 应用掩码(如果提供)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))

# Softmax归一化
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)

# 加权求和
output = torch.matmul(attention_weights, V)
output = self.out_proj(output)

return output, attention_weights

多头注意力机制

flowchart TB
    subgraph 多头注意力
        subgraph 头1
            X1[输入] --> Q1[Q₁]
            X1 --> K1[K₁]
            X1 --> V1[V₁]
            Q1 --> Attn1[Attention]
            K1 --> Attn1
            V1 --> Attn1
            Attn1 --> O1[O₁]
        end
        subgraph 头2
            X2[输入] --> Q2[Q₂]
            X2 --> K2[K₂]
            X2 --> V2[V₂]
            Q2 --> Attn2[Attention]
            K2 --> Attn2
            V2 --> Attn2
            Attn2 --> O2[O₂]
        end
        subgraph 头h
            Xh[输入] --> Qh[Qₕ]
            Xh --> Kh[Kₕ]
            Xh --> Vh[Vₕ]
            Qh --> Attnh[Attention]
            Kh --> Attnh
            Vh --> Attnh
            Attnh --> Oh[Oₕ]
        end
    end
    
    O1 --> Concat[拼接]
    O2 --> Concat
    Oh --> Concat
    Concat --> W[WO权重矩阵]
    W --> Output[最终输出]

多头注意力实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class MultiHeadAttention(nn.Module):
"""多头注意力机制"""

def __init__(self, embed_dim, num_heads, dropout=0.1):
super().__init__()
assert embed_dim % num_heads == 0, "embed_dim必须能被num_heads整除"

self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads

self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)

def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape

# QKV投影
Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

# 缩放点积注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)

if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))

attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)

# 加权求和并拼接
context = torch.matmul(attention_weights, V)
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)

output = self.out_proj(context)

return output, attention_weights

性能对比

实现方式 参数量 计算复杂度 序列长度限制
标准注意力 O(d²) O(n²·d) 较长序列受限
线性注意力 O(d) O(n·d²) 可处理超长序列
Flash Attention O(d²) O(n²·d) 显存优化版本

实际应用示例

1
2
3
4
5
6
7
8
9
10
11
12
# 使用示例
embed_dim = 512
num_heads = 8
batch_size = 16
seq_len = 100

model = MultiHeadAttention(embed_dim, num_heads)
x = torch.randn(batch_size, seq_len, embed_dim)

output, attn_weights = model(x)
print(f"输出形状: {output.shape}") # [16, 100, 512]
print(f"注意力权重形状: {attn_weights.shape}") # [16, 8, 100, 100]

总结

mindmap
  root((自注意力))
    核心原理
      Query-Key-Value
      点积注意力
      缩放因子
    多头注意力
      并行计算
      注意力头
      拼接输出
    应用场景
      Transformer
      BERT
      GPT
      Vision Transformer
    优化技术
      掩码机制
      稀疏注意力
      Flash Attention

自注意力机制是现代深度学习的基石之一,理解其原理对于掌握Transformer架构至关重要。

© 2019-2026 ovo$^{mc^2}$ All Rights Reserved. | 站点总访问 28969 次 | 访客 19045
Theme by hiero