深度学习优化器:从SGD到AdamW

🎙️ 语音朗读 当前: 晓晓 (温柔女声)

前言

优化器是深度学习训练的核心组件,从经典的SGD到现代的AdamW,优化器技术经历了快速发展。本文系统介绍各类优化器的原理、实现和使用场景。

经典SGD

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import torch.nn as nn
import torch.nn.functional as F

class SGD:
"""随机梯度下降"""

def __init__(self, params, lr=0.01, momentum=0.0, dampening=0, weight_decay=0):
self.params = list(params)
self.lr = lr
self.momentum = momentum
self.dampening = dampening
self.weight_decay = weight_decay

def step(self):
for p in self.params:
if p.grad is None:
continue

grad = p.grad.data

# 权重衰减
if self.weight_decay != 0:
grad = grad + self.weight_decay * p.data

# 动量
if self.momentum != 0:
if 'momentum_buffer' not in self.state[p]:
self.state[p]['momentum_buffer'] = grad.clone()
else:
buf = self.state[p]['momentum_buffer']
buf = self.momentum * buf + (1 - self.dampening) * grad
self.state[p]['momentum_buffer'] = buf
grad = buf

p.data = p.data - self.lr * grad

def zero_grad(self):
for p in self.params:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()

# 使用示例
model = nn.Linear(10, 2)
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)

for inputs, targets in dataloader:
optimizer.zero_grad()
outputs = model(inputs)
loss = F.cross_entropy(outputs, targets)
loss.backward()
optimizer.step()

##带动量的SGD

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class SGDMomentum:
"""带动量的SGD (Nesterov)"""

def __init__(self, params, lr=0.01, momentum=0.9,
weight_decay=1e-4, nesterov=True):
self.params = list(params)
self.lr = lr
self.momentum = momentum
self.weight_decay = weight_decay
self.nesterov = nesterov
self.state = {}

def step(self):
for p in self.params:
if p.grad is None:
continue

grad = p.grad.data

# 权重衰减(独立于动量)
if self.weight_decay != 0:
grad = grad.add(p.data, alpha=self.weight_decay)

# 初始化动量缓冲区
if id(p) not in self.state:
self.state[id(p)] = torch.zeros_like(p.data)

buf = self.state[id(p)]

# Nesterov动量
if self.nesterov:
buf.mul_(self.momentum).add_(grad)
grad = grad.add(buf, alpha=self.momentum)
else:
buf.mul_(self.momentum).add_(grad)
grad = buf

p.data.add_(grad, alpha=-self.lr)

Adam优化器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class Adam:
"""Adam优化器"""

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999),
eps=1e-8, weight_decay=0, amsgrad=False):
self.params = list(params)
self.lr = lr
self.beta1, self.beta2 = betas
self.eps = eps
self.weight_decay = weight_decay
self.amsgrad = amsgrad

self.state = {}
self.step_count = 0

def step(self):
self.step_count += 1

for p in self.params:
if p.grad is None:
continue

grad = p.grad.data

# 获取状态
if id(p) not in self.state:
self.state[id(p)] = {
'exp_avg': torch.zeros_like(p.data),
'exp_avg_sq': torch.zeros_like(p.data),
'max_exp_avg_sq': torch.zeros_like(p.data) if self.amsgrad else None
}

state = self.state[id(p)]
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

# 权重衰减(L2正则化)
if self.weight_decay != 0:
p.data.add_(p.data, alpha=-self.lr * self.weight_decay)

# 更新一阶矩估计
exp_avg.mul_(self.beta1).add_(grad, alpha=1 - self.beta1)

# 更新二阶矩估计
exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad, value=1 - self.beta2)

# bias correction
bias_correction1 = 1 - self.beta1 ** self.step_count
bias_correction2 = 1 - self.beta2 ** self.step_count

# AMSGrad
if self.amsgrad:
max_exp_avg_sq = state['max_exp_avg_sq']
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(self.eps)
else:
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(self.eps)

step_size = self.lr / bias_correction1
p.data.addcdiv_(exp_avg, denom, value=-step_size)

def zero_grad(self):
for p in self.params:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()

AdamW优化器

AdamW是Adam的改进版本,正确处理权重衰减:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class AdamW:
"""AdamW优化器(解耦权重衰减)"""

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999),
eps=1e-8, weight_decay=0.01):
self.params = list(params)
self.lr = lr
self.beta1, self.beta2 = betas
self.eps = eps
self.weight_decay = weight_decay

self.state = {}
self.step_count = 0

def step(self):
self.step_count += 1

for p in self.params:
if p.grad is None:
continue

grad = p.grad.data.clone()

# 获取状态
if id(p) not in self.state:
self.state[id(p)] = {
'exp_avg': torch.zeros_like(p.data),
'exp_avg_sq': torch.zeros_like(p.data)
}

state = self.state[id(p)]
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

# 更新一阶矩估计
exp_avg.mul_(self.beta1).add_(grad, alpha=1 - self.beta1)

# 更新二阶矩估计
exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad, value=1 - self.beta2)

# bias correction
bias_correction1 = 1 - self.beta1 ** self.step_count
bias_correction2 = 1 - self.beta2 ** self.step_count

denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(self.eps)
step_size = self.lr / bias_correction1

# 解耦权重衰减
p.data.add_(p.data, alpha=-self.lr * self.weight_decay)
p.data.addcdiv_(exp_avg, denom, value=-step_size)

def zero_grad(self):
for p in self.params:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()

学习率调度

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class CosineAnnealingWarmRestarts:
"""余弦退火热重启"""

def __init__(self, optimizer, T_0, T_mult=1, eta_min=0):
self.optimizer = optimizer
self.T_0 = T_0
self.T_i = T_0
self.T_mult = T_mult
self.eta_min = eta_min
self.T_cur = 0

def step(self):
for param_group in self.optimizer.params:
lr = self.eta_min + (param_group['lr'] - self.eta_min) * \
(1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2
param_group['lr'] = lr

self.T_cur += 1
if self.T_cur >= self.T_i:
self.T_cur = 0
self.T_i *= self.T_mult

class OneCycleLR:
"""单周期学习率"""

def __init__(self, optimizer, max_lr, pct_start=0.3, div_factor=25, final_div_factor=1e4):
self.optimizer = optimizer
self.max_lr = max_lr
self.pct_start = pct_start
self.div_factor = div_factor
self.final_div_factor = final_div_factor
self.step_count = 0
self.total_steps = None

def step(self, batch_idx, total_batches):
if self.total_steps is None:
self.total_steps = total_batches

step_pct = self.step_count / self.total_steps

if step_pct <= self.pct_start:
# 上升阶段
lr_mult = step_pct / self.pct_start
else:
# 下降阶段
lr_mult = 1 - (step_pct - self.pct_start) / (1 - self.pct_start)

min_lr = self.max_lr / self.div_factor
final_lr = self.max_lr / self.final_div_factor

lr = (self.max_lr - min_lr) * lr_mult + min_lr

for param_group in self.optimizer.params:
param_group['lr'] = lr

self.step_count += 1

优化器对比与选择

优化器 优点 缺点 适用场景
SGD 泛化好、稳定 收敛慢 图像分类
SGD+Momentum 加速收敛 需要调参 CNN训练
Adam 收敛快 可能泛化差 NLP、小数据
AdamW 正则化效果好 需调参 Transformer
LAMB 大batch友好 实现复杂 大规模训练
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# 实际训练示例
def train_with_optimizer(model, train_loader, optimizer_type='adamw', lr=1e-4):
if optimizer_type == 'sgd':
optimizer = torch.optim.SGD(
model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=len(train_loader) * 90
)
elif optimizer_type == 'adam':
optimizer = torch.optim.Adam(
model.parameters(), lr=lr, betas=(0.9, 0.999)
)
elif optimizer_type == 'adamw':
optimizer = torch.optim.AdamW(
model.parameters(), lr=lr, weight_decay=0.01
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=len(train_loader) * 10
)
elif optimizer_type == 'lamb':
optimizer = torch.optim.Lamb(
model.parameters(), lr=lr, weight_decay=0.01
)

for epoch in range(num_epochs):
model.train()
for batch in train_loader:
optimizer.zero_grad()
loss = compute_loss(model, batch)
loss.backward()

# 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

optimizer.step()
scheduler.step()

总结

选择合适的优化器和学习率调度策略对模型训练至关重要。SGD在小数据集和图像任务上泛化性好,Adam系列在NLP任务上收敛更快,AdamW通过解耦权重衰减在Transformer训练中表现优异。

参考资源

© 2019-2026 ovo$^{mc^2}$ All Rights Reserved. | 站点总访问 28969 次 | 访客 19045
Theme by hiero