前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >SIGIR'21「华为」双塔模型中的CBNS负采样方法

SIGIR'21「华为」双塔模型中的CBNS负采样方法

作者头像
秋枫学习笔记
发布2022-09-19 11:48:08
1K0
发布2022-09-19 11:48:08
举报
文章被收录于专栏:秋枫学习笔记

Cross-Batch Negative Sampling for Training Two-Tower Recommenders https://dl.acm.org/doi/pdf/10.1145/3404835.3463032

1.背景

本文是SIGIR'21上的一篇短文,主要是对召回阶段的双塔模型中的负采样方法的改进。通常用的表多的是batch内(in-batch)负采样,但是直接使用in-batch负采样,需要较大的batch size,而如果batch size太大,GPU就会承受不住,因此负样本的多少会受到GPU的限制。本文主要利用网络模型训练到一定程度后是相对稳定的,从而得到的embedding相差不大。在此基础上将之前batch的embedding存储之后用于后续batch的训练,从而提出了Cross Batch Negative Sampling (CBNS)。

2.方法

双塔模型在这里不做详细介绍,主要就是对user和item的embedding求相似度从而得到推荐分数,最终进行召回。得分或者说概率计算方式如下,这里用到的是sampled softmax。其中N表示负采样的样本,而u,v分别表示user和item的embedding。最后采用交叉熵损失函数计算损失。

p(I \mid U ; \Theta)=\frac{e^{u^{\top} v}}{e^{u^{\top} v}+\sum_{I^{-} \in \mathcal{N}} e^{u^{\top} v^{-}}}

负采样用的比较多的是in-batch负采样方式,如图1.a所示。即,除了当前的正样本,把同一batch中的其他样本作为负样本。其中负样本的分布符合一元模型分布(unigram distribution),即和样本出现的频率有关,频率越高的越有可能被选为负样本。采用in-batch的负采样方式,并根据sampled softmax的含义,可以将其上式改写为下式,可以发现不同点在于

log(q(I))

,log用于矫正采样偏差。in-batch的采样方式使得负采样的数目和batch size是呈线性关系的,会受到batch size的限制。而batch size太大GPU的内存就会承受不住。

p_{\text {In-batch }}(I \mid U ; \Theta)=\frac{e^{s^{\prime}(U, I ; q)}}{e^{s^{\prime}(U, I ; q)}+\sum_{I^{-} \in \mathcal{B} \backslash\{I\}} e^{s^{\prime}\left(U, I^{-} ; q\right)}},
s^{\prime}(U, I ; q)=s(U, I)-\log q(I)=\boldsymbol{u}^{\top} v-\log q(I)
S(U,I)=u^Tv

2.1CBNS

2.1.1Embedding Stability

在训练模型的时候,我们通常只考虑当前batch的信息,而忽略了前面batch的信息。本文所提的CBNS方法就是利用之前batch中的信息来帮助训练。文中用下式来衡量特征偏移,其中gv表示将item编码的函数,

\theta

表示其参数。

t,\Delta t

表示训练轮次和轮次差。

D(\mathcal{I}, t ; \Delta t) \triangleq \sum_{I \in I}\left\|g_{v}\left(I ; \theta_{g}^{t}\right)-g_{v}\left(I ; \theta_{g}^{t-\Delta t}\right)\right\|_{2}

如下图所示,通过youtube DNN在in-batch上的实验,作者发现在训练前期,特征偏移是非常大的,也就是说特征的在不同轮次中的变化是很大的,但是随着lr的降低,在训练了

4 \times 10^4

轮后,embedding就相对稳定了。因此作者希望采用“embedding stability”这一现象来提升采样效率。但是直接使用之前batch的embedding会给梯度带来误差,作者在文中证明了误差的影响很小,详细证明可以看文中的3.3.1。

2.1.2FIFO Memory Bank

正如2.1.1中所说的embedding stability的稳定是需要在一定的轮次之后,因此在前期作者依旧是使用in-batch,而是在相对稳定之后采用CBNS。存储之前batch的embedding,文中采用先进先出的大小为M的队列组

\mathcal{M}=\{(v_i,q(I_i))\}_{i=1}^M

,其中

q(I_i)

表示item

I_i

在一元模型分布下的的采样概率。具体结构如图1.b。最后sampled softmax改写为:

p_{\mathrm{CBNS}}(I \mid U ; \Theta)=\frac{e^{s^{\prime}(U, I ; q)}}{e^{s^{\prime}(U, I ; q)}+\sum_{I^{-} \in \mathcal{M} \cup \mathcal{B} \backslash\{I\}} e^{s^{\prime}\left(U, I^{-} ; q\right)}}

在每次迭代后,把当前batch的embedding和采样概率存入队列中,并将最早的embedding出队。在计算sampled softmax的时候可以用到batch内的和队列中的负样本。

3.结果

在不同召回模型上的实验结果表明所提采样方法能够使Recall和NDCG明显提升,这也说明利用之前的embedding的信息能够进一步促进模型的性能。可以发现由于多了队列,在时间耗费上有所上升,不过上升的不多。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2021-10-30,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 秋枫学习笔记 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 2.1CBNS
    • 2.1.1Embedding Stability
      • 2.1.2FIFO Memory Bank
      相关产品与服务
      批量计算
      批量计算(BatchCompute,Batch)是为有大数据计算业务的企业、科研单位等提供高性价比且易用的计算服务。批量计算 Batch 可以根据用户提供的批处理规模,智能地管理作业和调动其所需的最佳资源。有了 Batch 的帮助,您可以将精力集中在如何分析和处理数据结果上。
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档