前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >基于PyTorch实现联邦学习的基本算法FedAvg

基于PyTorch实现联邦学习的基本算法FedAvg

作者头像
Cyril-KI
发布2022-11-08 16:35:20
8060
发布2022-11-08 16:35:20
举报
文章被收录于专栏:KI的算法杂记

I. 前言

在之前的一篇文章联邦学习基本算法FedAvg的代码实现中利用numpy手搭神经网络实现了FedAvg,相比于自己造轮子,还是建议优先使用PyTorch。

II. 数据介绍

联邦学习中存在多个客户端,每个客户端都有自己的数据集,这个数据集他们是不愿意共享的。

本文选用的数据集为中国北方某城市十个区/县从2016年到2019年三年的真实用电负荷数据,采集时间间隔为1小时,即每一天都有24个负荷值。

我们假设这10个地区的电力部门不愿意共享自己的数据,但是他们又想得到一个由所有数据统一训练得到的全局模型。

除了电力负荷数据以外,还有一个备选数据集:风功率数据集。两个数据集通过参数type指定:type == 'load'表示负荷数据,type == 'wind'表示风功率数据。

特征构造

用某一时刻前24个时刻的负荷值以及该时刻的相关气象数据(如温度、湿度、压强等)来预测该时刻的负荷值。

对于风功率数据,同样使用某一时刻前24个时刻的风功率值以及该时刻的相关气象数据来预测该时刻的风功率值。

各个地区应该就如何制定特征集达成一致意见,本文使用的各个地区上的数据的特征是一致的,可以直接使用。

III. 联邦学习

1. 整体框架

原始论文中提出的FedAvg的框架为:

客户端模型采用PyTorch搭建:

代码语言:javascript
复制
class ANN(nn.Module):
    def __init__(self, input_dim, name, B, E, type, lr):
        super(ANN, self).__init__()
        self.name = name
        self.B = B
        self.E = E
        self.len = 0
        self.type = type
        self.lr = lr
        self.loss = 0
        self.fc1 = nn.Linear(input_dim, 20)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout()
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, 20)
        self.fc4 = nn.Linear(20, 1)

    def forward(self, data):
        x = self.fc1(data)
        x = self.sigmoid(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        x = self.fc3(x)
        x = self.sigmoid(x)
        x = self.fc4(x)
        x = self.sigmoid(x)

        return x

2. 服务器端

服务器端执行以下步骤:

1. 初始化参数

2. 对第t轮训练来说:首先计算出

,然后随机选择m个客户端,对这m个客户端做如下操作(所有客户端并行执行):更新本地的

得到

。所有客户端更新结束后,将

传到服务器,服务器整合所有

得到最新的全局参数

3. 服务器将最新的

分发给所有客户端,然后进行下一轮的更新。

简单来说,每一轮通信时都只是选择部分客户端,这些客户端利用本地的数据进行参数更新,然后将更新后的参数传给服务器,服务器汇总所有客户端的参数形成最新的参数,然后将最新的参数再次分发给所有客户端,进行下一轮更新。

3. 客户端

如果某个客户端被选中,那么它将利用本地数据来训练更新服务器传递的初始参数。无论是否被选中,客户端最终都将模型传递给服务器(未更新的客户端的模型是上一轮的全局模型)。

IV. 代码实现

1. 初始化

代码语言:javascript
复制
class FedAvg:
    def __init__(self, options):
        self.C = options['C']
        self.E = options['E']
        self.B = options['B']
        self.K = options['K']
        self.r = options['r']
        self.input_dim = options['input_dim']
        self.type = options['type']
        self.lr = options['lr']
        self.clients = options['clients']
        self.nn = ANN(input_dim=self.input_dim, name='server', B=B, E=E, type=self.type, lr=self.lr).to(device)
        self.nns = []
        for i in range(K):
            temp = copy.deepcopy(self.nn)
            temp.name = self.clients[i]
            self.nns.append(temp)

参数:

  • K,客户端数量,本文为10个,也就是10个地区。
  • C:选择率,每一轮通信时都只是选择C * K个客户端。
  • E:客户端更新本地模型的参数时,在本地数据集上训练E轮。
  • B:客户端更新本地模型的参数时,本地数据集的batch_size=B。
  • r:服务器端和客户端一共进行r轮通信。
  • clients:客户端集合。
  • type:指定数据类型,负荷预测or风功率预测。
  • lr:学习率。
  • input_dim:数据输入维度,负荷预测为30,风功率预测为28。
  • nn:全局模型。
  • nns: 客户端模型集合。

2. 服务器端

代码语言:javascript
复制
def server(self):
     for t in range(self.r):
          print('第', t + 1, '轮通信:')
          m = np.max([int(self.C * self.K), 1])
          # sampling
          index = random.sample(range(0, self.K), m)
          # local updating
          self.client_update(index)
          # aggregation
          self.aggregation()
          # dispatch
          self.dispatch()

     # return global model
     return self.nn

其中client_update(index)表示对选中的客户端进行更新,index为被选中的客户端的序号集合:

代码语言:javascript
复制
def client_update(self, index):  # update nn
     for k in index:
          self.nns[k] = train(self.nns[k])

服务器聚合客户端模型:

代码语言:javascript
复制
def aggregation(self):
     s = 0
     for j in range(self.K):
          # normal
          s += self.nns[j].len
          
     params = {}
     with torch.no_grad():
          for k, v in self.nns[0].named_parameters():
               params[k] = copy.deepcopy(v)
               params[k].zero_()
     for j in range(self.K):
          with torch.no_grad():
               for k, v in self.nns[j].named_parameters():
                    params[k] += v * (self.nns[j].len / s)
     with torch.no_grad():
          for k, v in self.nn.named_parameters():
               v.copy_(params[k])

简单来说就是根据客户端样本个数来决定其在最终聚合时所占的权重。聚合公式:

其中

表示第

个客户端的本地数据量。

当然,这只是一种很简单的汇总方式,还有一些其他类型的汇总方式。论文Electricity Consumer Characteristics Identification: A Federated Learning Approach(https://ieeexplore.ieee.org/document/9380668)中总结了三种汇总方式:

  • normal:原始论文中的方式,即根据样本数量来决定客户端参数在最终组合时所占比例。
  • LA:根据客户端模型的损失占所有客户端损失和的比重来决定最终组合时参数所占比例。
  • LS:根据损失与样本数量的乘积所占的比重来决定。

值得注意的是,虽然服务器端每次只是选择K个客户端中的m个来进行更新,但在最终汇总的却是所有客户端模型参数。GitHub上某些FedAvg的代码实现中只对被选中的模型进行了聚合,不过本文还是决定以原始论文中的算法框架为准,对所有客户端进行聚合。

聚合结束后服务器端向客户端发送更新后的模型:

代码语言:javascript
复制
def dispatch(self):
     params = {}
     with torch.no_grad():
          for k, v in self.nn.named_parameters():
               params[k] = copy.deepcopy(v)
     for j in range(self.K):
          with torch.no_grad():
               for k, v in self.nns[j].named_parameters():
                    v.copy_(params[k])

3. 客户端

客户端更新模型代码:

代码语言:javascript
复制
def train(ann):
    ann.train()
    if ann.type == 'load':
        Dtr, Dte = nn_seq(ann.name, ann.B, ann.type)
    else:
        Dtr, Dte = nn_seq_wind(ann.named, ann.B, ann.type)
    ann.len = len(Dtr)
    # print(len(Dtr))
    loss_function = nn.MSELoss().to(device)
    loss = 0
    optimizer = torch.optim.Adam(ann.parameters(), lr=ann.lr)
    for epoch in range(ann.E):
        cnt = 0
        for (seq, label) in Dtr:
            cnt += 1
            seq = seq.to(device)
            label = label.to(device)
            y_pred = ann(seq)
            loss = loss_function(y_pred, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print('epoch', epoch, ':', loss.item())

    return ann

4. 测试

利用最终的全局模型在各个客户端的测试集上进行测试:

代码语言:javascript
复制
def global_test(self):
     model = self.nn
     model.eval()
     c = clients if self.type == 'load' else clients_wind
     for client in c:
          model.name = client
          test(model)

V. 实验及结果

本次实验的参数为:

K

C

E

B

r

type

lr

10

0.5

50

50

5

load

0.08

代码语言:javascript
复制
if __name__ == '__main__':
    K, C, E, B, r = 10, 0.5, 50, 50, 5
    type = 'load'
    input_dim = 30 if type == 'load' else 28
    _client = clients if type == 'load' else clients_wind
    lr = 0.08
    options = {'K': K, 'C': C, 'E': E, 'B': B, 'r': r, 'type': type, 'clients': _client,
               'input_dim': input_dim, 'lr': lr}
    fedavg = FedAvg(options)
    fedavg.server()
    fedavg.global_test()

实验结果(MAPE / %):

编号

1

2

3

4

5

6

7

8

9

10

本地

5.33

4.11

3.03

4.20

3.02

2.70

2.94

2.99

2.30

4.10

numpy

6.58

4.19

3.17

5.13

3.58

4.69

4.71

3.75

2.94

4.77

PyTorch

6.84

4.54

3.56

5.11

3.75

4.47

4.30

3.90

3.15

4.58

其中本地表示十个客户端仅利用本地数据进行训练得到的预测结果,numpy和PyTorch分别表示利用numpy和PyTorch实现FedAvg后全局模型在各个客户端上的预测结果。

可以发现:

  • 由于本地数据量充足,因此本地模型表现最好。
  • numpy和PyTorch预测精度相差不大,并且很接近本地模型,这说明十个地区上数据分布类似。
  • 虽然numpy和PyTorch预测精度相差不大,但采用PyTorch实现FedAvg更简单,建议采用后者。
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-02-16,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 KI的算法杂记 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 特征构造
  • 1. 整体框架
  • 客户端模型采用PyTorch搭建:
  • 2. 服务器端
  • 3. 客户端
  • 1. 初始化
  • 2. 服务器端
  • 3. 客户端
  • 4. 测试
相关产品与服务
云服务器
云服务器(Cloud Virtual Machine,CVM)提供安全可靠的弹性计算服务。 您可以实时扩展或缩减计算资源,适应变化的业务需求,并只需按实际使用的资源计费。使用 CVM 可以极大降低您的软硬件采购成本,简化 IT 运维工作。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档