联邦学习:隐私保护的机器学习新范式

🎙️ 语音朗读 当前: 晓晓 (温柔女声)

前言

联邦学习(Federated Learning)是一种分布式机器学习范式,允许在不直接共享原始数据的情况下进行模型训练,有效保护用户隐私。本文将深入解析联邦学习的原理、算法和实现。

联邦学习核心思想

联邦学习的核心是”数据不动,模型动”:

  • 数据保留在本地设备
  • 只共享模型参数或梯度
  • 中央服务器聚合更新
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict

class FederatedClient:
"""联邦学习客户端"""

def __init__(self, client_id, model, train_data, device='cpu'):
self.client_id = client_id
self.model = model
self.train_data = train_data
self.device = device

def set_model_params(self, global_params):
"""从全局模型加载参数"""
self.model.load_state_dict(global_params)

def get_model_params(self):
"""获取本地模型参数"""
return self.model.state_dict()

def local_train(self, epochs=5, lr=0.01, batch_size=32):
"""本地训练"""
self.model.train()
optimizer = torch.optim.SGD(self.model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

total_loss = 0
for epoch in range(epochs):
for batch_idx, (data, target) in enumerate(self.train_data):
data, target = data.to(self.device), target.to(self.device)

optimizer.zero_grad()
output = self.model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()

total_loss += loss.item()

return total_loss

def compute_gradient_update(self, global_params):
"""计算梯度更新"""
self.set_model_params(global_params)

# 复制原始参数
original_params = OrderedDict({
name: param.clone()
for name, param in global_params.items()
})

# 本地训练
self.local_train(epochs=1)

# 计算参数差异
gradient_updates = OrderedDict({
name: original_params[name] - param
for name, param in self.model.state_dict().items()
})

return gradient_updates

class FederatedServer:
"""联邦学习中央服务器"""

def __init__(self, model, num_clients):
self.global_model = model
self.num_clients = num_clients

def aggregate(self, client_updates, client_weights=None):
"""聚合客户端更新

Args:
client_updates: 客户端参数更新列表
client_weights: 客户端权重(通常为数据量比例)
"""
if client_weights is None:
# 均匀加权
client_weights = [1.0 / self.num_clients] * self.num_clients

# 加权平均
global_state = OrderedDict()
for key in client_updates[0].keys():
global_state[key] = sum(
update[key] * weight
for update, weight in zip(client_updates, client_weights)
)

# 更新全局模型
self.global_model.load_state_dict(global_state)

return self.global_model.state_dict()

def federated_averaging(self, client_updates, sample_rates):
"""联邦平均算法 FedAvg"""
total_samples = sum(sample_rates)

aggregated_params = OrderedDict()
for key in client_updates[0].keys():
weighted_sum = torch.zeros_like(client_updates[0][key])
for update, rate in zip(client_updates, sample_rates):
weighted_sum += update[key] * (rate / total_samples)
aggregated_params[key] = weighted_sum

return aggregated_params

FedAvg算法实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def fedavg_training(server, clients, rounds=100, local_epochs=5):
"""联邦平均训练流程"""

for round_idx in range(rounds):
print(f"\n=== Round {round_idx + 1} ===")

# 获取全局参数
global_params = server.global_model.state_dict()

# 选择客户端(模拟)
selected_clients = clients

# 收集客户端更新
client_updates = []
sample_rates = []

for client in selected_clients:
# 设置全局参数
client.set_model_params(global_params)

# 本地训练
client.local_train(epochs=local_epochs)

# 获取更新
update = client.get_model_params()
client_updates.append(update)

# 记录样本数
sample_rates.append(len(client.train_data.dataset))

# 服务器聚合
server.aggregate(client_updates,
[rate/sum(sample_rates) for rate in sample_rates])

# 评估全局模型
accuracy = evaluate(server.global_model, test_data)
print(f"Global Model Accuracy: {accuracy:.2f}%")

return server.global_model

差分隐私

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import numpy as np

class DPFederatedClient:
"""差分隐私联邦学习客户端"""

def __init__(self, client_id, model, train_data, noise_multiplier=1.0, max_grad_norm=1.0):
self.client_id = client_id
self.model = model
self.train_data = train_data
self.noise_multiplier = noise_multiplier
self.max_grad_norm = max_grad_norm

def train_with_dp(self, epochs=5, lr=0.01):
"""带差分隐私的本地训练"""
optimizer = torch.optim.SGD(self.model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

for epoch in range(epochs):
for data, target in self.train_data:
optimizer.zero_grad()
output = self.model(data)
loss = criterion(output, target)
loss.backward()

# 梯度裁剪
self._clip_gradients()

# 添加噪声
self._add_noise()

optimizer.step()

def _clip_gradients(self):
"""梯度裁剪"""
total_norm = 0
for p in self.model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5

clip_coef = self.max_grad_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in self.model.parameters():
if p.grad is not None:
p.grad.data.mul_(clip_coef)

def _add_noise(self):
"""添加高斯噪声"""
for p in self.model.parameters():
if p.grad is not None:
noise = torch.randn_like(p.grad) * self.noise_multiplier * self.max_grad_norm
p.grad.data.add_(noise)

FedProx算法

处理联邦学习中的数据异构性问题:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class FedProxClient(FederatedClient):
"""FedProx客户端"""

def __init__(self, client_id, model, train_data, mu=0.01):
super().__init__(client_id, model, train_data)
self.mu = mu

def local_train_prox(self, global_params, epochs=5, lr=0.01):
"""带近端项的本地训练"""
self.set_model_params(global_params)

optimizer = torch.optim.SGD(self.model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

for epoch in range(epochs):
for data, target in self.train_data:
optimizer.zero_grad()
output = self.model(data)
loss = criterion(output, target)

# 近端项:惩罚与全局模型的偏离
prox_loss = self._compute_prox_term(global_params)
total_loss = loss + self.mu * prox_loss

total_loss.backward()
optimizer.step()

return self.get_model_params()

def _compute_prox_term(self, global_params):
"""计算近端项"""
prox_term = 0
for (name, param), (g_name, g_param) in zip(
self.model.named_parameters(), global_params.items()
):
prox_term += torch.sum((param - g_param) ** 2)
return prox_term / 2

联邦学习通信优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class CompressedFederatedClient(FederatedClient):
"""压缩通信的联邦学习客户端"""

def __init__(self, client_id, model, train_data, compression_ratio=0.1):
super().__init__(client_id, model, train_data)
self.compression_ratio = compression_ratio

def send_compressed_update(self, global_params):
"""发送压缩更新"""
self.set_model_params(global_params)
self.local_train(epochs=1)

local_params = self.get_model_params()
compressed = {}

for name, param in local_params.items():
# Top-K压缩
flat_param = param.flatten()
k = int(len(flat_param) * self.compression_ratio)

# 选择最大的k个值
values, indices = torch.topk(flat_param.abs(), k)
mask = torch.zeros_like(flat_param)
mask[indices] = 1

compressed[name] = {
'values': flat_param[mask.bool()],
'indices': indices,
'shape': param.shape
}

return compressed

def decompress_update(self, compressed_update):
"""解压更新"""
decompressed = {}

for name, data in compressed_update.items():
flat = torch.zeros(data['shape'].numel())
flat[data['indices']] = data['values']
decompressed[name] = flat.view(data['shape'])

return decompressed

联邦学习类型

类型 数据分布 场景
横向联邦 特征相同,样本不同 跨银行用户数据
纵向联邦 样本相同,特征不同 电商+银行合作
联邦迁移 异构数据 跨域学习

实际应用场景

联邦学习在以下领域应用广泛:

  • 移动键盘预测:Gboard输入法
  • 健康医疗:跨医院医学数据协作
  • 金融风控:银行间的风控模型
  • 边缘计算:IoT设备上的分布式学习

总结

联邦学习为隐私敏感场景下的机器学习提供了可行的解决方案,通过在数据不出本地的情况下进行协作学习,平衡了数据利用和隐私保护的需求。

参考资源

© 2019-2026 ovo$^{mc^2}$ All Rights Reserved. | 站点总访问 28969 次 | 访客 19045
Theme by hiero