首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >多任务学习模型MMoE详解 Multi-gate Mixture-of-Experts 与代码实现

多任务学习模型MMoE详解 Multi-gate Mixture-of-Experts 与代码实现

原创
作者头像
大鹅
发布2021-06-09 15:39:05
7.1K0
发布2021-06-09 15:39:05
举报

背景

在线上推荐预测任务时往往需要预测用户的多个行为,如关注、点赞、停留时间等,从而调整策略进行权衡。其中涉及到多任务学习,本篇将会大概整理一些常用的模型如MMoE, ESMM, SNR方便理解与学习。

MMoE

背景与动机

在工业界基于神经网络的多任务学习在推荐等场景业务应用广泛,比如在推荐系统中对用户推荐物品时,不仅要推荐用户感兴趣的物品,还要尽可能地促进转化和购买,因此要对用户评分和购买两种目标同时建模。阿里之前提出的ESMM模型属于同时对点击率和转换率进行建模,提出的模型是典型的shared-bottom结构。多任务学习中有个问题就是如果子任务差异很大,往往导致多任务模型效果不佳。今天要介绍的这篇文章是谷歌的一个内容推荐团队考虑了多任务之间的区别提出了MMoE模型,并取得了不错的效果。

多任务模型通过学习不同任务的联系和差异,可提高每个任务的学习效率和质量。多任务学习的的框架广泛采用shared-bottom的结构,不同任务间共用底部的隐层。这种结构本质上可以减少过拟合的风险,但是效果上可能受到任务差异和数据分布带来的影响。也有一些其他结构,比如两个任务的参数不共用,但是通过对不同任务的参数增加L2范数的限制;也有一些对每个任务分别学习一套隐层然后学习所有隐层的组合。和shared-bottom结构相比,这些模型对增加了针对任务的特定参数,在任务差异会影响公共参数的情况下对最终效果有提升。缺点就是模型增加了参数量所以需要更大的数据量来训练模型,而且模型更复杂并不利于在真实生产环境中实际部署使用。

因此,论文中提出了一个Multi-gate Mixture-of-Experts(MMoE)的多任务学习结构。MMoE模型刻画了任务相关性,基于共享表示来学习特定任务的函数,避免了明显增加参数的缺点。

模型介绍

MMoE模型的结构(下图c)基于广泛使用的Shared-Bottom结构(下图a)和MoE结构,其中图(b)是图(c)的一种特殊情况,下面依次介绍。

image.png
image.png

Shared-Bottom Multi-task Model

如上图a所示,shared-bottom网络(表示为函数f)位于底部,多个任务共用这一层。往上,K个子任务分别对应一个tower network(表示为h^k),每个子任务的输出y_k=h^k(f(x))

Mixture-of-Experts(MoE)

MoE模型可以形式化表示为y=\sum^n_{i=1}g_i(x)f_i(x) , 其中\sum_{i=1}^ng_i(x)=1, 且f_i,i=1,...,n是n个expert network(expert network可认为是一个神经网络)。

g是组合experts结果的gating network,具体来说g产生n个experts上的概率分布,最终的输出是所有experts的带权加和。显然,MoE可看做基于多个独立模型的集成方法。这里注意MoE并不对应上图中的b部分。

后面有些文章将MoE作为一个基本的组成单元,将多个MoE结构堆叠在一个大网络中。比如一个MoE层可以接受上一层MoE层的输出作为输入,其输出作为下一层的输入使用。

Multi-gate Mixture-of-Experts(MMoE)

MMoE目的就是相对于shared-bottom结构不明显增加模型参数的要求下捕捉任务的不同。其核心思想是将shared-bottom网络中的函数f替换成MoE层,如上图c所示,形式化表达为:

y_k=h^k(f^k(x)),f^k(x)=\sum^n_{i=1}g^k(x)_if_i(x)

其中g^k(x)=softmax(W_{gk}x) ,输入就是input feature,输出是所有experts上的权重。

一方面,因为gating networks通常是轻量级的,而且expert networks是所有任务共用,所以相对于论文中提到的一些baseline方法在计算量和参数量上具有优势。

另一方面,相对于所有任务公共一个门控网络(One-gate MoE model,如上图b),这里MMoE(上图c)中每个任务使用单独的gating networks。每个任务的gating networks通过最终输出权重不同实现对experts的选择性利用。不同任务的gating networks可以学习到不同的组合experts的模式,因此模型考虑到了捕捉到任务的相关性和区别。

模型训练

模型的可训练性,就是模型对于超参数和初始化是否足够鲁棒。作者在人工合成数据集上进行了实验,观察不同随机种子和模型初始化方法对loss的影响。这里简单介绍下两个现象:

第一,Shared-Bottom models的效果方差要明显大于基于MoE的方法,说明Shared-Bottom模型有很多偏差的局部最小点;

第二,如果任务相关度非常高,则OMoE和MMoE的效果近似,但是如果任务相关度很低,则OMoE的效果相对于MMoE明显下降,说明MMoE中的multi-gate的结构对于任务差异带来的冲突有一定的缓解作用。

整体来看,这篇文章是对多任务学习的一个扩展,通过门控网络的机制来平衡多任务的做法在真实业务场景中具有借鉴意义。

模型总结与应用实践

MoE与MMoE两者的共同点都是把原先Hard-parameter sharing中底层全连接层网络划分成了多个子网络Expert,这样的做法更多是模仿了集成学习中的思想,即同等规模下单个网络无法有效学习到所有任务之间通用的表达但通过划分得到多个子网络后每个子网络总能学到某个任务中一些相关独特的表达,再通过Gate的输出(Softmax)加权各个Expert输出,送入各自多层全连接就能将特定任务学习地较好。

MoE只有一个Gate输出,而MMoE是有多个输出。所以不同点在于MMoE针对不同任务均设置了一个对应的Gate,这样的好处是在不添加大量的新参数的情况下学习任务特定的函数去平衡共享的表达来对任务之间的关系进行更明确地建模

在将MMoE应用在Youtube的论文可以得到不同的Expert在不同任务中的重要性不同(可通过看各个Gate的输出来判断每个任务对应哪些Expert比较重要),因此如果想要某个Expert与某个任务之间的相关性越高,可以在输入Gate之前加入一些预设好的任务和Expert权值关系,或者直接自定义Softmax函数,让占比大的Expert输出更大。该论文中提出了wide&deep的框架并有效结合了MMoE的优势,wide部分引入一个浅层网络来缓和选择偏见问题,这被证明是一个很有效的解决方案。

应用实践方面,知乎在2019年利用MMoE替换了Hard-parameter sharing并取得了用户互动率提升100%的巨大成绩,而互动率直接影响的就是用户体验。

知乎后期的努力方向也主要是使用各种策略优化方法来最大化模型的价值,也就是更好地改善用户的体验。一个好的多任务学习方法应该存在一种最合理的方式去对目标进行权衡和融合,才能得到用户和平台收益的最大化。这就是知乎正在尝试的方法:即对用户进行分群,利用用户对不同内容的不同层次的满意度来动态地调整每个目标的权重,最终融合输出,给出一个最终的排序。

代码实现

import tensorflow as tf

from deepctr.feature_column import build_input_features, input_from_feature_columns
from deepctr.layers.utils import combined_dnn_input
from deepctr.layers.core import PredictionLayer, DNN

from tensorflow.python.keras.initializers import glorot_normal
from tensorflow.python.keras.layers import Layer


class MMOELayer(Layer):

    def __init__(self, num_tasks, num_experts, output_dim, seed=1024, **kwargs):
        self.num_experts = num_experts
        self.num_tasks = num_tasks
        self.output_dim = output_dim
        self.seed = seed
        super(MMOELayer, self).__init__(**kwargs)

    def build(self, input_shape):
        input_dim = int(input_shape[-1])
        self.expert_kernel = self.add_weight(
            name='expert_kernel',
            shape=(input_dim, self.num_experts * self.output_dim),
            dtype=tf.float32,
            initializer=glorot_normal(seed=self.seed))
        self.gate_kernels = []
        for i in range(self.num_tasks):
            self.gate_kernels.append(self.add_weight(
                name='gate_weight_'.format(i),
                shape=(input_dim, self.num_experts),
                dtype=tf.float32,
                initializer=glorot_normal(seed=self.seed)))
        super(MMOELayer, self).build(input_shape)

    def call(self, inputs, **kwargs):
        outputs = []
        expert_out = tf.tensordot(inputs, self.expert_kernel, axes=(-1, 0))
        expert_out = tf.reshape(expert_out, [-1, self.output_dim, self.num_experts])
        for i in range(self.num_tasks):
            gate_out = tf.tensordot(inputs, self.gate_kernels[i], axes=(-1, 0))
            gate_out = tf.nn.softmax(gate_out)
            gate_out = tf.tile(tf.expand_dims(gate_out, axis=1), [1, self.output_dim, 1])
            output = tf.reduce_sum(tf.multiply(expert_out, gate_out), axis=2)
            outputs.append(output)
        return outputs

    def get_config(self):

        config = {'num_tasks': self.num_tasks,
                  'num_experts': self.num_experts,
                  'output_dim': self.output_dim}
        base_config = super(MMOELayer, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return [input_shape[0], self.output_dim] * self.num_tasks


def MMOE(dnn_feature_columns, num_tasks, tasks, num_experts=4, expert_dim=8, dnn_hidden_units=(128, 128),
         l2_reg_embedding=1e-5, l2_reg_dnn=0, task_dnn_units=None, seed=1024, dnn_dropout=0, dnn_activation='relu'):
    """Instantiates the Multi-gate Mixture-of-Experts architecture.

    :param dnn_feature_columns: An iterable containing all the features used by deep part of the model.
    :param num_tasks: integer, number of tasks, equal to number of outputs, must be greater than 1.
    :param tasks: list of str, indicating the loss of each tasks, ``"binary"`` for  binary logloss, ``"regression"`` for regression loss. e.g. ['binary', 'regression']
    :param num_experts: integer, number of experts.
    :param expert_dim: integer, the hidden units of each expert.
    :param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of shared-bottom DNN
    :param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector
    :param l2_reg_dnn: float. L2 regularizer strength applied to DNN
    :param task_dnn_units: list,list of positive integer or empty list, the layer number and units in each layer of task-specific DNN
    :param seed: integer ,to use as random seed.
    :param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate.
    :param dnn_activation: Activation function to use in DNN

    :return: a Keras model instance
    """
    if num_tasks <= 1:
        raise ValueError("num_tasks must be greater than 1")
    if len(tasks) != num_tasks:
        raise ValueError("num_tasks must be equal to the length of tasks")
    for task in tasks:
        if task not in ['binary', 'regression']:
            raise ValueError("task must be binary or regression, {} is illegal".format(task))

    features = build_input_features(dnn_feature_columns)

    inputs_list = list(features.values())

    sparse_embedding_list, dense_value_list = input_from_feature_columns(features, dnn_feature_columns,
                                                                         l2_reg_embedding, seed)
    dnn_input = combined_dnn_input(sparse_embedding_list, dense_value_list)
    dnn_out = DNN(dnn_hidden_units, dnn_activation, l2_reg_dnn, dnn_dropout,
                  False, seed=seed)(dnn_input)
    mmoe_outs = MMOELayer(num_tasks, num_experts, expert_dim)(dnn_out)
    if task_dnn_units != None:
        mmoe_outs = [DNN(task_dnn_units, dnn_activation, l2_reg_dnn, dnn_dropout, False, seed)(mmoe_out) for mmoe_out in
                     mmoe_outs]

    task_outputs = []
    for mmoe_out, task in zip(mmoe_outs, tasks):
        logit = tf.keras.layers.Dense(
            1, use_bias=False, activation=None)(mmoe_out)
        output = PredictionLayer(task)(logit)
        task_outputs.append(output)

    model = tf.keras.models.Model(inputs=inputs_list,
                                  outputs=task_outputs)
    return model

SNR

经典的Shared-Bottom网络结构存在一个明显的问题:当共同训练学习的多个任务之间联系不强的时,会严重损害各自任务的效果。因为相对于多个目标各自训练独立模型而言,Shared-Bottom的网络结构会在共享的网络底层引入了Bias。

上面提到的MMoE模型存在的一个问题,它只能够针对共享的experts子网络进行有限的组合。因此,在MMoE模型结构的基础上,SNR模型来实现更灵活的网络参数共享。与MMoE类似,SNR模型将共享的底层网络模块化为子网络。不同的是,SNR模型使用编码变量来控制子网络之间的连接,并且设计了两种类型的连接方式:SNR-Trans和SNR-Aver。

实验效果如下:

ESMM模型

ESMM背景

任务序列依赖关系建模方法中极具有代表性的是阿里妈妈在2018年提出来的Entire Space Multi-Task Model(ESMM)。

基于 Multi-Task Learning 的思路,它有效解决了真实场景中CVR预估面临的数据稀疏以及样本选择偏差这两个关键问题。CVR预估和CTR任务相比,有两个不同:

(1)Sample Selection Bias转化是在点击之后才“有可能”发生的动作,传统CVR模型通常以点击数据为训练集,其中点击未转化为负例,点击并转化为正例。但是训练好的模型实际使用时,则是对整个空间的样本进行预估,而非只对点击样本进行预估。

(2)Data Sparsity作为CVR训练数据的点击样本远小于CTR预估训练使用的曝光样本。

一些策略可以缓解这两个问题,但都没有从实质上解决上面任一个问题。

在推荐系统中,不同任务之间通常存在一种序列依赖关系。在电商中的多目标预估一般是点击率和转化率,其中购买这个行为只有在点击发生后才会发生。因此这是一种序列依赖关系,可以被利用来解决一些任务预估中存在的样本选择偏差SSB训练数据稀疏DS问题。

第一个问题主要是训练转化率预估模型时采用的是点击+转化数据,而用户登录后直接看到的并不是具体的商品详情页,而是首页或者列表页,因此转化率预估模型需要在产品曝光的场景下进行预估,这就导致了训练场景与预估场景(全样本空间)不一致的问题,不同场景肯定会产生有偏的预估结果,进而导致应用效果的损失。

第二个问题是指转化率预估模型的训练样本通常远小于点击率预估模型,如下图所示。

ESMM模型简述

ESMM的模型结构如下:

从模型结构上看,底层的嵌入层是转化率部分和点击率部分共享的,共享的目的主要是解决转化率预估任务正样本稀疏的问题,利用点击率的数据生成更准确的用户和物品的特征表达。

中间层是转化率和点击率部分各自利用完全隔离的神经网络拟合自己的优化目标,最终将两者相乘得到pCTCVR。

上式可转变为:

因此可以通过分别估计pCTCVR和pCTR,然后通过两者相除来解决。同时两者均可在全样本空间进行训练和预估。但预估时会出现前者大于后者的情况,导致pCVR预估值大于1。为了解决这个问题,引入了pCTCVR和pCTR两个辅助任务,并巧妙地将除法改为乘法,训练时,Loss为两者相加。

使用交叉熵损失函数,在CTR任务中,有点击行为的曝光事件标记为正样本,没有点击行为发生的曝光事件标记为负样本;在CTCVR任务中,同时有点击和购买行为的曝光事件标记为正样本,否则标记为负样本。

这样将三个目标同时融合进一个统一的模型,可以一次性得出所有三个优化目标的值,解决了“训练空间和预测空间不一致”和“同时利用点击和转化数据进行全局优化”两个关键问题。因此ESMM的思想是一种较为通用的序列依赖关系建模思路。

ESMM总结

综上所述,ESMM模型是一个新颖的CVR预估方法,其首创了利用用户行为的序列特性在完整样本空间建模,避免了传统CVR模型经常遭遇的样本选择偏差和训练数据稀疏的问题,取得了显著的效果。同时,ESMM模型中的子网络可以替换为任意的学习模型,因此ESMM的框架可以非常容易地和其他学习模型集成,从而吸收其他学习模型的优势,进一步提升学习效果。此外,ESMM建模的思想也比较容易被泛化到电商中多阶段行为的全链路预估场景,如 排序→展现→点击→转化 的行为链路预估,想象空间巨大。

ESM2

ESM2是阿里最新提出的序列依赖关系建模框架,是ESMM的改进版。由于ESMM在CVR预估场景中仍然面临一定的样本稀疏问题,因为点击到购买的样本相对于点击的样本非常少。但幸运的是我们可以利用一个用户在购买某个商品之前产生的一系列其它行为信息,比如将商品加入购物车或心愿单,如图。

那么可以把加入购物车或心愿单的行为定义为Deterministic Action (DAction),表示购买目的很明确的一类行为。而其他对购买相关性不是很大的行为称作 Other Action (OAction)。那原来的Impression→Click→Buy 购物过程就变为:Impression→Click→DAction/OAction→Buy过程。

ESM2的模型结构如下:

通过引入一些中间行为即可将两个任务拆分成更多子任务:点击率、点击到DAction的概率、DAction到购买的概率和OAction到购买的概率。这样做的好处就是更好地缓解样本稀疏问题。

从模型结构中可以看到,总共有四个子任务,但是损失函数一共有三个,分别是:Impression→Click、Impression→Click→DAction和Impression→Click→DAction/OAction→Buy

由二分类交叉熵损失函数形式可得第一个损失函数表达式为:

第二个损失函数表达式为:

第三个损失函数表达式为:

在论文中,总的损失函数由以上三个损失函数加权相加得到,文中的权值均为1,在实际业务场景中也可以根据经验进行动态调整。

在CVR和CTCVR的数据集中测试,该模型显示比当前的SOTA模型在各个指标上效果都更优,且更有效地解决了SSB和DS问题。

PLE Progressive Layered Extraction

腾讯PCG在RecSys2020发表的最佳长论文PLE(Progressive Layered Extraction),是在视频推荐场景下多任务模型。相对于前面的MMOE、SNR和ESMM模型,PLE模型主要解决两个问题:

(1)MMOE中所有的Expert是被所有任务所共享的,这可能无法捕捉到任务之间更复杂的关系,从而给部分任务带来一定的噪声;

(2)不同的Expert之间没有交互,联合优化的效果有所折扣。

从图中的网络结构可以看出,CGC的底层网络主要包括shared experts和task-specific expert构成,每一个expert module都由多个子网络组成,子网络的个数和网络结构都是超参数。上层由多任务网络构成,每一个多任务网络(towerA和towerB)的输入都是由gating网络进行加权控制,每一个子任务的gating网络的输入包括两部分,其中一部分是本任务下的task-specific部分的experts和shared部分的experts组成。

上面看到了CGC网络是一种single-level的网络结构,一个比较直观的思路就是叠加多层CGC网络,从而获得更加丰富的表征能力,而PLE网络结构就是将CGC拓展到了multi-level层中。

可以看出MOE不同experts权重基本相差不大,PLE模型共享experts和独有experts的权重相差更大,说明针对不同的任务,能够有效利用共享Expert和独有Expert的信息,这也解释了为什么其能够达到比MMoE更好的训练结果。

总结

现实世界中很多问题不能分解为一个一个独立的子问题,即使可以分解,各个子问题之间也是相互关联的,通过一些共享因素或共享表示联系在一起。把现实问题当做一个个独立的单任务处理,忽略了问题之间所富含的丰富的关联信息。多任务学习就是为了解决这个问题而诞生的。

多任务学习本质是一种归纳迁移机制,利用额外的信息来源来提高当前任务的学习性能,包括提高泛化准确率、学习速率和已学习模型的可理解性。多任务学习的不同任务在共享层里的局部极小值位置是不同的,通过多任务之间不相关的部分的相互作用,有助于逃离局部极小值点;而多任务之前相关的部分则有利于底部共享层对通用特征表示的学习,因此通常多任务能够取得比单任务模型更好的效果。多任务学习未来的发展可能会出现更多新的思路或者是现有的思路相结合,后者在业内已有一些研究,比如阿里发表在KDD2020的M2GRL:引入多视图后的GRL与MTL框架结合来更好地进行推荐。另外多任务中的损失函数设计也是一个很重要的研究方向,将另外展开叙述。

因此多任务学习能提高泛化能力的可能原因主要有:

第一,不相关任务对于聚合梯度的贡献相对于其他任务来说可以视为噪声,不相关任务也可以通过作为噪声源来提高泛化能力。

第二,增加任务会影响网络参数的更新,比如增加了隐层有效的学习率。

第三,多任务网络在所有任务之间共享网络底部的隐层,或许更小的容量就可以获得同水平或更好的泛化能力。

Ref

  1. paper https://www.kdd.org/kdd2018/accepted-papers/view/modeling-task-relationships-in-multi-task-learning-with-multi-gate-mixture-
  2. https://blog.csdn.net/ty44111144ty/article/details/99068255
  3. https://zhuanlan.zhihu.com/p/55752344
  4. ESMM模型 https://blog.csdn.net/sinat_15443203/article/details/83713802#
  5. https://blog.csdn.net/sinat_15443203/article/details/83713802#:~:text=%E6%8F%90%E5%87%BA%E7%9A%84ESMM%EF%BC%88%E5%AE%8C%E6%95%B4%E7%A9%BA%E9%97%B4%E5%A4%9A%E4%BB%BB%E5%8A%A1%EF%BC%89%E6%A8%A1%E5%9E%8B%E8%83%BD%E5%A4%9F%E5%9C%A8%E5%AE%8C%E6%95%B4%E7%9A%84%E6%A0%B7%E6%9C%AC%E6%95%B0%E6%8D%AE%E7%A9%BA%E9%97%B4%EF%BC%88%E5%8D%B3%E6%9B%9D%E5%85%89%E7%9A%84%E6%A0%B7%E6%9C%AC%E7%A9%BA%E9%97%B4%EF%BC%8C%E4%B8%8B%E5%9B%BE%E6%9C%80%E5%A4%96%E5%B1%82%E5%9C%88%EF%BC%89%E5%90%8C%E6%97%B6%E5%AD%A6%E4%B9%A0%E7%82%B9%E5%87%BB%E7%8E%87%28post-view%20click-through%20rate%2C%20CTR%29%E5%92%8C%E8%BD%AC%E5%8C%96%E7%8E%87%28post-click%20conversion%20rate%2C,CVR%29%E3%80%82%20%E7%94%A8%E6%88%B7%E5%9C%A8%E8%B4%AD%E7%89%A9%E6%97%B6%E9%83%BD%E9%81%B5%E5%BE%AA%E4%B8%80%E4%B8%AA%E9%A1%BA%E5%BA%8F%EF%BC%9Aimpression%20%E2%86%92%20click%20%E2%86%92%20conversion%E3%80%82
  6. 多任务学习汇总 https://blog.csdn.net/m0_52122378/article/details/113593246
  7. 多任务学习应用 https://cloud.tencent.com/developer/article/1652175
  8. Zhe Zhao, Lichan Hong, et al. Recommending What Video to Watch: A Multi-task Ranking System.
  9. 进击的推荐系统:多目标学习如何让知乎用户互动率提升100%?

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 背景
  • MMoE
    • 背景与动机
      • 模型介绍
        • Shared-Bottom Multi-task Model
        • Mixture-of-Experts(MoE)
        • Multi-gate Mixture-of-Experts(MMoE)
      • 模型训练
        • 模型总结与应用实践
          • 代码实现
          • SNR
          • ESMM模型
            • ESMM背景
              • ESMM模型简述
                • ESMM总结
                  • Ref
              • ESM2
              • PLE Progressive Layered Extraction
              • 总结
              相关产品与服务
              云服务器
              云服务器(Cloud Virtual Machine,CVM)提供安全可靠的弹性计算服务。 您可以实时扩展或缩减计算资源,适应变化的业务需求,并只需按实际使用的资源计费。使用 CVM 可以极大降低您的软硬件采购成本,简化 IT 运维工作。
              领券
              问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档