ResNet进阶:深度残差网络的技术演进

🎙️ 语音朗读 当前: 晓晓 (温柔女声)

前言

ResNet(Residual Network)由何恺明等人于2015年提出,通过残差连接解决了深层网络的梯度消失问题,成为计算机视觉领域最具影响力的架构之一。

ResNet的核心创新

ResNet的核心思想是残差学习。传统网络学习的是底层到高层的映射 H(x),而ResNet学习残差 F(x) = H(x) - x:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock(nn.Module):
"""ResNet基本残差块"""
expansion = 1

def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(BasicBlock, self).__init__()

self.conv1 = nn.Conv2d(
in_channels, out_channels,
kernel_size=3, stride=stride, padding=1, bias=False
)
self.bn1 = nn.BatchNorm2d(out_channels)

self.conv2 = nn.Conv2d(
out_channels, out_channels,
kernel_size=3, stride=1, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(out_channels)

self.downsample = downsample
self.relu = nn.ReLU(inplace=True)

def forward(self, x):
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)

if self.downsample is not None:
identity = self.downsample(x)

out += identity # 残差连接
out = self.relu(out)

return out

class Bottleneck(nn.Module):
"""用于ResNet50/101/152的Bottleneck块"""
expansion = 4

def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(Bottleneck, self).__init__()

self.conv1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)

self.conv2 = nn.Conv2d(
out_channels, out_channels, 3, stride, 1, bias=False
)
self.bn2 = nn.BatchNorm2d(out_channels)

self.conv3 = nn.Conv2d(
out_channels, out_channels * self.expansion, 1, bias=False
)
self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)

self.relu = nn.ReLU(inplace=True)
self.downsample = downsample

def forward(self, x):
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample is not None:
identity = self.downsample(x)

out += identity
out = self.relu(out)

return out

ResNet架构实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class ResNet(nn.Module):
"""完整的ResNet模型"""

def __init__(self, block, layers, num_classes=1000):
super(ResNet, self).__init__()

self.in_channels = 64

# 初始卷积层
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

# 残差层
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

# 全局平均池化和分类器
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)

# 参数初始化
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

def _make_layer(self, block, out_channels, blocks, stride=1):
downsample = None
if stride != 1 or self.in_channels != out_channels * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.in_channels, out_channels * block.expansion,
kernel_size=1, stride=stride, bias=False
),
nn.BatchNorm2d(out_channels * block.expansion)
)

layers = []
layers.append(block(self.in_channels, out_channels, stride, downsample))
self.in_channels = out_channels * block.expansion

for _ in range(1, blocks):
layers.append(block(self.in_channels, out_channels))

return nn.Sequential(*layers)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)

x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)

return x

# 不同深度的ResNet
def resnet18(num_classes=1000):
return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)

def resnet34(num_classes=1000):
return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)

def resnet50(num_classes=1000):
return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)

def resnet101(num_classes=1000):
return ResNet(Bottleneck, [3, 4, 23, 3], num_classes)

ResNet的训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# 数据增强
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# 训练配置
def train_resnet(model, train_loader, num_epochs=90):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
model.parameters(),
lr=0.1, momentum=0.9, weight_decay=1e-4
)

# 学习率调度
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[30, 60, 90], gamma=0.1
)

for epoch in range(num_epochs):
model.train()
total_loss = 0
correct = 0
total = 0

for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)

optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

total_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()

scheduler.step()

print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}, "
f"Acc={100.*correct/total:.2f}%")

ResNet的变体

1. Pre-activation ResNet

将BN和ReLU放在卷积之前:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class PreActBasicBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super().__init__()

self.bn1 = nn.BatchNorm2d(in_channels)
self.relu1 = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False)

self.bn2 = nn.BatchNorm2d(out_channels)
self.relu2 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)

self.downsample = downsample

def forward(self, x):
identity = x

out = self.relu1(self.bn1(x))
if self.downsample:
identity = self.downsample(out)

out = self.conv1(out)
out = self.relu2(self.bn2(out))
out = self.conv2(out)

return out + identity

2. ResNeXt

使用分组卷积增加基数(cardinality):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class ResNeXtBlock(nn.Module):
"""ResNeXt瓶颈块"""
expansion = 4

def __init__(self, in_channels, out_channels, cardinality=32, stride=1, downsample=None):
super().__init__()

mid_channels = out_channels * 2

self.conv1 = nn.Conv2d(in_channels, mid_channels, 1, bias=False)
self.bn1 = nn.BatchNorm2d(mid_channels)

self.conv2 = nn.Conv2d(
mid_channels, mid_channels, 3, stride, 1,
groups=cardinality, bias=False # 分组卷积
)
self.bn2 = nn.BatchNorm2d(mid_channels)

self.conv3 = nn.Conv2d(
mid_channels, out_channels * self.expansion, 1, bias=False
)
self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)

self.relu = nn.ReLU(inplace=True)
self.downsample = downsample

def forward(self, x):
identity = x

out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))

if self.downsample:
identity = self.downsample(x)

return self.relu(out + identity)

实际应用

ResNet在以下任务中表现优异:

  • 图像分类:ImageNet、CIFAR-10
  • 目标检测:Faster R-CNN、YOLO的基础网络
  • 语义分割:DeepLab系列的骨干网络
  • 人脸识别:ArcFace等方法的backbone

总结

ResNet通过残差连接有效解决了深层网络的训练难题,其设计理念影响了后续众多网络架构。理解ResNet对于学习现代计算机视觉模型至关重要。

参考资源

© 2019-2026 ovo$^{mc^2}$ All Rights Reserved. | 站点总访问 28969 次 | 访客 19045
Theme by hiero