关注我们,一起学习~
此次我们一起来复现Meta AI(Facebook)提出的MetaBalance方法,该方法主要用于平衡多任务模型中不同任务的梯度。在多任务学习中,不同的任务构建的损失函数在梯度回传时其大小可能存在较大差异,MetaBalance对不同的梯度进行相应的缩放使得不同任务的梯度相近。复现的代码如下,需要注意的是,这部分的梯度缩放只针对共享层,对于任务独立的tower部分不影响。
MetaBalance论文之前已经解读过了,下面是文章链接
WWW'22「Meta」MetaBalance:动态调整辅助任务的梯度提升多任务推荐系统性能
这部分是对MetaBalance的实现,具体的使用方式详见链接
https://github.com/dqdallen/Torch-RecHub
"""The metaoptimizer module, it provides a class MetaBalance
MetaBalance is used to scale the gradient and balance the gradient of each task
Authors: Qida Dong, dongjidan@126.com
"""
import torch
from torch.optim.optimizer import Optimizer
import numpy as np
import time
# set random seeds of torch and numpy
# torch.manual_seed(0)
# np.random.seed(0)
# with the fixed random seed and the following settings,
# the output is same when the input is fixed
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
class MetaBalance(Optimizer):
"""MetaBalance Optimizer
This method is used to scale the gradient and balance the gradient of each task
Args:
parameters (list): the parameters of model
relax_factor (float, optional): the relax factor of gradient scaling (default: 0.7)
beta (float, optional): the coefficient of moving average (default: 0.9)
"""
def __init__(self, parameters, relax_factor=0.7, beta=0.9):
if relax_factor < 0. or relax_factor >= 1.:
raise ValueError(f'Invalid relax_factor: {relax_factor}, it should be 0. <= relax_factor < 1.')
if beta < 0. or beta >= 1.:
raise ValueError(f'Invalid beta: {beta}, it should be 0. <= beta < 1.')
rel_beta_dict = {'relax_factor': relax_factor, 'beta': beta}
super(MetaBalance, self).__init__(parameters, rel_beta_dict)
@torch.no_grad()
def step(self, losses):
"""
Args:
losses (list[loss_func]): it contains some losses from each auxiliary task and main task
the first one is main task
"""
for idx, loss in enumerate(losses):
loss.backward(retain_graph=True)
for group in self.param_groups:
for gp in group['params']:
if gp.grad is None:
# print('breaking')
break
if gp.grad.is_sparse:
raise RuntimeError('MetaBalance does not support sparse gradients')
# store the result of moving average
state = self.state[gp]
if len(state) == 0:
for i in range(len(losses)):
if i == 0:
gp.norms = [0]
else:
gp.norms.append(0)
# calculate the moving average
beta = group['beta']
gp.norms[idx] = gp.norms[idx] * beta + (1 - beta) * torch.norm(gp.grad)
# scale the auxiliary gradient
relax_factor = group['relax_factor']
gp.grad = gp.grad * gp.norms[0] / (gp.norms[idx] + 1e-5) * relax_factor + gp.grad * (1. - relax_factor)
# store the gradient of each auxiliary task in state
if idx == 0:
state['sum_gradient'] = torch.zeros_like(gp.data)
state['sum_gradient'] += gp.grad
else:
state['sum_gradient'] += gp.grad
if gp.grad is not None:
gp.grad.detach_()
gp.grad.zero_()
if idx == len(losses) - 1:
gp.grad = state['sum_gradient']
在采样后的AliExpress-US数据集上进行了使用MetaBalance和不使用MetaBalance的实验,可以发现MetaBalance还有有效的,感兴趣的小伙伴可以尝试一下。同时可以发现有显式共享层的相比于MMOE而言提升效果更显著。
数据集地址:https://cowtransfer.com/s/fe32c8d1043b4d