自监督学习:SimCLR与对比学习原理

🎙️ 语音朗读 当前: 晓晓 (温柔女声)

前言

自监督学习是深度学习领域的重要研究方向,其中对比学习(Contrastive Learning)通过让模型学习相似样本的相似表示、不同样本的不同表示来学习特征表示。SimCLR是Google提出的经典对比学习框架。

对比学习的核心思想

对比学习的目标是:

  • 拉近相似样本(正样本对)的表示
  • 拉远不相似样本(负样本对)的表示
1
Loss = -log(exp(sim(z_i, z_j) / τ)) / Σ_k exp(sim(z_i, z_k) / τ)

其中τ是温度参数,sim是余弦相似度。

SimCLR框架

SimCLR的核心流程:

  1. 对图像进行随机数据增强
  2. 使用编码器网络提取特征
  3. 使用非线性投影头映射到表示空间
  4. 最大化正样本对的相似度
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

class SimCLR(nn.Module):
def __init__(self, base_encoder, projection_dim=128, feature_dim=2048):
super().__init__()

# 基础编码器(ResNet-50)
self.encoder = base_encoder

# 投影头
self.projection_head = nn.Sequential(
nn.Linear(feature_dim, 2048),
nn.ReLU(inplace=True),
nn.Linear(2048, projection_dim)
)

def forward(self, x):
# 提取特征
h = self.encoder(x) # (batch, feature_dim)

# 投影到z空间
z = self.projection_head(h) # (batch, projection_dim)

return h, z

def nt_xent_loss(z_i, z_j, temperature=0.5):
"""
NT-Xent损失(归一化温度尺度的交叉熵损失)
"""
batch_size = z_i.shape[0]

# 拼接正样本对
z = torch.cat([z_i, z_j], dim=0) # (2*batch, projection_dim)

# 计算相似度矩阵
sim = torch.mm(z, z.T) / temperature

# 创建掩码:正样本对的位置
sim_i_j = torch.diag(sim, batch_size) # z_i与z_j的相似度
sim_j_i = torch.diag(sim, -batch_size) # z_j与z_i的相似度

# 正样本对的相似度
positives = torch.cat([sim_i_j, sim_j_i], dim=0)

# 排除自身
mask = torch.eye(2 * batch_size, dtype=torch.bool)
sim.masked_fill_(mask, float('-inf'))

# 计算损失
nominator = torch.exp(positives)
denominator = torch.sum(torch.exp(sim), dim=1)

loss = -torch.log(nominator / denominator)

return loss.mean()

数据增强策略

SimCLR的数据增强是关键组件:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class SimCLRTransform:
"""SimCLR的数据增强"""

def __init__(self, size=224):
self.train_transform = transforms.Compose([
transforms.RandomResizedCrop(
size, scale=(0.08, 1.0), ratio=(0.75, 1.33)
),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])

def __call__(self, x):
# 对同一图像进行两次随机增强得到正样本对
return self.train_transform(x), self.train_transform(x)

完整的训练流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def train_simclr(model, train_loader, optimizer, device, epochs=100, temperature=0.5):
"""SimCLR训练流程"""

model = model.to(device)

for epoch in range(epochs):
total_loss = 0

for images, _ in train_loader:
# 生成正样本对
x_i, x_j = images, images # 简化处理

x_i = x_i.to(device)
x_j = x_j.to(device)

# 前向传播
_, z_i = model(x_i)
_, z_j = model(x_j)

# L2归一化
z_i = F.normalize(z_i, dim=1)
z_j = F.normalize(z_j, dim=1)

# 计算损失
loss = nt_xent_loss(z_i, z_j, temperature)

# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()

total_loss += loss.item()

print(f"Epoch {epoch+1}: Loss = {total_loss/len(train_loader):.4f}")

# 模型和优化器
from torchvision.models import resnet50

model = SimCLR(base_encoder=resnet50(pretrained=False))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

MoCo:动量对比学习

MoCo使用队列维护负样本库,通过动量编码器保持一致性:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class MoCo(nn.Module):
"""MoCo v1"""

def __init__(self, base_encoder, projection_dim=128, queue_size=65536, momentum=0.999):
super().__init__()

self.queue_size = queue_size
self.momentum = momentum

# 查询编码器
self.encoder_q = base_encoder
self.projection_q = nn.Linear(2048, projection_dim)

# 键编码器(动量更新)
self.encoder_k = base_encoder.__class__(pretrained=False)
self.projection_k = nn.Linear(2048, projection_dim)

# 初始化键编码器
for param_q, param_k in zip(
self.encoder_q.parameters(), self.encoder_k.parameters()
):
param_k.data.copy_(param_q.data)

# 负样本队列
self.register_buffer('queue', torch.randn(queue_size, projection_dim))
self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long))

@torch.no_grad()
def _momentum_update_key_encoder(self):
"""动量更新键编码器"""
for param_q, param_k in zip(
self.encoder_q.parameters(), self.encoder_k.parameters()
):
param_k.data = self.momentum * param_k.data + (1 - self.momentum) * param_q.data

@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
"""入队新样本,出队最旧样本"""
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)

self.queue[ptr:ptr+batch_size] = keys
ptr = (ptr + batch_size) % self.queue_size
self.queue_ptr[0] = ptr

def forward(self, im_q, im_k):
# 查询
q = self.encoder_q(im_q)
q = self.projection_q(q)
q = F.normalize(q, dim=1)

# 键(无梯度)
with torch.no_grad():
self._momentum_update_key_encoder()
k = self.encoder_k(im_k)
k = self.projection_k(k)
k = F.normalize(k, dim=1)

# 计算对比损失
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

logits = torch.cat([l_pos, l_neg], dim=1) / 0.07

labels = torch.zeros(logits.shape[0], dtype=torch.long, device=logits.device)
loss = F.cross_entropy(logits, labels)

# 更新队列
self._dequeue_and_enqueue(k)

return loss

BYOL和SimSiam:无负样本对比学习

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class BYOL(nn.Module):
"""BYOL:Bootstrap Your Own Latent"""

def __init__(self, base_encoder, projection_dim=4096, prediction_dim=256):
super().__init__()

self.encoder = base_encoder

# 投影头和预测头
self.projection_head = nn.Sequential(
nn.Linear(2048, projection_dim),
nn.BatchNorm1d(projection_dim),
nn.ReLU(inplace=True),
nn.Linear(projection_dim, projection_dim)
)

self.prediction_head = nn.Sequential(
nn.Linear(projection_dim, prediction_dim),
nn.BatchNorm1d(prediction_dim),
nn.ReLU(inplace=True),
nn.Linear(prediction_dim, projection_dim)
)

# 目标网络的动量编码器
self.target_encoder = base_encoder.__class__(pretrained=False)
self.target_encoder.load_state_dict(base_encoder.state_dict())

for param in self.target_encoder.parameters():
param.requires_grad = False

@torch.no_grad()
def _update_target(self, momentum=0.996):
"""更新目标网络"""
for param, target_param in zip(
self.encoder.parameters(), self.target_encoder.parameters()
):
target_param.data = momentum * target_param.data + (1-momentum) * param.data

def forward(self, x1, x2):
# 在线网络
online_z1 = self.projection_head(self.encoder(x1))
online_pred1 = self.prediction_head(online_z1)

online_z2 = self.projection_head(self.encoder(x2))
online_pred2 = self.prediction_head(online_z2)

# 目标网络
with torch.no_grad():
target_z1 = self.projection_head(self.target_encoder(x1))
target_z2 = self.projection_head(self.target_encoder(x2))

# BYOL损失
loss = 2 - 2 * (
F.cosine_similarity(online_pred1, target_z2).mean() +
F.cosine_similarity(online_pred2, target_z1).mean()
) / 2

self._update_target()

return loss

实际应用场景

对比学习在以下场景应用广泛:

  • 图像特征学习:学习可迁移的视觉表示
  • 表示学习:预训练模型的特征提取器
  • 聚类任务:无监督聚类的特征学习
  • 跨模态学习:图文匹配、声音-图像对应

总结

SimCLR开创了现代对比学习的框架,其核心思想——通过数据增强创建正样本对、使用对比损失学习表示——已被广泛应用于各类自监督学习方法中。

参考资源

© 2019-2026 ovo$^{mc^2}$ All Rights Reserved. | 站点总访问 28969 次 | 访客 19045
Theme by hiero