1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
| class RelativePositionalEncoding(nn.Module): """相对位置编码(Shaw等人提出)""" def __init__(self, d_model, max_relative_position=64, dropout=0.1): super().__init__() self.d_model = d_model self.max_relative_position = max_relative_position self.relative_embeddings = nn.Parameter( torch.randn(2 * max_relative_position + 1, d_model) ) def forward(self, length): """ 生成相对位置编码 """ range_vec = torch.arange(length) relative_indices = range_vec.unsqueeze(0) - range_vec.unsqueeze(1) relative_indices = relative_indices.clamp(-self.max_relative_position, self.max_relative_position) relative_indices = relative_indices + self.max_relative_position return self.relative_embeddings[relative_indices]
class ShawAttention(nn.Module): """带相对位置编码的注意力""" def __init__(self, d_model, num_heads, max_relative_position=64): super().__init__() self.d_model = d_model self.num_heads = num_heads self.d_k = d_model // num_heads self.max_relative_position = max_relative_position self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) self.relative_embeddings = RelativePositionalEncoding( d_model, max_relative_position ) def forward(self, query, key, value, mask=None): batch_size = query.size(0) Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) length = query.size(1) relative_encoding = self.relative_embeddings(length) relative_scores = torch.matmul(Q.transpose(1, 2), relative_encoding.transpose(0, 1)) scores = scores + relative_scores if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf')) attn_weights = F.softmax(scores, dim=-1) output = torch.matmul(attn_weights, V) output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) return self.W_o(output)
|