1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
| import torch import torch.nn as nn import math
class MistralConfig: """ Mistral 7B配置 """ def __init__(self): self.vocab_size = 32000 self.hidden_size = 4096 self.intermediate_size = 14336 self.num_hidden_layers = 32 self.num_attention_heads = 32 self.num_key_value_heads = 8 self.hidden_act = "silu" self.max_position_embeddings = 32768 self.rope_theta = 10000.0 self.sliding_window = 4096 self.rope_scaling = {"type": "linear", "factor": 2.0}
class MistralAttention(nn.Module): """ Mistral注意力机制 支持滑动窗口和Grouped Query Attention """ def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = config.hidden_size // config.num_heads self.num_key_value_heads = config.num_key_value_heads self.sliding_window = config.sliding_window self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size) self.rotary_emb = MistralRotaryEmbedding(self.head_dim, config.max_position_embeddings) def forward(self, x, attention_mask=None, position_ids=None): B, T, C = x.shape q = self.q_proj(x).reshape(B, T, self.num_heads, self.head_dim) k = self.k_proj(x).reshape(B, T, self.num_key_value_heads, self.head_dim) v = self.v_proj(x).reshape(B, T, self.num_key_value_heads, self.head_dim) q, k = self.rotary_emb(q, k, position_ids) q = q.reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2) k = k.reshape(B, T, self.num_key_value_heads, self.head_dim).transpose(1, 2) v = v.reshape(B, T, self.num_key_value_heads, self.head_dim).transpose(1, 2) k = self._repeat_kv(k, self.num_heads // self.num_key_value_heads) v = self._repeat_kv(v, self.num_heads // self.num_key_value_heads) if self.sliding_window and T > self.sliding_window: mask = torch.triu(torch.ones(T, T, device=x.device), 1).bool() mask = mask | (torch.arange(T, device=x.device).unsqueeze(0) < T - self.sliding_window) if attention_mask is not None: mask = mask & attention_mask else: mask = attention_mask attn = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=mask ) return self.o_proj(attn.transpose(1, 2).reshape(B, T, C)) def _repeat_kv(self, x, n_rep): """重复K/V以匹配Q头数""" B, n_kv_heads, T, head_dim = x.shape if n_rep == 1: return x return x[:, :, None, :, :].expand(B, n_kv_heads, n_rep, T, head_dim).reshape(B, n_kv_heads * n_rep, T, head_dim)
class MistralRotaryEmbedding(nn.Module): """ 旋转位置编码 """ def __init__(self, dim, max_position=32768, base=10000): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) def forward(self, q, k, position_ids): T = position_ids.shape[1] freqs = torch.einsum("i,j->ij", position_ids[0].float(), self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() return self._rotate(q, cos, sin), self._rotate(k, cos, sin) def _rotate(self, x, cos, sin): x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:] return torch.cat((-x2, x1), dim=-1) * cos + x1 * (-sin) + x2 * cos
|