CLIP:对比语言图像预训练

🎙️ 语音朗读 当前: 晓晓 (温柔女声)

前言

CLIP(Contrastive Language-Image Pre-training)由OpenAI在2021年发布,通过大规模图文对比学习实现了零样本图像分类,展现了多模态学习的强大能力。

CLIP核心思想

1
2
3
4
5
6
7
8
9
mermaid
graph TB
A[图像] --> B[图像Encoder]
A --> C[图像特征]
C --> F[相似度计算]
D[文本] --> E[文本Encoder]
D --> G[文本特征]
G --> F
F --> H[对比损失]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms

class CLIP(nn.Module):
"""CLIP模型"""

def __init__(self, image_encoder, text_encoder, embed_dim=512):
super().__init__()
self.image_encoder = image_encoder
self.text_encoder = text_encoder
self.image_proj = nn.Linear(image_encoder.embed_dim, embed_dim)
self.text_proj = nn.Linear(text_encoder.embed_dim, embed_dim)

# 温度参数
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

def encode_image(self, images):
"""编码图像"""
features = self.image_encoder(images)
features = self.image_proj(features)
return F.normalize(features, dim=-1)

def encode_text(self, texts):
"""编码文本"""
features = self.text_encoder(texts)
features = self.text_proj(features)
return F.normalize(features, dim=-1)

def forward(self, images, texts):
"""前向传播"""
# 编码
image_features = self.encode_image(images)
text_features = self.encode_text(texts)

# 计算相似度
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logit_scale * text_features @ image_features.t()

return logits_per_image, logits_per_text

CLIP损失函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class CLIPLoss(nn.Module):
"""CLIP对比损失"""

def __init__(self):
super().__init__()

def forward(self, image_features, text_features):
"""
计算对称对比损失
"""
# 计算相似度矩阵
logits = image_features @ text_features.t()

# 温度缩放
temperature = 0.1
logits = logits / temperature

# 标签:对角线为正样本
batch_size = image_features.shape[0]
labels = torch.arange(batch_size, device=image_features.device)

# 图像到文本的损失
loss_i2t = F.cross_entropy(logits, labels)

# 文本到图像的损失
loss_t2i = F.cross_entropy(logits.t(), labels)

# 总损失
loss = (loss_i2t + loss_t2i) / 2

return loss

def clip_loss(logits_per_image, logits_per_text):
"""CLIP损失函数"""
batch_size = logits_per_image.shape[0]

# 对角线为正样本
labels = torch.arange(batch_size, device=logits_per_image.device)

# 对称损失
loss_i = F.cross_entropy(logits_per_image, labels)
loss_t = F.cross_entropy(logits_per_text, labels)

return (loss_i + loss_t) / 2

图像编码器(ResNet)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class ImageEncoder(nn.Module):
"""图像编码器(ResNet)"""

def __init__(self, embed_dim=512):
super().__init__()

# 使用ResNet50作为骨干
from torchvision.models import resnet50, ResNet50_Weights
backbone = resnet50(weights=ResNet50_Weights.DEFAULT)

# 去掉最后的分类层
self.backbone = nn.Sequential(*list(backbone.children())[:-1])

# 投影层
self.projection = nn.Linear(2048, embed_dim)
self.embed_dim = 2048 # ResNet50特征维度

def forward(self, x):
features = self.backbone(x) # (B, 2048, 1, 1)
features = features.flatten(1) # (B, 2048)
features = self.projection(features) # (B, embed_dim)
return features

class ImageEncoderViT(nn.Module):
"""ViT图像编码器"""

def __init__(self, embed_dim=512):
super().__init__()
from torchvision.models import vit_b_16, ViT_B_16_Weights
backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)

self.backbone = backbone
self.projection = nn.Linear(768, embed_dim)
self.embed_dim = 768

def forward(self, x):
features = self.backbone(x)
features = self.projection(features)
return features

文本编码器(Transformer)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class TextEncoder(nn.Module):
"""文本编码器"""

def __init__(self, vocab_size=49408, embed_dim=512,
context_length=77, transformer_width=512):
super().__init__()

self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(
torch.empty(context_length, transformer_width)
)

encoder_layer = nn.TransformerEncoderLayer(
d_model=transformer_width,
nhead=8,
dim_feedforward=2048,
activation='gelu',
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=12)

self.ln_final = nn.LayerNorm(transformer_width)
self.projection = nn.Linear(transformer_width, embed_dim)

self.context_length = context_length
self.embed_dim = transformer_width

def forward(self, text):
"""
Args:
text: tokenized text (B, context_length)
"""
x = self.token_embedding(text) # (B, L, D)
x = x + self.positional_embedding[:x.size(1)]

# Transformer处理
x = self.transformer(x)
x = self.ln_final(x)

# 使用[EOS]位置的表示
x = x[:, -1, :] # 简化的EOS表示

# 投影
x = self.projection(x)

return x

零样本分类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class ZeroShotClassifier:
"""零样本图像分类"""

def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer

def classify(self, image, class_names):
"""
零样本分类

Args:
image: 输入图像
class_names: 类别名称列表,如["cat", "dog", "bird"]
Returns:
预测的类别索引和置信度
"""
# 编码图像
with torch.no_grad():
image_features = self.model.encode_image(image)

# 编码文本提示
text_descriptions = [f"a photo of a {name}" for name in class_names]
text_tokens = self.tokenizer(text_descriptions, padding=True, return_tensors='pt')

with torch.no_grad():
text_features = self.model.encode_text(text_tokens)

# 计算相似度
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)

similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)

# 获取最高置信度的类别
values, indices = similarity[0].topk(1)

return indices[0].item(), values[0].item()

def zero_shot_classification(model, image, class_names, tokenizer):
"""零样本分类函数"""
classifier = ZeroShotClassifier(model, tokenizer)
pred_idx, confidence = classifier.classify(image, class_names)
return class_names[pred_idx], confidence

# 使用示例
class_names = ["cat", "dog", "bird", "fish", "horse"]
image = preprocess_image("test.jpg")

predicted_class, confidence = zero_shot_classification(model, image, class_names, tokenizer)
print(f"Predicted: {predicted_class}, Confidence: {confidence:.4f}")

CLIP训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def train_clip(model, train_loader, optimizer, epochs=32, device='cuda'):
"""CLIP训练"""
model = model.to(device)
clip_loss = CLIPLoss()

for epoch in range(epochs):
total_loss = 0
for batch_idx, (images, texts) in enumerate(train_loader):
images = images.to(device)

# 前向传播
logits_per_image, logits_per_text = model(images, texts)

# 计算损失
loss = clip_loss(
model.encode_image(images),
model.encode_text(texts)
)

# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()

total_loss += loss.item()

if batch_idx % 100 == 0:
print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}")

avg_loss = total_loss / len(train_loader)
print(f"Epoch {epoch+1} Average Loss: {avg_loss:.4f}")

CLIP应用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class CLIPImageSearch:
"""CLIP图像搜索"""

def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.image_features = []
self.image_paths = []

def add_to_index(self, image, path):
"""添加图像到索引"""
with torch.no_grad():
features = self.model.encode_image(image)
features = F.normalize(features, dim=-1)

self.image_features.append(features.cpu())
self.image_paths.append(path)

def search(self, query, top_k=5):
"""搜索最相似的图像"""
# 编码查询
text_tokens = self.tokenizer([query], return_tensors='pt')

with torch.no_grad():
query_features = self.model.encode_text(text_tokens)
query_features = F.normalize(query_features, dim=-1)

# 计算相似度
image_features = torch.cat(self.image_features)
similarities = (query_features @ image_features.t()).squeeze()

# 获取top-k
values, indices = similarities.topk(top_k)

results = [(self.image_paths[idx], values[i].item())
for i, idx in enumerate(indices)]

return results

CLIP局限性

  • 对细粒度分类可能不准确
  • 需要大量训练数据
  • 文本提示工程重要
  • 对分布外数据泛化有限

总结

CLIP开创了大规模图文对比学习的先河,其零样本分类能力和跨模态理解能力为多模态AI奠定了基础,后续的BLIP、ALIGN等模型在此基础上进一步发展。

参考资源

© 2019-2026 ovo$^{mc^2}$ All Rights Reserved. | 站点总访问 28969 次 | 访客 19045
Theme by hiero