前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >联邦元学习算法Per-FedAvg的PyTorch实现

联邦元学习算法Per-FedAvg的PyTorch实现

作者头像
Cyril-KI
发布2022-11-09 14:54:08
7910
发布2022-11-09 14:54:08
举报
文章被收录于专栏:KI的算法杂记KI的算法杂记

I. 前言

Per-FedAvg的原理请见:arXiv | Per-FedAvg:一种联邦元学习方法

II. 数据介绍

联邦学习中存在多个客户端,每个客户端都有自己的数据集,这个数据集他们是不愿意共享的。

数据集为某城市十个地区的风电功率,我们假设这10个地区的电力部门不愿意共享自己的数据,但是他们又想得到一个由所有数据统一训练得到的全局模型。

III. Per-FedAvg

Per-FedAvg算法伪代码:

1. 服务器端

服务器端和FedAvg一致,这里不再详细介绍了,可以看看前面几篇文章。

2. 客户端

对于每个客户端,我们定义它的元函数

为了在本地训练中对元函数进行更新,我们需要计算其梯度:

代码实现如下:

代码语言:javascript
复制
def train(args, model):
    model.train()
    Dtr, Dte, m, n = nn_seq(model.name, args.B)
    model.len = len(Dtr)
    print('training...')
    data = [x for x in iter(Dtr)]
    for epoch in range(args.E):
        model = one_step(args, data, model, lr=args.alpha)
        model = one_step(args, data, model, lr=args.beta)

    return model


def one_step(args, data, model, lr):
    ind = np.random.randint(0, high=len(data), size=None, dtype=int)
    seq, label = data[ind]
    seq = seq.to(args.device)
    label = label.to(args.device)
    y_pred = model(seq)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_function = nn.MSELoss().to(args.device)
    loss = loss_function(y_pred, label)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return model

3. 本地梯度下降

得到初始模型后,需要在本地进行1轮迭代更新:

代码语言:javascript
复制
def local_adaptation(args, model):
    model.train()
    Dtr, Dte = nn_seq_wind(model.name, 50)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.alpha)
    loss_function = nn.MSELoss().to(args.device)
    loss = 0
    for epoch in range(1):
        for seq, label in Dtr:
            seq, label = seq.to(args.device), label.to(args.device)
            y_pred = model(seq)
            loss = loss_function(y_pred, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # print('local_adaptation loss', loss.item())

    return model

IV. 完整代码

完整代码及数据:https://github.com/ki-ljl/Per-FedAvg,点击阅读原文即可跳转至代码下载界面。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-03-12,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 KI的算法杂记 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. 服务器端
  • 2. 客户端
  • 3. 本地梯度下降
相关产品与服务
联邦学习
联邦学习(Federated Learning,FELE)是一种打破数据孤岛、释放 AI 应用潜能的分布式机器学习技术,能够让联邦学习各参与方在不披露底层数据和底层数据加密(混淆)形态的前提下,通过交换加密的机器学习中间结果实现联合建模。该产品兼顾AI应用与隐私保护,开放合作,协同性高,充分释放大数据生产力,广泛适用于金融、消费互联网等行业的业务创新场景。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档