CLIP模型原理与多模态学习
CLIP(Contrastive Language-Image Pre-training)是OpenAI提出的多模态模型,通过对比学习将视觉和语言映射到同一特征空间,开创了视觉-语言预训练的新范式。
1. CLIP的核心思想
传统计算机视觉模型需要为每个任务收集标注数据,而CLIP通过从互联网上收集的4亿对图文数据进行对比学习,实现了零样本迁移:
1 2
| 传统方案: 标注数据 → 训练模型 → 特定任务 CLIP方案: 海量图文对 → 对比预训练 → 零样本迁移到任何任务
|
2. 模型架构
CLIP由两个编码器组成:图像编码器和文本编码器。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
| import torch import torch.nn as nn
class CLIPModel(nn.Module): def __init__(self, image_encoder, text_encoder, embed_dim=512, temperature=0.07): super().__init__() self.image_encoder = image_encoder self.text_encoder = text_encoder self.image_projection = nn.Sequential( nn.Linear(image_encoder.embed_dim, embed_dim), nn.ReLU(), nn.Linear(embed_dim, embed_dim) ) self.text_projection = nn.Sequential( nn.Linear(text_encoder.embed_dim, embed_dim), nn.ReLU(), nn.Linear(embed_dim, embed_dim) ) self.logit_scale = nn.Parameter( torch.ones([]) * np.log(1 / temperature) ) def forward(self, images, texts): image_features = self.image_encoder(images) text_features = self.text_encoder(texts) image_embeddings = self.image_projection(image_features) text_embeddings = self.text_projection(text_features) image_embeddings = F.normalize(image_embeddings, dim=-1) text_embeddings = F.normalize(text_embeddings, dim=-1) logit_scale = self.logit_scale.exp() logits = logit_scale * image_embeddings @ text_embeddings.T return logits, image_embeddings, text_embeddings
|
3. 对比学习损失函数
CLIP使用对称的对比损失,让匹配的图文对相似度最高:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
| def clip_loss(logits): """ 对称对比损失 logits: [batch_size, batch_size] 相似度矩阵 对角线上是匹配的图文对,其余为不匹配的 """ labels = torch.arange(logits.shape[0], device=logits.device) loss_i2t = F.cross_entropy(logits, labels) loss_t2i = F.cross_entropy(logits.T, labels) return (loss_i2t + loss_t2i) / 2
|
4. 零样本分类
CLIP最强大的能力是零样本分类,无需任何训练数据即可完成分类任务:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
| def zero_shot_classify(model, image, class_names, template): """ 零样本图像分类 Args: model: CLIP模型 image: 输入图像 class_names: 类别名称列表,如["cat", "dog", "bird"] template: 提示模板,如"A photo of a {}" """ text_inputs = [template.format(name) for name in class_names] text_tokens = tokenizer(text_inputs) with torch.no_grad(): image_features = model.image_encoder(image) image_embeddings = model.image_projection(image_features) image_embeddings = F.normalize(image_embeddings, dim=-1) text_features = model.text_encoder(text_tokens) text_embeddings = model.text_projection(text_features) text_embeddings = F.normalize(text_embeddings, dim=-1) similarity = (image_embeddings @ text_embeddings.T).squeeze(0) probs = F.softmax(similarity * model.logit_scale.exp(), dim=-1) for name, prob in zip(class_names, probs): print(f"{name}: {prob:.2%}") return class_names[probs.argmax()]
result = zero_shot_classify( model, image, ["猫", "狗", "鸟", "鱼"], "一张{}的照片" )
|
5. 实际应用场景
5.1 图像检索
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| def image_text_retrieval(model, query_text, image_database, top_k=5): """基于文本的图像检索""" image_features = [] for img in image_database: feat = model.image_encoder(img) emb = model.image_projection(feat) image_features.append(F.normalize(emb, dim=-1)) image_features = torch.cat(image_features, dim=0) text_tokens = tokenizer([query_text]) text_feat = model.text_encoder(text_tokens) text_emb = F.normalize(model.text_projection(text_feat), dim=-1) similarities = text_emb @ image_features.T top_indices = similarities.argsort(descending=True)[:top_k] return [image_database[i] for i in top_indices]
|
5.2 图像描述生成
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| def image_captioning(model, image, candidate_captions): """从候选描述中选择最匹配的""" image_feat = model.encode_image(image) best_score = -float('inf') best_caption = None for caption in candidate_captions: text_feat = model.encode_text(tokenizer(caption)) score = (image_feat @ text_feat.T).item() if score > best_score: best_score = score best_caption = caption return best_caption
|
6. CLIP的局限性与改进方向
| 局限 |
描述 |
改进方向 |
| 细粒度识别弱 |
对子类区分能力有限 |
细粒度对比学习 |
| OCR能力差 |
难以识别图像中的文字 |
引入OCR数据 |
| 计数能力弱 |
难以准确计数物体 |
添加计数任务 |
| 分辨率限制 |
输入图像分辨率较低 |
高分辨率适配器 |
7. CLIP的后续发展
- OpenCLIP:开源复现版本
- EVA-CLIP:更强的训练策略
- SIGLIP:Sigmoid损失替代Softmax
- Chinese-CLIP:中文多模态模型
总结
CLIP通过对比学习将视觉和语言统一到同一特征空间,实现了强大的零样本迁移能力。它不仅是一个模型,更是一种新范式——用自然语言作为监督信号来训练视觉模型,这为多模态AI的发展奠定了基础。