🎙️ 语音朗读
当前: 晓晓 (温柔女声)
前言
知识蒸馏(Knowledge Distillation)是一种模型压缩技术,通过让小模型(学生网络)学习大模型(教师网络)的”知识”来提升性能。本文将深入解析知识蒸馏的原理、实现和多种变体。
知识蒸馏的核心思想
知识蒸馏的核心是用教师模型的软输出(soft predictions)来指导学生模型的学习:
1 | Loss = α × Soft_Loss + (1-α) × Hard_Loss |
其中:
- Soft_Loss:学生与教师软标签的KL散度
- Hard_Loss:学生与真实标签的交叉熵
- α:平衡参数(通常设为0.7-0.9)
1 | import torch |
温度参数的作用
温度T控制softmax的平滑程度:
- T=1:标准softmax
- T>1:更平滑的概率分布
- T→∞:接近均匀分布
- T→0:接近one-hot分布
1 | def visualize_temperature_effect(): |
特征蒸馏
除了logits蒸馏,还可以蒸馏中间层特征:
1 | class FeatureDistillationLoss(nn.Module): |
FitNet:深度特征蒸馏
FitNet让学生学习教师的中间表示:
1 | class FitNetLoss(nn.Module): |
标签平滑蒸馏
结合标签平滑和知识蒸馏:
1 | class LabelSmoothingDistillation(nn.Module): |
自蒸馏
自蒸馏使用同一模型的不同版本作为教师:
1 | class SelfDistillation(nn.Module): |
蒸馏策略对比
| 方法 | 教师输出 | 学生结构 | 适用场景 |
|---|---|---|---|
| Logit Distillation | Soft labels | 任意 | 通用 |
| Feature Distillation | 中间特征 | 相似 | 深度不同 |
| Self-Distillation | 深层特征 | 同一模型 | 无教师 |
| Multi-Teacher | 多个教师 | 任意 | 知识融合 |
实际应用
知识蒸馏在以下场景广泛应用:
- 移动端部署:压缩大模型用于手机
- 边缘计算:减少推理延迟和能耗
- 模型加速:加速推理过程
- 集成学习:融合多个模型知识
总结
知识蒸馏是模型压缩的重要技术,通过迁移教师模型的”暗知识”到学生模型,实现了在保持性能的同时大幅降低模型复杂度。