专栏首页小小挖掘机RS Meet DL(62)-[阿里]电商推荐中的特殊特征蒸馏

RS Meet DL(62)-[阿里]电商推荐中的特殊特征蒸馏

今天介绍的论文是:《Privileged Features Distillation for E-Commerce Recommendations》 论文下载地址为:https://arxiv.org/abs/1907.05171?context=cs.IR

说说题目吧,先讲讲蒸馏(Distillation)的概念,我们知道模型最终都要应用于线上,如果太过复杂的模型会导致性能无法保证,往往会应用一个比较简单的模型。但简单的模型有时难以保证预测精度,因此一种做法是训练一个复杂的模型作为老师来指导这个简单模型的训练。这种教师-学生的训练模式,便称为蒸馏。再讲讲Privileged Features,我们这里暂且翻译为特殊特征。好了,进入正文吧。

1、背景

在淘宝的推荐系统中,整个推荐流程可以分为下面的三个阶段:

首先是候选集生成阶段(candidate generation),接下来是粗排阶段(coarse-grained ranking),最后是精排阶段(fine-grained ranking)。这里跟咱们之前接触的两阶段过程不太一样,接下来分别介绍各阶段的内容。

在候选集生成阶段,通过多路召回的方式得到候选集合,召回方式可能有协同过滤、DNN模型等等。

在粗排阶段,主要的任务是预估精排阶段返回的候选集中每个物品的点击率,然后选择最高的一些物品进入精排阶段。粗排阶段输入的特征主要有用户的行为特征(用户的历史点击/购买行为,通常通过RNN或者self-attention进行处理)、用户自身特征(如用户id、性别、年龄)、物品自身特征(如物品id、类别、品牌)。在粗排阶段,考虑到性能的关系,模型的复杂度受到了很大的限制,因此通常是用下面的双塔结构:

点击率计算公式如下:

其中Xu和Xi代表用户和物品对应的向量,Xu混合了用户本身特征和用户行为序列特征。Wu和Wi代表用户和物品侧的参数,而Φ代表从输入到输出的映射关系。在线上应用阶段,可以预先把每个物品的映射计算出来,作为词表进行保存,当一个请求到来时,只需要计算用户侧的映射即可。过程如下图所示:

由于性能的限制,在粗排阶段没有考虑用户-物品的一些交互特征,如用户过去24小时在同类别下物品的点击行为、用户在过去24小时在物品所在店铺内的点击行为。加入这些特征,如果放到用户侧,那么针对每个物品都需要计算一次用户侧的映射,如果放到物品侧,同样针对每个物品都需要计算一次物品侧的映射,这会大大加大计算复杂度。因此,这些交互特征对于粗排阶段的模型来说,通常在线上无法应用,我们就称为Privileged Features。

最后讲一下精排阶段,这一阶段我们不仅要预估CTR、还要预估CVR,因为电商领域的推荐的目标一般是提高GMV(CTR * CVR * Price,商品的Price是确定的,无需预估)。CVR的定义是用户从点击到购买的概率。那么对于用户购买来说,用户在商品详情页面停留的时间、对于评论的查看与否、是否会与商家进行交流会是一些比较有用的强特征。但是,这些特征在线上预估阶段是无法获取的,我们需要在给用户展示物品的时候就来预估CVR,所以对于CVR预估来说,用户在点击后进入到商品详情页的一些特征同样是Privileged Features。

使用这些Privileged Features,是可以提升模型的预测精度的。因此本文借鉴模型蒸馏的思想,让粗排阶段的CTR模型或者是精排阶段的CVR模型,都能够学习到一些Privileged Features的信息。下一节,咱们来具体学习一下。

2、特殊特征蒸馏(Privileged Features Distillation)

接下来,咱们以粗排阶段的CTR预估来讲一下本文中提出的蒸馏技术。

2.1 模型蒸馏 VS 特殊特征蒸馏

先来看一下模型蒸馏Model Distillation和特殊特征蒸馏Privileged Features Distillation的对比:

二者的思路都是训练一个复杂的Teacher网络和一个简单的Student网络,并通过Teacher网络来在一定程度上指导Student网络的学习。对于模型蒸馏Model Distillation来说,两个网络的输入是相同的,只是Teacher网络的模型结构更加复杂;对于Privileged Features Distillation来说,两个网络的结构是相同的,只不过Teacher网络可以输入更多的Privileged Features。

2.2 Unified Distillation(UD)

如果只使用Privileged Features Distillation,Teacher网络和Student网络均使用双塔结构的话,这其实也对模型的能力在一定程度上进行了限制。因此实际应用中,融合Model Distillation和Privileged Features Distillation,便得到Unified Distillation。其结构示意图如下:

对于Teacher网络,使用多层神经网络来进行学习,而对于Student网络,还是使用双塔结构。

2.3 模型训练

既然是用Teacher网络来指导Student网络的训练,那么常见的一种方式是,先训练好一个比较精确的Teacher网络,然后再训练Student网络。Student网络的损失函数如下:

上面的损失函数被分为两部分,两部分都是计算交叉熵。其中X*代表Privileged Features。损失的第一部分是可以称为hard loss,其label是[0,1]或者[1,0],第二部分可以称为soft loss或distillation loss,其label是Teacher网络的输出,如[0.8,0.2](0.8的概率点击,0.2的概率不点击)。

但是,如果先训练Teacher网络,在阿里的实际场景中需要数天的时间。因此,一种做法是同时训练Teacher网络和Student网络,二者的损失函数变为:

这么做虽然能够带来训练速度的提升,但有时候的效果是比较差的。这主要是由于在训练的初期,Teacher网络的精度不够,给出的结果容易误导Student网络。因此通过对参数λ的控制来调整Teacher网络对于Student网络的影响。在初期,λ比较小,Teacher网络对于Student网络的影响较小,而随着训练的进行,逐步增加λ,让Student学习到更多的Teacher网络的信息。

论文里还提出了两点值得注意。首先是更新Teacher网络的时候,把distillation loss剔除,避免Student网络影响到teacher网络。第二点就是Teacher网络和Student共享特征的embedding,这样就极大减少了参数的数量。

3、实验结果

简单看一下实验结果,这里对比了模型蒸馏、特殊特征蒸馏以及混合方式下Teacher网络和Student网络的AUC,结果如下:

可以看到,混合方式下得到了最好的AUC。其他的一些实验结果大伙可以看下论文。

本文分享自微信公众号 - 小小挖掘机(wAIsjwj),作者:石晓文

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2019-08-23

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 调试神经网络的checklist,切实可行的步骤

    这篇文章提供了可以采取的切实可行的步骤来识别和修复机器学习模型的训练、泛化和优化问题。

    石晓文
  • 三大顶会看动态图表示学习

    鉴于网络挖掘在现实生活中的丰富应用,以及近些年网络表示学习的兴起,网络嵌入已经成为学术界和工业界日益关注的研究热点。

    石晓文
  • 深度强化学习-DDPG算法原理和实现

    在之前的几篇文章中,我们介绍了基于价值Value的强化学习算法Deep Q Network。有关DQN算法以及各种改进算法的原理和实现,可以参考之前的文章: 实...

    石晓文
  • 网络——Wireshark工具

    官网下载安装:https://www.wireshark.org/download.html 基础抓包: 效果查看:

    瑞新
  • 关于人才培养的一点心得

    正式入职鹅厂已有三年,三年里带了三个徒弟,其中两个又各自带了一个。本文纯属自己的一点心得,对错不好说,欢迎对此感兴趣的小伙伴们一起交流、探讨。

    serena
  • 「R」统计检验函数汇总

    通常先用 lm() 函数对数据建立线性模型,再用 anova() 函数提取方差分析的信息更方便。

    王诗翔呀
  • 并发编程之Condition

    一、引言 在java中,对于任意一个java对象,它都拥有一组定义在java.lang.Object上监视器方法,包括wait(),wait(long time...

    lyb-geek
  • 学界 | 伯克利提出强化学习新方法,可让智能体同时学习多个解决方案

    选自BAIR Blog 作者:Haoran Tang、Tuomas Haarnoja 机器之心编译 参与:Panda 强化学习可以帮助智能体自动找到任务的解决策...

    机器之心
  • 麻省理工学院CSAIL的AI会检测出可能被劫持的IP地址

    边界网关协议(BGP)是用于在不同主机网关之间传输数据和信息的路由协议,是internet设计的基础。然而,它却始终存在缺陷:

    AiTechYun
  • 接收post请求(vue+axios)解决跨域问题(三)

    测试是否成功连接: 打开mysql 运行node服务 npm start 运行vue npm run dev 发现并没有拿到数据...

    RtyXmd

扫码关注云+社区

领取腾讯云代金券