关注我们,一起学习~
标题:Learning Causal Effects on Hypergraphs 链接:https://dl.acm.org/doi/pdf/10.1145/3534678.3539299 会议:KDD 2022 公司,学校:微软,弗吉尼亚大学
1. 导读
本文结合超图和因果推断,专注于超图上的个干预效果 (ITE) 估计问题,旨在估计干预(例如,戴面罩)会对结果(例如,COVID-19 感染)产生多大的因果影响。现有的 ITE 估计工作要么假设一个人的结果不应该受到其他人的干预分配的影响(即没有干扰),或者假设干扰只存在于普通图中的连接个体之间。这些假设在现实世界的超图上可能是不现实的,其中高阶干扰会由于组交互的存在而影响最终的 ITE 估计。本文研究了高阶干扰建模,并提出了一种由超图神经网络驱动的新因果学习框架。
2. 问题定义
表示节点集合,
表示超边集合,它们一起形成具有 n 节点和 m 条超边的超图H={V, E},其中每个超边可以连接任意数量的节点。该超图上的观测数据可以表示为 {X, H, T, Y},其中
,
和
分别表示节点特征、干预和观察结果。
表示超图关联的矩阵。h=1表示节点 i 在超边 e 中。
实例 i 的潜在结果
,
分别表示干预t=1和t=0。通过函数
得到输出,
,-i表示出了i节点之外的所有其他节点。
对于超图 H 上的每个节点 i,ITE由对应于 t = 1 和 t = 0 的潜在结果之间的差定义,公式如下,文中的ITE实际上是以条件平均干预效果(CATE)的形式定义的。同时,引入了溢出效应(spillover effect )的概念来评估对超图的干扰程度。
节点 i 在其干预 t 和其他节点的干预
下,在超图 H 上的溢出效应定义为下式,即在高阶交互中考虑其他干预对当前节点的影响。
3. 方法
该框架包含三个组件:混杂表征学习、干扰建模和结果预测。从整体上讲,目标是学习一种表达转换来总结高阶干扰,然后采用干扰表征、混杂表征以及干预来估计预期的潜在结果。HyperSCI的图示如图2所示。
首先通过多层感知器 (MLP) 模块将节点特征 x 编码到潜在空间中,即
。从而产生一组表征
期望以此来捕获混杂因此,模型可以通过控制学习表征 z 来减轻混杂偏差。
表征平衡。干预组和对照组中混杂表征 Z 的分布之间可能存在差异,从而导致因果效应的估计有偏,为了最小化这种差异,通过向损失函数添加差异惩罚来利用表征平衡技术,其中可以使用任何分布距离度量来计算这种差异惩罚。在实现中,在干预组和对照组的表征分布之间使用 Wasserstein-1 距离。
在此干扰建模组件中,将混杂表征(Z)、干预 (T) 和超图上的关系信息 (H) 作为输入,以捕获高阶干扰。通过超图模块学习变换函数
,以生成每个节点 i 的干扰表征 (
),即
。如图 3 所示,该模块由超图卷积网络和超图注意机制实现,其中卷积算子计算得到超边的干扰,注意力算子来增强这种机制。
学习干扰表征。为了学习对超图中的每个节点的干扰,使用超图卷积层传播处理干预和混杂表征。首先为超图 H 引入一个普通的拉普拉斯矩阵,公式如下,D为度矩阵,B也是对角矩阵,每个元素表示超边的大小,超图卷积公式如下,其中
表示来自超图模块中第 l 层的表征。
,其中 ∗ 表示逐元素乘法,W为可学习参数。
建模具有不同意义的干扰。虽然上述卷积层可以通过超边传递干扰,但它并没有提供太多的灵活性来解释不同节点通过不同超边干扰的重要性。为了更好地捕捉超图上节点和超边之间的内在关系,利用超图注意机制来学习每个节点的注意力权重。通过聚合其关联节点 (
) 来计算每个超边 (e) 的表征:
。Agg可以是任何聚合函数(例如,平均聚合)。对于每个节点 i 及其关联的超边 e,节点 i 和超边 e 之间的注意力分数计算方式如下,σ为非线性激活函数,
表示和节点i相关的超边的集合,a为可学习权重。
然后使用注意力分数来模拟具有不同意义的干扰。替换公式中的原始关联矩阵 H为增强后的矩阵
其中
。这样,同一超边上不同节点的干扰可以被赋予不同的重要性权重,表明对干扰建模的不同程度的贡献。将最后一个卷积层的最终表征表示为
,并期望它捕获每个节点的高阶干扰。
表征平衡。与混杂表征学习模块类似,计算差异惩罚以反映干预组和对照组中干扰表征分布之间的差异。将这两个差异惩罚加在一起来计算表征平衡损失
。
计算得到混杂表征
和干扰表征
后,建模输出为下式,f是可学习的函数,用于预测潜在结果。使用两个 MLP 模块实现f。然后构建损失函数如下,
4. 结果