MAE:掩码自编码器的突破

🎙️ 语音朗读 当前: 晓晓 (温柔女声)

前言

MAE(Masked AutoEncoder)是何恺明团队2021年提出的自监督学习算法,在图像表示学习领域取得了突破性成果。本文将深入解析MAE的核心原理和实现细节。

MAE的核心思想

MAE采用了类似NLP中BERT的掩码思想:

  • 掩码:随机遮挡输入图像的大部分区域(如75%)
  • 重建:让模型重建被遮挡的像素值
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms

class PatchEmbed(nn.Module):
"""图像分块嵌入"""

def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2

# 使用卷积进行分块+嵌入
self.proj = nn.Conv2d(
in_chans, embed_dim,
kernel_size=patch_size, stride=patch_size
)

def forward(self, x):
# x: (B, C, H, W)
x = self.proj(x) # (B, embed_dim, H/P, W/P)
x = x.flatten(2) # (B, embed_dim, num_patches)
x = x.transpose(1, 2) # (B, num_patches, embed_dim)
return x

class MAEEncoder(nn.Module):
"""MAE编码器(ViT)"""

def __init__(self, img_size=224, patch_size=16, in_chans=3,
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0):
super().__init__()

self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
self.num_patches = self.patch_embed.num_patches

# 可学习的类别token和位置编码
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(
torch.zeros(1, self.num_patches + 1, embed_dim)
)

# Transformer编码器块
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads, mlp_ratio)
for _ in range(depth)
])

self.norm = nn.LayerNorm(embed_dim)

def forward(self, x, mask=None):
# 分块嵌入
x = self.patch_embed(x)

# 添加位置编码
x = x + self.pos_embed[:, 1:, :]

# 应用掩码(如果提供)
if mask is not None:
x = x[~mask].unsqueeze(0)

# 添加cls token
cls_tokens = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_tokens.expand(x.size(0), -1, -1)
x = torch.cat([cls_tokens, x], dim=1)

# 通过Transformer块
for block in self.blocks:
x = block(x)

return self.norm(x)

class TransformerBlock(nn.Module):
"""Transformer编码器块"""

def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1):
super().__init__()

self.norm1 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout)
self.norm2 = nn.LayerNorm(embed_dim)

mlp_hidden_dim = int(embed_dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, mlp_hidden_dim),
nn.GELU(),
nn.Linear(mlp_hidden_dim, embed_dim),
nn.Dropout(dropout)
)

def forward(self, x):
# 注意力残差连接
x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
# MLP残差连接
x = x + self.mlp(self.norm2(x))
return x

MAE解码器

解码器负责重建被掩码的patch:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class MAEDecoder(nn.Module):
"""MAE解码器"""

def __init__(self, num_patches, embed_dim=512, decoder_embed_dim=256,
decoder_depth=8, decoder_num_heads=8, patch_size=16):
super().__init__()

self.num_patches = num_patches
self.patch_size = patch_size
self.decoder_embed_dim = decoder_embed_dim

# 解码器嵌入层
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim)

# 解码器位置编码
self.decoder_pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, decoder_embed_dim)
)

# 解码器块
self.decoder_blocks = nn.ModuleList([
TransformerBlock(decoder_embed_dim, decoder_num_heads)
for _ in range(decoder_depth)
])

self.decoder_norm = nn.LayerNorm(decoder_embed_dim)

# 预测头
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * 3)

def forward(self, x, ids_restore):
# 解码器嵌入
x = self.decoder_embed(x)

# 添加位置编码
x = x + self.decoder_pos_embed

# 重建掩码token
mask_tokens = self._get_mask_tokens(ids_restore)

# 解码器处理
for block in self.decoder_blocks:
x = block(x)

x = self.decoder_norm(x)

# 预测像素值
pred = self.decoder_pred(x[:, 1:, :]) # 去掉cls token

return pred

def _get_mask_tokens(self, ids_restore):
"""获取掩码位置的token"""
mask_tokens = torch.zeros(
1, ids_restore.size(1), self.decoder_embed_dim,
device=ids_restore.device
)
return mask_tokens

完整MAE模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
class MAE(nn.Module):
"""完整的MAE模型"""

def __init__(self, img_size=224, patch_size=16, in_chans=3,
encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4.0, mask_ratio=0.75):
super().__init__()

self.mask_ratio = mask_ratio

# 编码器
self.encoder = MAEEncoder(
img_size, patch_size, in_chans,
encoder_embed_dim, encoder_depth, encoder_num_heads, mlp_ratio
)

# 解码器
num_patches = (img_size // patch_size) ** 2
self.decoder = MAEDecoder(
num_patches, encoder_embed_dim, decoder_embed_dim,
decoder_depth, decoder_num_heads, patch_size
)

self.patch_size = patch_size

def random_masking(self, x):
"""随机掩码"""
N, L, D = x.shape
len_keep = int(L * (1 - self.mask_ratio))

# 随机噪声
noise = torch.rand(N, L)

# 升序排列得到索引
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)

# 保留部分
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(
x, dim=1, index=ids_keep.unsqueeze(-1).expand(-1, -1, D)
)

return x_masked, ids_restore

def forward(self, imgs):
# 编码
x = self.encoder.patch_embed(imgs)
x, ids_restore = self.random_masking(x)
x = self.encoder(x)

# 解码
pred = self.decoder(x, ids_restore)

return pred

def patchify(self, imgs):
"""将图像转换为patches"""
p = self.patch_size
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0

C = imgs.shape[1]
num_patches_per_side = imgs.shape[2] // p

x = imgs.reshape(shape=(imgs.shape[0], C, num_patches_per_side, p,
num_patches_per_side, p))
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(shape=(imgs.shape[0], num_patches_per_side * num_patches_per_side, p*p*C))

return x

def unpatchify(self, x):
"""将patches转换回图像"""
p = self.patch_size
h = w = int(x.shape[1] ** 0.5)
C = 3

assert h * w == x.shape[1]

x = x.reshape(shape=(x.shape[0], h, w, p, p, C))
x = torch.einsum('nhwpqc->nchpwq', x)

imgs = x.reshape(shape=(x.shape[0], C, h * p, h * p))

return imgs

MAE训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def mae_pretrain(model, train_loader, optimizer, device, epochs=800):
"""MAE预训练"""
model = model.to(device)

for epoch in range(epochs):
model.train()
total_loss = 0

for imgs, _ in train_loader:
imgs = imgs.to(device)

# 前向传播
pred = model(imgs)

# 获取真实patches
patches = model.patchify(imgs)

# 计算重建损失
loss = F.mse_loss(pred, patches)

# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()

total_loss += loss.item()

print(f"Epoch {epoch+1}: Loss = {total_loss/len(train_loader):.4f}")

MAE的关键设计

  1. 高掩码率:75%的掩码率效果最好
  2. 非对称架构:解码器比编码器更轻量
  3. 直接预测像素:无需额外的patch projection
  4. ViT作为编码器:利用Transformer处理图像

实际应用

MAE可用于:

  • 图像特征学习:学习可迁移的视觉表示
  • 下游任务微调:分类、检测、分割
  • 视频理解:时空掩码建模
  • 医学图像:自监督预训练

总结

MAE证明了掩码自编码器在视觉领域的强大能力,其简单而有效的设计为自监督学习提供了新的思路,成为视觉表示学习的重要里程碑。

参考资源

© 2019-2026 ovo$^{mc^2}$ All Rights Reserved. | 站点总访问 28969 次 | 访客 19045
Theme by hiero