神经架构搜索:AutoML的深度实践

🎙️ 语音朗读 当前: 晓晓 (温柔女声)

前言

神经架构搜索(Neural Architecture Search, NAS)是自动化机器学习(AutoML)的核心技术,通过算法自动设计神经网络架构。本文深入解析NAS的原理、方法和实现。

NAS概述

1
2
3
4
5
6
mermaid
graph LR
A[搜索空间] --> B[搜索策略]
B --> C[性能评估]
C --> D[架构优化]
D --> B

搜索空间定义

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import torch.nn as nn
import numpy as np

class SearchSpace:
"""搜索空间定义"""

def __init__(self):
# 操作类型
self.operations = [
'conv3x3', # 3x3卷积
'conv5x5', # 5x5卷积
'conv7x7', # 7x7卷积
'depthwise_conv3x3', # 深度卷积
'max_pool3x3', # 最大池化
'avg_pool3x3', # 平均池化
'skip_connect', # 恒等映射
'sep_conv3x3', # 分离卷积
'sep_conv5x5', # 分离卷积
'none', # 无操作
]

# 通道数选项
self.channel_choices = [16, 32, 64, 128, 256, 512]

# 层数选项
self.depth_choices = [2, 4, 6, 8, 12, 16]

def get_num_operations(self):
return len(self.operations)

def get_operation(self, idx):
return self.operations[idx]

class NASCell(nn.Module):
"""NAS中的基本单元"""

def __init__(self, in_channels, out_channels, num_nodes, search_space):
super().__init__()
self.num_nodes = num_nodes
self.search_space = search_space

# 每个节点的输入混合
self.input_probs = nn.Parameter(
torch.randn(num_nodes, num_nodes) * 0.01
)

# 每个节点的操作选择
self.op_probs = nn.Parameter(
torch.randn(num_nodes, len(search_space.operations)) * 0.01
)

def forward(self, inputs):
"""
Args:
inputs: list of tensors
"""
num_nodes = len(inputs)

# 计算每个节点的操作权重
op_weights = F.softmax(self.op_probs, dim=-1)

# 计算每个节点之间的连接权重
edge_weights = F.softmax(self.input_probs, dim=-1)

# 构建DAG
states = list(inputs)

for node_idx in range(num_nodes):
# 混合所有前驱节点的表示
aggregated = torch.zeros_like(states[0])

for prev_idx in range(node_idx):
# 边的权重
weight = edge_weights[node_idx, prev_idx]
aggregated += weight * states[prev_idx]

# 应用操作
op_idx = torch.argmax(op_weights[node_idx]).item()
op = self._get_operation(op_idx, states[0].shape[1])

if op is not None:
states.append(op(aggregated))
else:
states.append(aggregated)

# 输出是所有节点的拼接
return torch.cat(states[num_nodes:], dim=1)

def _get_operation(self, op_idx, channels):
op_name = self.search_space.get_operation(op_idx)

if op_name == 'conv3x3':
return nn.Sequential(
nn.Conv2d(channels, channels, 3, padding=1, bias=False),
nn.BatchNorm2d(channels),
nn.ReLU(inplace=True)
)
elif op_name == 'conv5x5':
return nn.Sequential(
nn.Conv2d(channels, channels, 5, padding=2, bias=False),
nn.BatchNorm2d(channels),
nn.ReLU(inplace=True)
)
elif op_name == 'max_pool3x3':
return nn.MaxPool2d(3, padding=1, stride=1)
elif op_name == 'avg_pool3x3':
return nn.AvgPool2d(3, padding=1, stride=1)
elif op_name == 'skip_connect':
return nn.Identity()
else:
return None

DARTS:可微分架构搜索

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
class DARTSCell(nn.Module):
"""DARTS的Cell结构"""

def __init__(self, in_channels, out_channels, num_nodes=4, is_reduction=False):
super().__init__()

self.num_nodes = num_nodes
self.is_reduction = is_reduction

# 输入节点
if is_reduction:
self.preprocess0 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, stride=2, bias=False),
nn.BatchNorm2d(out_channels)
)
self.preprocess1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, stride=2, bias=False),
nn.BatchNorm2d(out_channels)
)
else:
self.preprocess0 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels)
)
self.preprocess1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels)
)

# 候选操作(使用Gumbel Softmax)
self._ops = nn.ModuleList([
nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False, groups=out_channels),
nn.BatchNorm2d(out_channels),
nn.Conv2d(out_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels)
), # sep_conv_3x3
nn.Sequential(
nn.Conv2d(out_channels, out_channels, 5, padding=2, bias=False, groups=out_channels),
nn.BatchNorm2d(out_channels),
nn.Conv2d(out_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels)
), # sep_conv_5x5
nn.Sequential(
nn.AvgPool2d(3, padding=1, stride=1),
nn.BatchNorm2d(out_channels)
), # avg_pool_3x3
nn.Sequential(
nn.MaxPool2d(3, padding=1, stride=1),
nn.BatchNorm2d(out_channels)
), # max_pool_3x3
nn.Sequential(
nn.Identity(),
nn.BatchNorm2d(out_channels, affine=False)
), # skip_connect
])

# 边权重(可学习)
self.edge_weights = nn.Parameter(
torch.randn(num_nodes, len(self._ops)) * 1e-3
)

def forward(self, s0, s1):
# 预处理
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)

states = [s0, s1]

for node_idx in range(self.num_nodes):
# 计算该节点的输出
node_result = torch.zeros_like(s0)

for prev_idx in range(node_idx + 2):
# 边权重
weights = F.softmax(self.edge_weights[node_idx], dim=0)

# 对每条边求和
for op_idx, op in enumerate(self._ops):
node_result += weights[op_idx] * op(states[prev_idx])

states.append(node_result)

# 输出是最后两个节点的拼接
return torch.cat(states[-self.num_nodes:], dim=1)

class DARTSNetwork(nn.Module):
"""DARTS网络"""

def __init__(self, num_classes=10, num_layers=8, channels=36):
super().__init__()

self.num_layers = num_layers
self.channels = channels

# Stem
self.stem = nn.Sequential(
nn.Conv2d(3, channels, 3, padding=1, bias=False),
nn.BatchNorm2d(channels)
)

# 中间层
self.cells = nn.ModuleList()
self.reductions = nn.ModuleList()

for i in range(num_layers):
if i in [num_layers // 3, 2 * num_layers // 3]:
# Reduction cell
self.cells.append(
DARTSCell(channels, channels * 2, is_reduction=True)
)
channels *= 2
else:
# Normal cell
self.cells.append(
DARTSCell(channels, channels)
)

# 分类头
self.classifier = nn.Sequential(
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(channels, num_classes)
)

def forward(self, x):
x = self.stem(x)
s0 = s1 = x

for cell in self.cells:
s0, s1 = s1, cell(s0, s1)

return self.classifier(s1)

强化学习搜索策略

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
class Controller(nn.Module):
"""RNN控制器"""

def __init__(self, num_ops=5, hidden_size=100, num_layers=2):
super().__init__()

self.lstm = nn.LSTM(
input_size=hidden_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True
)

self.fc = nn.Linear(hidden_size, num_ops)

# 嵌入层
self.embed = nn.Embedding(num_ops, hidden_size)

self.num_ops = num_ops
self.hidden_size = hidden_size

def forward(self, inputs, hidden=None):
"""
Args:
inputs: 当前的action
hidden: LSTM的隐藏状态
"""
if hidden is None:
batch_size = inputs.size(0)
hidden = (
torch.zeros(2, batch_size, self.hidden_size),
torch.zeros(2, batch_size, self.hidden_size)
)

# 嵌入
x = self.embed(inputs)

# LSTM前向
output, hidden = self.lstm(x, hidden)

# 预测下一个action
logits = self.fc(output)

return logits, hidden

class ReinforceOptimizer:
"""REINFORCE优化器"""

def __init__(self, controller, baseline_value=0):
self.controller = controller
self.baseline = baseline_value
self.optimizer = torch.optim.Adam(
controller.parameters(), lr=0.01
)

def update(self, rewards, log_probs):
"""
Args:
rewards: 多个架构的验证集准确率
log_probs: 每个架构的采样对数概率
"""
# 优势函数
advantages = rewards - self.baseline

# 更新baseline
self.baseline = 0.9 * self.baseline + 0.1 * rewards.mean()

# 计算策略梯度
policy_loss = []

for advantage, lp in zip(advantages, log_probs):
# 负对数似然 * 优势
policy_loss.append(-lp * advantage)

policy_loss = torch.stack(policy_loss).mean()

# 反向传播
self.optimizer.zero_grad()
policy_loss.backward()
self.optimizer.step()

return policy_loss.item()

超网络训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class SuperNet(nn.Module):
"""超网络(用于One-Shot NAS)"""

def __init__(self):
super().__init__()

# 路径权重
self.path_weights = nn.Parameter(
torch.ones(10, 5) / 5 # 10条边,5种操作
)

def forward(self, x, sampling='softmax'):
"""前向传播"""
weights = F.softmax(self.path_weights, dim=-1)

if sampling == 'hard':
# 硬采样:选择权重最大的路径
op_indices = torch.argmax(weights, dim=-1)
# 实际执行对应操作
pass
else:
# 软采样:所有路径加权求和
pass

return x

def train_supernet(supernet, dataloader, epochs=50):
"""训练超网络"""
optimizer = torch.optim.Adam(supernet.parameters(), lr=0.01)

for epoch in range(epochs):
for inputs, targets in dataloader:
optimizer.zero_grad()

outputs = supernet(inputs)
loss = nn.CrossEntropyLoss()(outputs, targets)

loss.backward()
optimizer.step()

print(f"Epoch {epoch+1}: Loss={loss.item():.4f}")

NAS方法对比

方法 搜索策略 搜索速度 精度
NASNet 强化学习
DARTS 梯度
ENAS 权重共享
ProxylessNAS 梯度
Once-for-All 渐进式

总结

神经架构搜索通过自动化设计神经网络架构,大大降低了深度学习模型设计的门槛。DARTS等方法使得NAS在大规模网络上变得可行,推动了AutoML的发展。

参考资源

© 2019-2026 ovo$^{mc^2}$ All Rights Reserved. | 站点总访问 28969 次 | 访客 19045
Theme by hiero