知识蒸馏:模型压缩与知识迁移

🎙️ 语音朗读 当前: 晓晓 (温柔女声)

前言

知识蒸馏(Knowledge Distillation)是一种模型压缩技术,通过让小模型(学生网络)学习大模型(教师网络)的”知识”来提升性能。本文将深入解析知识蒸馏的原理、实现和多种变体。

知识蒸馏的核心思想

知识蒸馏的核心是用教师模型的软输出(soft predictions)来指导学生模型的学习:

1
Loss = α × Soft_Loss + (1-α) × Hard_Loss

其中:

  • Soft_Loss:学生与教师软标签的KL散度
  • Hard_Loss:学生与真实标签的交叉熵
  • α:平衡参数(通常设为0.7-0.9)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torch.nn as nn
import torch.nn.functional as F

class KnowledgeDistillationLoss(nn.Module):
"""知识蒸馏损失"""

def __init__(self, temperature=4.0, alpha=0.7):
super().__init__()
self.temperature = temperature
self.alpha = alpha

def forward(self, student_logits, teacher_logits, labels):
"""
Args:
student_logits: 学生模型输出
teacher_logits: 教师模型输出
labels: 真实标签
"""
# 软损失:KL散度
soft_teacher = F.softmax(teacher_logits / self.temperature, dim=1)
soft_student = F.log_softmax(student_logits / self.temperature, dim=1)
soft_loss = F.kl_div(
soft_student, soft_teacher, reduction='batchmean'
) * (self.temperature ** 2)

# 硬损失:交叉熵
hard_loss = F.cross_entropy(student_logits, labels)

# 组合损失
total_loss = self.alpha * soft_loss + (1 - self.alpha) * hard_loss

return total_loss

def distill_knowledge(teacher, student, train_loader, optimizer, epochs=100):
"""知识蒸馏训练流程"""
distillation_loss = KnowledgeDistillationLoss(temperature=4.0, alpha=0.7)

teacher.eval() # 教师模型不训练
student.train()

for epoch in range(epochs):
total_loss = 0

for inputs, labels in train_loader:
with torch.no_grad():
teacher_outputs = teacher(inputs)

student_outputs = student(inputs)

loss = distillation_loss(student_outputs, teacher_outputs, labels)

optimizer.zero_grad()
loss.backward()
optimizer.step()

total_loss += loss.item()

print(f"Epoch {epoch+1}: Loss = {total_loss/len(train_loader):.4f}")

温度参数的作用

温度T控制softmax的平滑程度:

  • T=1:标准softmax
  • T>1:更平滑的概率分布
  • T→∞:接近均匀分布
  • T→0:接近one-hot分布
1
2
3
4
5
6
7
8
9
10
11
def visualize_temperature_effect():
"""展示不同温度的效果"""
logits = torch.tensor([[2.0, 1.0, 0.5], [4.0, 3.0, 2.0]])

temperatures = [0.5, 1.0, 2.0, 4.0, 10.0]

for T in temperatures:
probs = F.softmax(logits / T, dim=1)
print(f"Temperature={T}:")
print(f" Sample 1: {probs[0].tolist()}")
print(f" Sample 2: {probs[1].tolist()}")

特征蒸馏

除了logits蒸馏,还可以蒸馏中间层特征:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class FeatureDistillationLoss(nn.Module):
"""特征蒸馏损失"""

def __init__(self, loss_type='mse'):
super().__init__()
self.loss_type = loss_type

if loss_type == 'mse':
self.criterion = nn.MSELoss()
elif loss_type == 'l1':
self.criterion = nn.L1Loss()
elif loss_type == 'cosine':
self.criterion = nn.CosineEmbeddingLoss()

def forward(self, student_features, teacher_features, T=2.0):
if self.loss_type == 'mse':
# L2损失
return self.criterion(student_features, teacher_features)

elif self.loss_type == 'l1':
# L1损失
return self.criterion(student_features, teacher_features)

elif self.loss_type == 'cosine':
# 余弦损失
target = torch.ones(student_features.size(0)).to(student_features.device)
return self.criterion(student_features, teacher_features, target)

class AttentionTransfer(nn.Module):
"""注意力迁移"""

def forward(self, student_maps, teacher_maps, T=2.0):
"""
基于激活的注意力迁移
"""
def attention_map(x):
return torch.sum(x ** 2, dim=1)

student_attn = attention_map(student_maps)
teacher_attn = attention_map(teacher_maps)

return F.mse_loss(student_attn, teacher_attn)

FitNet:深度特征蒸馏

FitNet让学生学习教师的中间表示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class FitNetLoss(nn.Module):
"""FitNet:深度监督蒸馏"""

def __init__(self, hint_layer=None):
super().__init__()
self.hint_layer = hint_layer

def forward(self, student_features, teacher_features,
student_logits, teacher_logits, labels):
# 特征适配器
if student_features.shape != teacher_features.shape:
student_features = self.adapt_features(student_features)

# 特征蒸馏损失
feat_loss = F.mse_loss(student_features, teacher_features)

# logits蒸馏损失
logit_loss = F.cross_entropy(student_logits, labels)

return feat_loss + 0.01 * logit_loss

class HintLayer(nn.Module):
"""特征适配层"""

def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 1)

def forward(self, x):
return self.conv(x)

标签平滑蒸馏

结合标签平滑和知识蒸馏:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class LabelSmoothingDistillation(nn.Module):
"""带标签平滑的知识蒸馏"""

def __init__(self, temperature=4.0, alpha=0.7, label_smoothing=0.1):
super().__init__()
self.temperature = temperature
self.alpha = alpha
self.label_smoothing = label_smoothing

def forward(self, student_logits, teacher_logits, labels, num_classes):
# 软损失
soft_teacher = F.softmax(teacher_logits / self.temperature, dim=1)
soft_student = F.log_softmax(student_logits / self.temperature, dim=1)
soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean')

# 标签平滑
with torch.no_grad():
smooth_labels = torch.zeros_like(soft_teacher)
smooth_labels.fill_(self.label_smoothing / (num_classes - 1))
smooth_labels.scatter_(1, labels.unsqueeze(1), 1 - self.label_smoothing)

# 硬损失
hard_loss = (-smooth_labels * soft_student).sum(dim=1).mean()

return self.alpha * soft_loss + (1 - self.alpha) * hard_loss

自蒸馏

自蒸馏使用同一模型的不同版本作为教师:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class SelfDistillation(nn.Module):
"""自蒸馏"""

def __init__(self, model, num_classes=1000):
super().__init__()
self.model = model
self.classifier = nn.Linear(model.feature_dim, num_classes)

# 辅助分类器
self.aux_classifier1 = nn.Linear(512, num_classes)
self.aux_classifier2 = nn.Linear(1024, num_classes)

def forward(self, x, layer1_features, layer2_features):
# 主分类器
main_output = self.classifier(self.model(x))

# 辅助蒸馏
if self.training:
aux_loss1 = F.cross_entropy(self.aux_classifier1(layer1_features), main_output.argmax(dim=1))
aux_loss2 = F.cross_entropy(self.aux_classifier2(layer2_features), main_output.argmax(dim=1))
return main_output, aux_loss1 + aux_loss2

return main_output

蒸馏策略对比

方法 教师输出 学生结构 适用场景
Logit Distillation Soft labels 任意 通用
Feature Distillation 中间特征 相似 深度不同
Self-Distillation 深层特征 同一模型 无教师
Multi-Teacher 多个教师 任意 知识融合

实际应用

知识蒸馏在以下场景广泛应用:

  • 移动端部署:压缩大模型用于手机
  • 边缘计算:减少推理延迟和能耗
  • 模型加速:加速推理过程
  • 集成学习:融合多个模型知识

总结

知识蒸馏是模型压缩的重要技术,通过迁移教师模型的”暗知识”到学生模型,实现了在保持性能的同时大幅降低模型复杂度。

参考资源

© 2019-2026 ovo$^{mc^2}$ All Rights Reserved. | 站点总访问 28969 次 | 访客 19045
Theme by hiero