1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
| import torch import torch.nn as nn import math
class RMSNorm(nn.Module): """ RMSNorm:更高效的正则化方法 LLaMA使用RMSNorm替代LayerNorm,提升训练稳定性 """ def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype)
class LLaMAAttention(nn.Module): """ LLaMA的自注意力机制,使用RoPE位置编码 """ def __init__(self, config): super().__init__() self.n_heads = config.n_heads self.head_dim = config.hidden_size // config.n_heads self.max_position_embeddings = config.max_position_embeddings self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) self.k_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) self.v_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) self.rotary_emb = RotaryEmbedding(self.head_dim) def forward(self, x, attention_mask=None): B, T, C = x.shape q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) q, k = self.rotary_emb(q, k) attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_output = torch.matmul(attn_weights, v) return self.o_proj(attn_output.transpose(1, 2).contiguous().view(B, T, C))
class RotaryEmbedding(nn.Module): """ Rotary Position Embedding (RoPE) 旋转位置编码,有效处理位置信息 """ def __init__(self, dim, max_position_embeddings=2048): super().__init__() inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) def forward(self, q, k): B, n_heads, T, head_dim = q.shape t = torch.arange(T, device=q.device).type_as(self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) q_embed = self._rotate_half(q, emb) k_embed = self._rotate_half(k, emb) return q_embed, k_embed def _rotate_half(self, x, emb): x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:] return torch.cat((-x2, x1), dim=-1) * self._get_cos_sin(emb) def _get_cos_sin(self, emb): return torch.cos(emb), torch.sin(emb)
|