前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【动手实现】Metabalance缓解多任务梯度不平衡

【动手实现】Metabalance缓解多任务梯度不平衡

作者头像
秋枫学习笔记
发布2022-09-19 10:17:05
4310
发布2022-09-19 10:17:05
举报
文章被收录于专栏:秋枫学习笔记

关注我们,一起学习~

1. 导读

此次我们一起来复现Meta AI(Facebook)提出的MetaBalance方法,该方法主要用于平衡多任务模型中不同任务的梯度。在多任务学习中,不同的任务构建的损失函数在梯度回传时其大小可能存在较大差异,MetaBalance对不同的梯度进行相应的缩放使得不同任务的梯度相近。复现的代码如下,需要注意的是,这部分的梯度缩放只针对共享层,对于任务独立的tower部分不影响。

MetaBalance论文之前已经解读过了,下面是文章链接

WWW'22「Meta」MetaBalance:动态调整辅助任务的梯度提升多任务推荐系统性能

2. 代码

这部分是对MetaBalance的实现,具体的使用方式详见链接

https://github.com/dqdallen/Torch-RecHub

代码语言:javascript
复制
"""The metaoptimizer module, it provides a class MetaBalance 
MetaBalance is used to scale the gradient and balance the gradient of each task
Authors: Qida Dong, dongjidan@126.com
"""
import torch
from torch.optim.optimizer import Optimizer
import numpy as np
import time


# set random seeds of torch and numpy
# torch.manual_seed(0)
# np.random.seed(0)

# with the fixed random seed and the following settings,
# the output is same when the input is fixed
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False


class MetaBalance(Optimizer):
  """MetaBalance Optimizer
  This method is used to scale the gradient and balance the gradient of each task

  Args:
    parameters (list): the parameters of model
    relax_factor (float, optional): the relax factor of gradient scaling (default: 0.7)
    beta (float, optional): the coefficient of moving average (default: 0.9)

  """
  def __init__(self, parameters, relax_factor=0.7, beta=0.9):
    
    if relax_factor < 0. or relax_factor >= 1.:
      raise ValueError(f'Invalid relax_factor: {relax_factor}, it should be 0. <= relax_factor < 1.')
    if beta < 0. or beta >= 1.:
      raise ValueError(f'Invalid beta: {beta}, it should be 0. <= beta < 1.')
    rel_beta_dict = {'relax_factor': relax_factor, 'beta': beta}
    super(MetaBalance, self).__init__(parameters, rel_beta_dict)

  @torch.no_grad()
  def step(self, losses):
    """
    Args:
      losses (list[loss_func]): it contains some losses from each auxiliary task and main task
                    the first one is main task
    """
    for idx, loss in enumerate(losses):
      loss.backward(retain_graph=True)
      for group in self.param_groups:
        for gp in group['params']:
          if gp.grad is None:
            # print('breaking')
            break
          if gp.grad.is_sparse:
            raise RuntimeError('MetaBalance does not support sparse gradients')
          # store the result of moving average
          state = self.state[gp]
          if len(state) == 0:
            for i in range(len(losses)):
              if i == 0:
                gp.norms = [0]
              else:
                gp.norms.append(0)
          # calculate the moving average
          beta = group['beta']
          gp.norms[idx] = gp.norms[idx] * beta + (1 - beta) * torch.norm(gp.grad)
          # scale the auxiliary gradient
          relax_factor = group['relax_factor']
          gp.grad = gp.grad * gp.norms[0] / (gp.norms[idx] + 1e-5) * relax_factor + gp.grad * (1. - relax_factor)
          # store the gradient of each auxiliary task in state
          if idx == 0:
            state['sum_gradient'] = torch.zeros_like(gp.data)
            state['sum_gradient'] += gp.grad
          else:
            state['sum_gradient'] += gp.grad

          if gp.grad is not None:
            gp.grad.detach_()
            gp.grad.zero_()
          if idx == len(losses) - 1:
            gp.grad = state['sum_gradient']

3. 效果

在采样后的AliExpress-US数据集上进行了使用MetaBalance和不使用MetaBalance的实验,可以发现MetaBalance还有有效的,感兴趣的小伙伴可以尝试一下。同时可以发现有显式共享层的相比于MMOE而言提升效果更显著。

数据集地址:https://cowtransfer.com/s/fe32c8d1043b4d

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-05-25,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 秋枫学习笔记 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. 导读
  • 2. 代码
  • 3. 效果
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档