1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
| class MAE(nn.Module): """完整的MAE模型""" def __init__(self, img_size=224, patch_size=16, in_chans=3, encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4.0, mask_ratio=0.75): super().__init__() self.mask_ratio = mask_ratio self.encoder = MAEEncoder( img_size, patch_size, in_chans, encoder_embed_dim, encoder_depth, encoder_num_heads, mlp_ratio ) num_patches = (img_size // patch_size) ** 2 self.decoder = MAEDecoder( num_patches, encoder_embed_dim, decoder_embed_dim, decoder_depth, decoder_num_heads, patch_size ) self.patch_size = patch_size def random_masking(self, x): """随机掩码""" N, L, D = x.shape len_keep = int(L * (1 - self.mask_ratio)) noise = torch.rand(N, L) ids_shuffle = torch.argsort(noise, dim=1) ids_restore = torch.argsort(ids_shuffle, dim=1) ids_keep = ids_shuffle[:, :len_keep] x_masked = torch.gather( x, dim=1, index=ids_keep.unsqueeze(-1).expand(-1, -1, D) ) return x_masked, ids_restore def forward(self, imgs): x = self.encoder.patch_embed(imgs) x, ids_restore = self.random_masking(x) x = self.encoder(x) pred = self.decoder(x, ids_restore) return pred def patchify(self, imgs): """将图像转换为patches""" p = self.patch_size assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 C = imgs.shape[1] num_patches_per_side = imgs.shape[2] // p x = imgs.reshape(shape=(imgs.shape[0], C, num_patches_per_side, p, num_patches_per_side, p)) x = torch.einsum('nchpwq->nhwpqc', x) x = x.reshape(shape=(imgs.shape[0], num_patches_per_side * num_patches_per_side, p*p*C)) return x def unpatchify(self, x): """将patches转换回图像""" p = self.patch_size h = w = int(x.shape[1] ** 0.5) C = 3 assert h * w == x.shape[1] x = x.reshape(shape=(x.shape[0], h, w, p, p, C)) x = torch.einsum('nhwpqc->nchpwq', x) imgs = x.reshape(shape=(x.shape[0], C, h * p, h * p)) return imgs
|