关注我们,一起学习~
title:A Counterfactual Modeling Framework for Churn Prediction
link:https://dl.acm.org/doi/pdf/10.1145/3488560.3498468
from:WSDM 2022
code:https://github.com/tsinghua-fib-lab/CFChurn
1. 导读
本文针对用户流失预测提出结合因果推断的方法CFChurn。结合反事实推理,捕获社会影响的信息从而对流失进行预测。
首先构建两个embedding,分别表示用户的内在意图和外在社会影响。 然后进行反事实数据增强,通过提供部分标记的反事实数据为模型引入因果信息。 最后通过三头反事实预测框架来引到模型学习因果信息,对用户流失进行预测。 主要思想 :将用户流失的原因归结为用户自己意图和外界影响两方面,然后通过反事实数据增广从原始观察数据中得到反事实数据,最后对原始数据和反事实数据进行预测,并结合一些约束。具体见下文。
2. 定义
流失预测的目标是预测用户在未来一段时间内是否会停止使用服务或平台。模型包括三个输入,记录了用户信息和历史行为的用户特征矩阵
X_v \in \mathbb{R}^{n_{v0} \times N} (N个用户),记录了用户之间交互的交互特征矩阵
X_e \in \mathbb{R}^{n_{e0} \times K} (K种交互),和社会网络G。输出为在未来一段时间内是否会流失,公式如下,
y=F(G,X_v,X_e) 社交网络G以用户为节点,用户交互为边,
G=(\mathcal{V,E}) ,邻接矩阵
A \in \mathbb{R}^{N \times N} ,用户特征即为节点特征,交互特征即为边特征。
3. 方法
一般来说,用户流失的原因主要有两种 。一个 源于人的内生意图,例如,一个人对服务失去兴趣。另一个 来自一个人的社会关系的社会影响,这是外生的。例如,一个人可能会停止使用某项服务,因为当他/她的大多数朋友停止使用某项服务时,他/她会感到压力。使用因果图来说明,构建如图 1(a) 来反映两个原因和客户流失之间的因果关系。
基于这个因果图构建模型 CFChurn,其架构如图 1(b) 所示。具体来说,CFChurn结合骨干网络利用两个独立的embedding来模拟用户的内生流失意图和外生社会影响。以两个embedding作为输入,所提出的反事实建模框架包含两个模块:反事实数据增强模块和三头反事实预测模块。
3.1 骨干网络 外生的社会影响和内生的用户意图是用户流失的两大原因。通过两个独立的embedding对它们进行建模,以捕获不同的信息源。作者设计了SGAT来模拟用户之间的关系,以便它学习的embedding可以潜在地捕获社会影响信息。
3.1.1 特征的embedding 模型首先将用户节点特征𝑿𝒗和边缘特征𝑿𝒆作为输入,并将它们转换为embedding。由于将用户作为节点,用节点embedding𝑯𝒗包含所有用户信息。通过两个全连接层得到特征的embedding,然后使用两个 GCN 层来自动建模用户的社会关系。这个过程可以表述如下,求中W和b都表示可学习参数,第一个公式表示两个全连接层得到embedding,第二个公式是常规的GCN,
\hat{A}=D^{-1/2}AD^{-1/2}+I 。第三个公式将交互信息和用户自身信息结合,||表示拼接。
\begin{aligned}
H_{v_{0}} &=\sigma\left(W_{v}^{1} \sigma\left(W_{v}^{0} X_{v}+b_{v}^{0}\right)+b_{v}^{1}\right) \\
H_{g} &=\sigma\left(\hat{A} H_{v_{0}} W_{g}\right) \\
H_{v} &=H_{g} \| H_{v_{0}}
\end{aligned} 令边embedding
H_e 记录用户的交互信息。用两层全连接层得到边的embedding,公式如下,
H_{e} =\sigma\left(W_{e}^{1} \sigma\left(W_{e}^{0} X_{e}+b_{e}^{0}\right)+b_{e}^{1}\right) 3.1.2 学习用户意图和社会影响的embedding 用户自己的意图可以从三类信息中推断出来,包括他们是谁、他们做了什么、他们有什么朋友。以上所有信息都包含在节点嵌入中,因此只需使用两个全连接层来学习用户意图embedding
H_{UI} 。
本文设计SGAT来挖掘用户之间的交互,从而有效挖掘社会影响。SGAT是在GAT上改进得到的,具体如图2所示,对于每一对相连的节点,可以得到节点embedding
h_{v_i} ,
h_{v_j} 和边embedding
h_{e_{ij}} 。将他们作为输入可以得到表征向量,公式如下,可以发现ij和ji的向量是不一样的,即相互影响是不平衡的。
\epsilon_{ij}=\sigma(W_{\epsilon}(h_{v_i}||h_{v_j}||h_{e_{ij}})+b_{\epsilon})
然后,模型将学习到的向量和原始节点embedding作为SGAT的输入,以更新第 l 层中节点 𝑣𝑖 的隐藏状态
h_{v_i}^l 。这一步通过注意力机制捕捉了来自不同朋友的影响的潜在不同影响。公式如下,在文中通过两层SGAT得到社会影响的表征
H_{SI} 。
\begin{array}{l}
\boldsymbol{h}_{v_{i}}^{l+1}=\alpha_{i i}^{l} \boldsymbol{W}_{s}^{l} \boldsymbol{h}_{v_{i}}^{l}+\sum_{j \in \mathcal{N}(i)} \alpha_{i j}^{l} \boldsymbol{W}_{s}^{l}\left(\boldsymbol{\epsilon}_{i j}^{l} \odot \boldsymbol{h}_{v_{j}}^{l}\right) \\
\alpha_{i j}^{l}=\operatorname{softmax}\left(\sigma\left(\boldsymbol{a}^{T}\left[\boldsymbol{W}_{s}^{l} \boldsymbol{h}_{v_{i}}^{l} \| \boldsymbol{W}_{s}^{l}\left(\boldsymbol{\epsilon}_{i j}^{l} \odot \boldsymbol{h}_{v_{j}}^{l}\right)\right]\right)\right)
\end{array} 3.2 反事实模型框架 将干预 𝑡 定义为用户是否流失了朋友,将潜在结果 𝑦 定义为用户是否会在未来一段时间内流失。如果用户有流失了的朋友,并且他/她在下一个时期流失了,那么他/她的流失可能归因于他/她的朋友的社会影响,通过模型建模来区分他/她是否受到社会因素影响。如果没有干预,则t=0,潜在结果表示为
y_{t_0} ,反之为
y_{t_1} 。由于只能观察到已经发生在干预下的结果,因此可以将观察到的事实表示为
y_f ,未观察到的作为反事实结果
y_{cf} 。公式如下,
\begin{array}{l}
y_{f}=t \times y_{t_{1}}+(1-t) \times y_{t_{0}} \\
y_{c f}=t \times y_{t_{0}}+(1-t) \times y_{t_{1}}
\end{array} 3.2.1 反事实数据增广 通过反事实数据增强来缓解反事实数据缺乏的问题,首先在用户流失场景下做了如下假设,
假设 :用户有流失的好友时的流失概率不小于不控制其他条件相同时的流失概率,即
P(y_{t_1})>=P(y_{t_0}) 得到两个推论(推论比较好理解):
y_f=1 ,那么当他的朋友有流失时,他也会流失
y_{cf}=1 。
y_f=0, 那么它朋友不流失时,他也不会流失
y_{cf}=0 。
这两个推论提供了基于关于客户流失的先验因果知识的反事实标签,可以基于推论构造数据集表示为以下等式,其中
O_{cf1} 和
O_{cf2} 表示基于两个推论构造的反事实数据。
O_f 表示观察到的数据。通过这种方式,可以将反事实预测问题转化为具有部分标记数据的监督学习问题。
\begin{array}{l}
O_{c f 1}=\left\{y_{c f}^{i}=y_{t_{1}}^{i}=1 \mid y_{f}^{i}=y_{t_{0}}^{i}=1, i \in O_{f}\right\}, \\
O_{c f 2}=\left\{y_{c f}^{i}=y_{t_{0}}^{i}=0 \mid y_{f}^{i}=y_{t_{1}}^{i}=0, i \in O_{f}\right\}, \\
O_{c f}=O_{c f 1} \cup O_{c f 2},
\end{array} 3.2.2 三头反事实预测 结合反事实增强的数据集,本节设计一个方法用于同时预测事实结果和反事实结果,以促进模型学习因果信息并提供可解释的预测。由于是二元预测问题,使用二元交叉熵损失作为预测损失,公式如下,其中N,M为
O_f 和
O_{cf} 的样本数。
\begin{array}{c}
\mathcal{L}_{f}\left(y_{f}, \hat{y}_{f}\right)=\frac{1}{N} \sum_{i \in O_{f}} y_{f}^{i} \log \left(\hat{y}_{f}^{i}\right)+\left(1-y_{f}^{i}\right) \log \left(1-\hat{y}_{f}^{i}\right), \\
\mathcal{L}_{c f}\left(y_{c f}, \hat{y}_{c f}\right)=\frac{1}{M} \sum_{i \in O_{c f}} y_{c f}^{i} \log \left(\hat{y}_{c f}^{i}\right)+\left(1-y_{c f}^{i}\right) \log \left(1-\hat{y}_{c f}^{i}\right),
\end{array} 为了促进因果信息学习过程,进一步设计了两个组件 。首先 ,基于因果假设引入一个因果正则化器。强制在有干预的分支中的模型预测不小于没有干预的分支中的预测,这可以表述为正则化损失如下:
\mathcal{L}_{c}\left(\hat{y}_{t_{0}}, \hat{y}_{t_{1}}\right)=\frac{1}{M} \sum_{i \in O_{f} \cup O_{c f}} \max \left(0, \hat{y}_{t_{0}}^{i}-\hat{y}_{t_{1}}^{i}\right)
其次 ,通过使用社会影响embedding来预测干预,以促进模型学习因果信息,如图 1(b) 所示。损失函数如下:
\mathcal{L}_{t}(t, \hat{t})=\frac{1}{N+M} \sum_{i \in O_{f} \cup O_{c f}} t^{i} \log \left(\hat{t}^{i}\right)+\left(1-t^{i}\right) \log \left(1-\hat{t}^{i}\right)
上述任务,包括事实结果预测任务、反事实结果预测任务和干预预测任务,以及因果正则化器构成了客户流失预测反事实预测框架的核心。
3.3 输出和训练 3.3.1 输出 为了将社会影响embedding和用户意图embedding转化为客户流失预测,首先将它们拼接起来并使用自注意力层将它们融合在一起。然后,使用两个全连接层来计算预测,可以表示如下,其中p,w,b均为可学习参数,反事实结果和事实结果都采用这个方式预测,并且利用单层全连接层结合社会影响embedding预测干预。
\hat{\boldsymbol{y}}=\operatorname{sigmoid}\left(\boldsymbol{p}^{T} \sigma\left(\boldsymbol{W}_{f c} \text { attention }\left(\boldsymbol{H}_{S I} \| \boldsymbol{H}_{U I}\right)+\boldsymbol{b}_{f c}\right)\right)
3.3.2 训练 客户流失的训练数据通常偏向于负样本,因为流失的用户少于留存的用户(不平衡)。可以通过加权
\alpha_d 来缓解不平衡性,使模型更加关注少数类。因此,目标函数可以定义如下,
\alpha_{cf} ,
\alpha_t ,
\alpha_c 为超参数。
\mathcal{L}=\left(1+y_{f} \times \alpha_{d}\right) \times\left(\mathcal{L}_{f}+\alpha_{c f} \mathcal{L}_{c f}\right)+\alpha_{t} \mathcal{L}_{t}+\alpha_{c} \mathcal{L}_{c}
4. 结果