1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
| import torch import torch.nn as nn import math
class QwenAttention(nn.Module): """ Qwen注意力机制 支持Flash Attention优化 """ def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.max_position_embeddings = config.max_position_embeddings self.rotary_emb = QwenRotaryEmbedding( dim=self.head_dim, max_position_embeddings=self.max_position_embeddings ) self.qkv_proj = nn.Linear( self.hidden_size, 3 * self.hidden_size, bias=False ) self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) def forward(self, x, attention_mask=None, position_ids=None): B, T, C = x.shape qkv = self.qkv_proj(x) qkv = qkv.reshape(B, T, 3, self.num_heads, self.head_dim) qkv = qkv.permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] q, k = self.rotary_emb(q, k, position_ids) if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): attn_output = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=attention_mask, dropout_p=0.0 if not self.training else 0.1 ) else: attn_output = self._fallback_attention(q, k, v, attention_mask) return self.o_proj(attn_output) def _fallback_attention(self, q, k, v, mask): scale = 1.0 / math.sqrt(self.head_dim) attn = torch.matmul(q, k.transpose(-2, -1)) * scale if mask is not None: attn = attn + mask attn = nn.functional.softmax(attn, dim=-1) return torch.matmul(attn, v)
class QwenRotaryEmbedding(nn.Module): """ Qwen旋转位置编码 """ def __init__(self, dim, max_position_embeddings=2048, base=10000): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) def forward(self, q, k, position_ids=None): if position_ids is None: T = q.shape[2] position_ids = torch.arange(T, device=q.device).unsqueeze(0) position_ids = position_ids.float() freqs = torch.einsum("i,j->ij", position_ids[0], self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() q_embed = self._rotate_half(q, cos, sin) k_embed = self._rotate_half(k, cos, sin) return q_embed, k_embed def _rotate_half(self, x, cos, sin): x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:] return torch.cat((-x2, x1), dim=-1) * cos + x1 * (-sin) + x2 * cos
|