模型压缩与知识蒸馏实战
引言
深度学习模型越来越大,模型压缩成为部署到边缘设备的关键技术。知识蒸馏是其中最有效的压缩方法之一。
知识蒸馏原理
核心思想
1 2 3 4 5
| Teacher模型 → 知识 → Student模型 ↓ 软标签 ↓ 温度参数T控制软化程度
|
蒸馏损失函数
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
| class KnowledgeDistillationLoss(nn.Module): def __init__(self, temperature=4.0, alpha=0.7): super().__init__() self.T = temperature self.alpha = alpha def forward(self, student_logits, teacher_logits, labels): hard_loss = F.cross_entropy(student_logits, labels) soft_teacher = F.softmax(teacher_logits / self.T, dim=1) soft_student = F.log_softmax(student_logits / self.T, dim=1) soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') soft_loss = soft_loss * (self.T ** 2) return self.alpha * hard_loss + (1 - self.alpha) * soft_loss
|
复杂蒸馏策略
多教师蒸馏
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| class MultiTeacherDistillation: def __init__(self, teachers): self.teachers = teachers self.student = StudentModel() def distill(self, data): teacher_outputs = [] for teacher in self.teachers: with torch.no_grad(): out = teacher(data) teacher_outputs.append(out) combined_knowledge = torch.stack(teacher_outputs).mean(dim=0) student_out = self.student(data) loss = self.compute_distillation_loss(student_out, combined_knowledge) return loss
|
中间层蒸馏
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| class IntermediateDistillation: """中间层特征蒸馏""" def __init__(self, teacher, student): self.teacher = teacher self.student = student def forward(self, x): teacher_features = self.teacher.extract_features(x) student_features = self.student.extract_features(x) feat_loss = 0 for t_feat, s_feat in zip(teacher_features, student_features): feat_loss += F.mse_loss(s_feat, t_feat) return feat_loss
|
模型剪枝
结构化剪枝
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| class StructuredPruner: """结构化剪枝器""" def __init__(self, model, sparsity=0.5): self.model = model self.sparsity = sparsity def prune(self): for name, module in self.model.named_modules(): if isinstance(module, nn.Conv2d): importance = module.weight.abs().mean(dim=(0, 2, 3)) threshold = torch.quantile(importance, self.sparsity) mask = importance > threshold module.weight.data = module.weight.data * mask.view(-1, 1, 1, 1)
|
量化技术
动态量化
1 2 3 4 5 6 7 8 9 10
| class DynamicQuantizer: """动态量化""" def quantize(self, model): quantized_model = torch.quantization.quantize_dynamic( model, {nn.Linear, nn.LSTM}, dtype=torch.qint8 ) return quantized_model
|
性能对比
| 方法 |
压缩比 |
精度保持 |
速度提升 |
| 知识蒸馏 |
4-10x |
95%+ |
2-5x |
| 剪枝 |
2-10x |
90%+ |
1.5-3x |
| 量化(INT8) |
4x |
98%+ |
2-4x |
| 量化(INT4) |
8x |
90%+ |
4-8x |
总结
模型压缩技术使深度学习模型能够在资源受限的设备上高效运行。
推荐阅读:《Knowledge Distillation: A Survey》