编码器和领域分类器的训练目标是相反的,我们可以使用对抗网络(Adversarial Networks)的模式来进行训练。而另一种更加简单的方法就是梯度反转了。
我们来看下图。模型的输入\boldsymbol{x} 经过编码器G_{f} 得到特征向量\boldsymbol{f} ,随后\boldsymbol{f} 被送到两个网络中:(1) 标记分类器G_{y} 和 (2) 领域分类器。标记分类器输出数据标记y ,而领域分类器则预测特征向量的来源的领域d 。
在上面,编码器G_{f} 和领域分类器G_{d} 的训练目标是对抗的,因此文章在二者之间添加了一个梯度反转层(gradient reversal layer, GRL)。
众所周知,反向传播是指将损失(预测值和真实值的差距)逐层向后传递,然后每层网络都会根据传回来的误差计算梯度,进而更新本层网络的参数。而GRL所做的就是,就是将传到本层的误差乘以一个负数(-\lambda ),这样就会使得GRL前后的网络其训练目标相反,以实现对抗的效果。
下面是在pytorch实现的代码。
class grl_func(torch.autograd.Function):
def __init__(self):
super(grl_func, self).__init__()
@ staticmethod
def forward(ctx, x, lambda_):
ctx.save_for_backward(lambda_)
return x.view_as(x)
@ staticmethod
def backward(ctx, grad_output):
lambda_, = ctx.saved_variables
grad_input = grad_output.clone()
return - lambda_ * grad_input, None
class GRL(nn.Module):
def __init__(self, lambda_=0.):
super(GRL, self).__init__()
self.lambda_ = torch.tensor(lambda_)
def set_lambda(self, lambda_):
self.lambda_ = torch.tensor(lambda_)
def forward(self, x):
return grl_func.apply(x, self.lambda_)
需要注意的是,-\lambda 并不是一个常数,而是由0变为1,即\lambda=\frac{2}{1+\exp (-\gamma \cdot p)}-1 其中,\gamma 是一个超参数,文章中设为10;p 随着训练的进行由0变为1,表示当前的训练步数/总的训练步数。上面的式子意味着一开始时,\lambda=0 ,领域分类损失不会回传到编码器网络中,只有领域分类器得到训练;随着训练的进行,\lambda 逐渐增加,编码器得到训练,并开始逐步生成可以混淆领域分类器的特征。