1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
| import torch import torch.nn as nn from transformers import PretrainedConfig
class GLMAttention(nn.Module): """ GLM的自注意力机制 支持长短混合注意力模式 """ def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.rotary_emb = RotaryEmbedding(self.head_dim) self.query_key_value = nn.Linear( config.hidden_size, 3 * config.hidden_size ) self.dense = nn.Linear(config.hidden_size, config.hidden_size) def forward(self, hidden_states, position_ids, attention_mask=None): B, L, _ = hidden_states.shape qkv = self.query_key_value(hidden_states) qkv = qkv.reshape(B, L, 3, self.num_heads, self.head_dim) q, k, v = qkv.unbind(2) q, k = self.rotary_emb(q, k, position_ids) attn_output = self._attn(q, k, v, attention_mask) return self.dense(attn_output) def _attn(self, q, k, v, attention_mask): scale = 1.0 / (self.head_dim ** 0.5) attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) return torch.matmul(attn_weights, v)
class GLMBlock(nn.Module): """ GLM Transformer Block """ def __init__(self, config): super().__init__() self.attention = GLMAttention(config) self.mlp = nn.Sequential( nn.Linear(config.hidden_size, config.intermediate_size * 4), nn.GELU(), nn.Linear(config.intermediate_size * 4, config.hidden_size) ) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size) def forward(self, hidden_states, position_ids, attention_mask=None): residual = hidden_states hidden_states = self.attention(hidden_states, position_ids, attention_mask) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states
|