1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
| import torch import torch.nn as nn import torch.nn.functional as F import math
class PatchEmbedding(nn.Module): """图像分块嵌入""" def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768): super().__init__() self.img_size = img_size self.patch_size = patch_size self.num_patches = (img_size // patch_size) ** 2 self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size=patch_size, stride=patch_size ) self.flatten_dim = in_channels * patch_size * patch_size self.linear_proj = nn.Linear(self.flatten_dim, embed_dim) def forward(self, x): """ Args: x: (B, C, H, W) Returns: (B, num_patches, embed_dim) """ x = self.proj(x) x = x.flatten(2) x = x.transpose(1, 2) x = self.linear_proj(x) return x
class MultiHeadSelfAttention(nn.Module): """多头自注意力""" def __init__(self, embed_dim=768, num_heads=8, dropout=0.1): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.d_k = embed_dim // num_heads self.qkv = nn.Linear(embed_dim, embed_dim * 3) self.attn_drop = nn.Dropout(dropout) self.proj = nn.Linear(embed_dim, embed_dim) self.proj_drop = nn.Dropout(dropout) self.scale = self.d_k ** -0.5 def forward(self, x, mask=None): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.d_k) qkv = qkv.permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q @ k.transpose(-2, -1)) * self.scale if mask is not None: attn = attn.masked_fill(mask.unsqueeze(1).unsqueeze(2) == 0, float('-inf')) attn = F.softmax(attn, dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x, attn
class TransformerBlock(nn.Module): """Transformer编码器块""" def __init__(self, embed_dim=768, num_heads=8, mlp_ratio=4.0, dropout=0.1, attn_dropout=0.1): super().__init__() self.norm1 = nn.LayerNorm(embed_dim) self.norm2 = nn.LayerNorm(embed_dim) self.attn = MultiHeadSelfAttention(embed_dim, num_heads, attn_dropout) mlp_hidden_dim = int(embed_dim * mlp_ratio) self.mlp = nn.Sequential( nn.Linear(embed_dim, mlp_hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(mlp_hidden_dim, embed_dim), nn.Dropout(dropout) ) def forward(self, x, return_attention=False): attn_out, attn_weights = self.attn(self.norm1(x)) x = x + attn_out x = x + self.mlp(self.norm2(x)) if return_attention: return x, attn_weights return x
|