I. 前言
在之前的一篇文章联邦学习基本算法FedAvg的代码实现中利用numpy手搭神经网络实现了FedAvg,相比于自己造轮子,还是建议优先使用PyTorch。
II. 数据介绍
联邦学习中存在多个客户端,每个客户端都有自己的数据集,这个数据集他们是不愿意共享的。
本文选用的数据集为中国北方某城市十个区/县从2016年到2019年三年的真实用电负荷数据,采集时间间隔为1小时,即每一天都有24个负荷值。
我们假设这10个地区的电力部门不愿意共享自己的数据,但是他们又想得到一个由所有数据统一训练得到的全局模型。
除了电力负荷数据以外,还有一个备选数据集:风功率数据集。两个数据集通过参数type指定:type == 'load'表示负荷数据,type == 'wind'表示风功率数据。
用某一时刻前24个时刻的负荷值以及该时刻的相关气象数据(如温度、湿度、压强等)来预测该时刻的负荷值。
对于风功率数据,同样使用某一时刻前24个时刻的风功率值以及该时刻的相关气象数据来预测该时刻的风功率值。
各个地区应该就如何制定特征集达成一致意见,本文使用的各个地区上的数据的特征是一致的,可以直接使用。
III. 联邦学习
原始论文中提出的FedAvg的框架为:
class ANN(nn.Module):
def __init__(self, input_dim, name, B, E, type, lr):
super(ANN, self).__init__()
self.name = name
self.B = B
self.E = E
self.len = 0
self.type = type
self.lr = lr
self.loss = 0
self.fc1 = nn.Linear(input_dim, 20)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.dropout = nn.Dropout()
self.fc2 = nn.Linear(20, 20)
self.fc3 = nn.Linear(20, 20)
self.fc4 = nn.Linear(20, 1)
def forward(self, data):
x = self.fc1(data)
x = self.sigmoid(x)
x = self.fc2(x)
x = self.sigmoid(x)
x = self.fc3(x)
x = self.sigmoid(x)
x = self.fc4(x)
x = self.sigmoid(x)
return x
服务器端执行以下步骤:
1. 初始化参数
2. 对第t轮训练来说:首先计算出
,然后随机选择m个客户端,对这m个客户端做如下操作(所有客户端并行执行):更新本地的
得到
。所有客户端更新结束后,将
传到服务器,服务器整合所有
得到最新的全局参数
。
3. 服务器将最新的
分发给所有客户端,然后进行下一轮的更新。
简单来说,每一轮通信时都只是选择部分客户端,这些客户端利用本地的数据进行参数更新,然后将更新后的参数传给服务器,服务器汇总所有客户端的参数形成最新的参数,然后将最新的参数再次分发给所有客户端,进行下一轮更新。
如果某个客户端被选中,那么它将利用本地数据来训练更新服务器传递的初始参数。无论是否被选中,客户端最终都将模型传递给服务器(未更新的客户端的模型是上一轮的全局模型)。
IV. 代码实现
class FedAvg:
def __init__(self, options):
self.C = options['C']
self.E = options['E']
self.B = options['B']
self.K = options['K']
self.r = options['r']
self.input_dim = options['input_dim']
self.type = options['type']
self.lr = options['lr']
self.clients = options['clients']
self.nn = ANN(input_dim=self.input_dim, name='server', B=B, E=E, type=self.type, lr=self.lr).to(device)
self.nns = []
for i in range(K):
temp = copy.deepcopy(self.nn)
temp.name = self.clients[i]
self.nns.append(temp)
参数:
def server(self):
for t in range(self.r):
print('第', t + 1, '轮通信:')
m = np.max([int(self.C * self.K), 1])
# sampling
index = random.sample(range(0, self.K), m)
# local updating
self.client_update(index)
# aggregation
self.aggregation()
# dispatch
self.dispatch()
# return global model
return self.nn
其中client_update(index)表示对选中的客户端进行更新,index为被选中的客户端的序号集合:
def client_update(self, index): # update nn
for k in index:
self.nns[k] = train(self.nns[k])
服务器聚合客户端模型:
def aggregation(self):
s = 0
for j in range(self.K):
# normal
s += self.nns[j].len
params = {}
with torch.no_grad():
for k, v in self.nns[0].named_parameters():
params[k] = copy.deepcopy(v)
params[k].zero_()
for j in range(self.K):
with torch.no_grad():
for k, v in self.nns[j].named_parameters():
params[k] += v * (self.nns[j].len / s)
with torch.no_grad():
for k, v in self.nn.named_parameters():
v.copy_(params[k])
简单来说就是根据客户端样本个数来决定其在最终聚合时所占的权重。聚合公式:
其中
表示第
个客户端的本地数据量。
当然,这只是一种很简单的汇总方式,还有一些其他类型的汇总方式。论文Electricity Consumer Characteristics Identification: A Federated Learning Approach(https://ieeexplore.ieee.org/document/9380668)中总结了三种汇总方式:
值得注意的是,虽然服务器端每次只是选择K个客户端中的m个来进行更新,但在最终汇总的却是所有客户端模型参数。GitHub上某些FedAvg的代码实现中只对被选中的模型进行了聚合,不过本文还是决定以原始论文中的算法框架为准,对所有客户端进行聚合。
聚合结束后服务器端向客户端发送更新后的模型:
def dispatch(self):
params = {}
with torch.no_grad():
for k, v in self.nn.named_parameters():
params[k] = copy.deepcopy(v)
for j in range(self.K):
with torch.no_grad():
for k, v in self.nns[j].named_parameters():
v.copy_(params[k])
客户端更新模型代码:
def train(ann):
ann.train()
if ann.type == 'load':
Dtr, Dte = nn_seq(ann.name, ann.B, ann.type)
else:
Dtr, Dte = nn_seq_wind(ann.named, ann.B, ann.type)
ann.len = len(Dtr)
# print(len(Dtr))
loss_function = nn.MSELoss().to(device)
loss = 0
optimizer = torch.optim.Adam(ann.parameters(), lr=ann.lr)
for epoch in range(ann.E):
cnt = 0
for (seq, label) in Dtr:
cnt += 1
seq = seq.to(device)
label = label.to(device)
y_pred = ann(seq)
loss = loss_function(y_pred, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('epoch', epoch, ':', loss.item())
return ann
利用最终的全局模型在各个客户端的测试集上进行测试:
def global_test(self):
model = self.nn
model.eval()
c = clients if self.type == 'load' else clients_wind
for client in c:
model.name = client
test(model)
V. 实验及结果
本次实验的参数为:
K | C | E | B | r | type | lr |
---|---|---|---|---|---|---|
10 | 0.5 | 50 | 50 | 5 | load | 0.08 |
if __name__ == '__main__':
K, C, E, B, r = 10, 0.5, 50, 50, 5
type = 'load'
input_dim = 30 if type == 'load' else 28
_client = clients if type == 'load' else clients_wind
lr = 0.08
options = {'K': K, 'C': C, 'E': E, 'B': B, 'r': r, 'type': type, 'clients': _client,
'input_dim': input_dim, 'lr': lr}
fedavg = FedAvg(options)
fedavg.server()
fedavg.global_test()
实验结果(MAPE / %):
编号 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |
---|---|---|---|---|---|---|---|---|---|---|
本地 | 5.33 | 4.11 | 3.03 | 4.20 | 3.02 | 2.70 | 2.94 | 2.99 | 2.30 | 4.10 |
numpy | 6.58 | 4.19 | 3.17 | 5.13 | 3.58 | 4.69 | 4.71 | 3.75 | 2.94 | 4.77 |
PyTorch | 6.84 | 4.54 | 3.56 | 5.11 | 3.75 | 4.47 | 4.30 | 3.90 | 3.15 | 4.58 |
其中本地表示十个客户端仅利用本地数据进行训练得到的预测结果,numpy和PyTorch分别表示利用numpy和PyTorch实现FedAvg后全局模型在各个客户端上的预测结果。
可以发现: