1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
| import torch import torch.nn as nn import torch.nn.functional as F from collections import OrderedDict
class FederatedClient: """联邦学习客户端""" def __init__(self, client_id, model, train_data, device='cpu'): self.client_id = client_id self.model = model self.train_data = train_data self.device = device def set_model_params(self, global_params): """从全局模型加载参数""" self.model.load_state_dict(global_params) def get_model_params(self): """获取本地模型参数""" return self.model.state_dict() def local_train(self, epochs=5, lr=0.01, batch_size=32): """本地训练""" self.model.train() optimizer = torch.optim.SGD(self.model.parameters(), lr=lr) criterion = nn.CrossEntropyLoss() total_loss = 0 for epoch in range(epochs): for batch_idx, (data, target) in enumerate(self.train_data): data, target = data.to(self.device), target.to(self.device) optimizer.zero_grad() output = self.model(data) loss = criterion(output, target) loss.backward() optimizer.step() total_loss += loss.item() return total_loss def compute_gradient_update(self, global_params): """计算梯度更新""" self.set_model_params(global_params) original_params = OrderedDict({ name: param.clone() for name, param in global_params.items() }) self.local_train(epochs=1) gradient_updates = OrderedDict({ name: original_params[name] - param for name, param in self.model.state_dict().items() }) return gradient_updates
class FederatedServer: """联邦学习中央服务器""" def __init__(self, model, num_clients): self.global_model = model self.num_clients = num_clients def aggregate(self, client_updates, client_weights=None): """聚合客户端更新 Args: client_updates: 客户端参数更新列表 client_weights: 客户端权重(通常为数据量比例) """ if client_weights is None: client_weights = [1.0 / self.num_clients] * self.num_clients global_state = OrderedDict() for key in client_updates[0].keys(): global_state[key] = sum( update[key] * weight for update, weight in zip(client_updates, client_weights) ) self.global_model.load_state_dict(global_state) return self.global_model.state_dict() def federated_averaging(self, client_updates, sample_rates): """联邦平均算法 FedAvg""" total_samples = sum(sample_rates) aggregated_params = OrderedDict() for key in client_updates[0].keys(): weighted_sum = torch.zeros_like(client_updates[0][key]) for update, rate in zip(client_updates, sample_rates): weighted_sum += update[key] * (rate / total_samples) aggregated_params[key] = weighted_sum return aggregated_params
|