1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
| class GPT3Config: """GPT-3配置""" def __init__(self): self.n_vocab = 50257 self.n_ctx = 2048 self.n_positions = 2048 self.n_embd = 12288 self.n_layer = 96 self.n_head = 96 self.n_head_size = 128 self.learning_rate = 0.00001 self.batch_size = 3.2M tokens self.train_steps = 300B tokens self.afn = "gelu" self.resid_dropout = 0.1 self.embd_dropout = 0.1 self.attn_dropout = 0.1 self.vocab_size = 50257
class GPT3SparseAttention(nn.Module): """GPT-3使用的稀疏注意力""" def __init__(self, n_embd, n_head, n_ctx, dropout=0.1): super().__init__() self.n_embd = n_embd self.n_head = n_head self.n_ctx = n_ctx self.d_k = n_embd // n_head self.qkv = nn.Linear(n_embd, n_embd * 3) self.window_size = 2048 self.stride = 64 self.out_proj = nn.Linear(n_embd, n_embd) self.dropout = nn.Dropout(dropout) def forward(self, x, attention_mask=None): B, T, C = x.shape qkv = self.qkv(x).reshape(B, T, 3, self.n_head, self.d_k) qkv = qkv.permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] attn_mask = self._create_sparse_mask(T, x.device) local_attn = self._local_attention(q, k, v, attn_mask) global_attn = self._global_attention(q, k, v) attn_output = local_attn + global_attn output = self.out_proj(attn_output) return self.dropout(output) def _create_sparse_mask(self, seq_len, device): """创建稀疏注意力掩码""" mask = torch.zeros(seq_len, seq_len, device=device) for i in range(seq_len): start = max(0, i - self.window_size) mask[i, start:i+1] = 1 for i in range(0, seq_len, self.stride): for j in range(i, min(i + self.stride, seq_len)): mask[j, i] = 1 if i + self.stride < seq_len: mask[j, i + self.stride] = 1 return mask.bool()
|