前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >PyG 官方教程 | 开箱即用的图神经网络解释框架

PyG 官方教程 | 开箱即用的图神经网络解释框架

作者头像
Houye
发布2023-03-01 15:29:17
2.7K0
发布2023-03-01 15:29:17
举报
文章被收录于专栏:图与推荐

图神经网络 (GNN) 在处理图结构数据(例如社交网络、分子图和知识图)方面越来越受欢迎。然而,基于图的数据的复杂性和图中节点之间的非线性关系使得很难理解为什么 GNN 会做出特定的预测。随着图神经网络的普及,人们对解释其预测的兴趣也越来越大。

解释在实际机器学习应用程序中的重要性怎么强调都不为过。它们有助于在模型中建立信任和透明度,因为用户可以更好地理解预测是如何进行的以及影响它们的因素。它们改进了决策制定,让决策者有更多的理解,可以根据模型预测做出更明智的决策。解释还使从业者更容易调试和改进他们开发的模型的性能。在某些领域,例如金融和医疗保健,由于合规性和法规,甚至可能需要解释。

图机器学习中的解释在很大程度上是一项持续的研究工作,图的可解释性不如 ML 的其他子领域(如计算机视觉或 NLP)的可解释性成熟。此外,由于 GNN 操作的复杂关系数据,解释本身也不同:

  • 上下文理解:解释需要对图中节点之间的关系和实体进行上下文理解,这可能很复杂且难以理解
  • 动态关系:图中节点之间的关系会随着时间而变化,这使得为在不同时间点做出的预测提供解释变得具有挑战性
  • 异构数据:GML 往往涉及处理具有复杂特征的异构数据类型,因此很难提供统一的解释方法论
  • 解释粒度:解释必须解释预测的结构起源以及特征重要性。这意味着解释哪些节点、边或子图很重要,以及哪些节点或边特征对预测结果有很大贡献。

突出了图机器学习中解释的复杂性。左侧显示了用于在节点 v 处进行预测的 GNN 计算图。计算图中的一些边是重要的神经按摩传递路径(绿色),而其他边则不是(橙色)。然而,GNN 需要聚合重要和不重要的特征来进行预测,而解释方法的目标是识别一小组对预测至关重要的重要特征和路径。

撇开图机器学习的困难和复杂性不谈,最近该领域有很多统一的工作,旨在提供一个统一的框架来评估解释[1,2],并提供现有解释动物园的分类法可用的方法[3]。

在最近的一次社区冲刺中,PyG 社区已经实施了一个核心可解释性框架(https://pytorch-geometric.readthedocs.io/en/latest/modules/explain.html#id4)以及各种评估方法、基准数据集和可视化,这使得在 PyG 中开始使用图机器学习解释变得非常容易。此外,如果您只想开箱即用地使用通用图形解释器(如 GNNExplainer [4] 或 PGExplainer [5]),或者如果您想实施、测试和评估您自己的解释方法,该框架都非常有用。

在这篇博文中,我们将逐步介绍可解释性模块,阐明框架的每个组件如何工作以及它服务于什么目的。之后,我们将讨论各种解释评估方法和综合基准,它们协同工作以确保您为手头的任务提供最佳解释。我们将继续研究开箱即用的可视化方法。最后,我们将介绍在 PyG 中实现您自己的解释方法所需的步骤,并重点介绍高级用例的工作,例如异构图和链接预测解释。

框架

在设计可解释性框架时,我们的目标是设计一个易于使用的可解释性模块,它:

  • 可以扩展以满足许多 GNN 应用程序的要求
  • 可以适应各种类型的图表和解释设置
  • 可以提供解释输出以进行综合评估和可视化

该框架的核心实际上有四个概念:

  • Explainer 类:PyG 可解释性模块的包装器,用于实例级解释(https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/explain/explainer.py)
  • Explanation类:封装解释器输出的类(https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/explain/explanation.py)
  • ExplainerAlgorithm 类:Explainer 用于在给定训练实例的情况下生成解释的可解释性算法(https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/explain/algorithm/base.py)
  • metric包:使用解释输出和潜在的 GNN 模型/基本事实来评估 ExplainerAlgorithm 的评估指标(https://github.com/pyg-team/pytorch_geometric/tree/master/torch_geometric/explain/metric)

要了解它们是如何组合在一起的,让我们看一下下图:

PyG 可解释性框架的高级概述

用户提供解释设置,以及需要解释的模型和数据。Explainer 类是一个 PyG 实例,它包装了一个解释器算法——一种特定的解释方法,为给定的模型和数据生成解释。解释封装在 Explanation 类中,可以进一步进行后处理、可视化和评估。现在让我们更深入地了解可用的各种解释设置

示例解释器

下面是一个 Explainer 设置示例,它使用 GNNExplainer 对 Cora (https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.datasets.Planetoid.html#torch_geometric.datasets.Planetoid)数据集进行模型解释(参见 gnn_explainer.py (https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gnn_explainer.py) 示例)。

代码语言:javascript
复制
explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=200),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='log_probs',
    ),
)

为所有属性设置节点级掩码,为对象设置边缘掩码。为了对模型的特定预测产生解释,我们简单地调用解释器:

代码语言:javascript
复制
node_index = 10 # which node index to explain
explanation = explainer(data.x, data.edge_index, index=node_index)

现在让我们来看看所有的细节,它们使 PyG 中的解释变得如此简单!

Explanation类

我们使用 Explanation 类表示解释,它是一个 Data 或 HeteroData 对象,包含节点、边、特征和数据的任何属性的掩码。在这个范例中,掩码充当各个节点/边/特征的解释属性。掩码值越大,对应的解释成分越重要(0 表示完全不重要)。Explanation 类包含用于获取诱导解释子图的方法,该子图由所有非零解释属性和对解释子图的补充组成。此外,它还包括用于解释的阈值和可视化方法。

Explainer 类和Explanation设置

Explainer 类旨在处理所有可解释性设置,这些设置既可以设置为 Explainer 的直接参数,也可以设置为 ModelConfig 或 ThresholdConfig 的配置。这个新界面提供了许多设置。让我们一个一个地检查可用的。

代码语言:javascript
复制
# Explainer Parameters
model: torch.nn.Module,
algorithm: ExplainerAlgorithm,
explanation_type: Union[ExplanationType, str],
model_config: Union[ModelConfig, Dict[str, Any]],
node_mask_type: Optional[Union[MaskType, str]] = None,
edge_mask_type: Optional[Union[MaskType, str]] = None,
threshold_config: Optional[ThresholdConfig] = None,

该模型可以是我们用来生成解释的任何 PyG 模型。其他模型设置在 ModelConfig 中指定,它指定模型的模式、task_level 和 return_type。模式描述了任务类型,例如 mode=multiclass-classification,task_level表示任务级别(node-、edge-或-graph级别的任务),return_type描述模型的预期返回类型(raw、probs或log_probs)。

有两种类型的解释,如 explanation_type 所指定(有关更深入的讨论,请参见 [1])

  • explanation_type="phenomenon" 旨在解释为什么针对特定输入做出特定决定。我们对数据中从输入到输出的现象感兴趣。在这种情况下,标签用作解释的目标。
  • explanation_type="model" 旨在为所提供的模型提供事后解释。在此设置中,我们试图打开黑匣子并解释其背后的逻辑。在这种情况下,模型预测被用作解释的目标。

解释的精确计算方式由算法参数指定,模块中提供了几个现成的:

  • GNNExplainer:来自“GNNExplainer: Generating Explanations for Graph Neural Networks (https://arxiv.org/abs/1903.03894) ”论文的 GNN-Explainer 模型。
  • PGExplainer:来自“Parameterized Explainer for Graph Neural Network (https://arxiv.org/abs/2011.04573) ”论文的 PGExplainer 模型。
  • AttentionExplainer:使用基于注意力的 GNN(例如 GATConv、GATv2Conv 或 TransformerConv)产生的注意力系数作为边解释的解释器
  • CaptumExplainer:基于 Captum (https://captum.ai/) 的解释器
  • GraphMaskExplainer:来自 Interpreting Graph Neural Networks for NLP With Differentiable Edge Masking (https://arxiv.org/abs/2010.00577) 论文的 GraphMask-Explainer(目前是 torch_geometric.contrib 的一部分)
  • PGMExplainer:来自 PGMExplainer: Probabilistic Graphical Model Explanations for Graph Neural Networks (https://arxiv.org/abs/1903.03894)的 PGMExplainer 模型(目前是 torch_geometric.contrib 的一部分)

我们还支持许多不同类型的掩码,这些掩码由 node_mask_type 和 edge_mask_type 设置,可以是:

  • None 不会屏蔽任何节点/边
  • "object"将掩盖每个节点/边缘
  • “common_attributes”将屏蔽每个节点特征/边缘属性
  • "attributes"将跨所有节点/边缘分别屏蔽每个节点特征/边缘属性

Explainer 类中可用的不同节点掩码类型

最后,您还可以通过 ThresholdConfig 设置阈值行为。如果您不想对解释掩码设置阈值,您可以将其设置为 None,或者您可以在任何值应用hard阈值,或者您可以仅保留 top-k 值与 topk 或将 top-k 值设置为 1 与 topk_hard。

解释评价

生成解释绝不是可解释性工作流程的结束。解释的质量可以通过多种不同的方法来判断。PyG 支持一些开箱即用的解释评估指标,您可以在指标包(https://github.com/pyg-team/pytorch_geometric/tree/master/torch_geometric/explain/metric)中找到它们。

也许最流行的评估指标是 Fidelity+/- (https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.explain.metric.fidelity.html#torch_geometric.explain.metric.fidelity)(有关详细信息,请参见 [1])。保真度评估生成的解释子图对初始预测的贡献,方法是仅将子图提供给模型(保真度-)或将其从整个图中移除(保真度+)。

现象和模型模式的保真度+/- 定义(来源 [1])

保真度分数反映了可解释模型再现自然现象或 GNN 模型逻辑的好坏程度。一旦我们做出了解释,我们就可以获得两种保真度:

代码语言:javascript
复制
from torch_geometric.explain.metric import fidelity
fid_pm = fidelity(explainer, explanation)

我们提供表征分数(https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.explain.metric.characterization_score.html#torch_geometric.explain.metric.characterization_score)作为将两种保真度组合成单个指标的方法 [1]。此外,如果我们有许多不同阈值(或熵)的保真度对解释,我们可以用保真度曲线 auc (https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.explain.metric.fidelity_curve_auc.html#torch_geometric.explain.metric.fidelity_curve_auc)计算保真度曲线下的面积。此外,我们提供了unfaithfulness指标,用于评估Explanation对底层 GNN 预测器 [6] 的忠实程度。

当没有可用的“基本事实”解释时,保真度分数和不忠实度等指标对于评估解释很有用,即我们没有一组预先确定的节点/特征/边来完全解释特定模型预测或现象。特别是在开发新的解释算法时,我们可能会对某些标准基准数据集 [1,2] 上的性能感兴趣。groundtruth_metrics 方法比较解释掩码并返回标准度量的选择((“准确性”,“召回”,“精度”,“f1_score”,“auroc”):

代码语言:javascript
复制
from torch_geometric.explain.metric import groundtruth_metrics
accuracy, auroc = groundtruth_metrics(pred_mask, 
                                      target_mask, 
                                      metrics=["accuracy", "auroc"])

当然,以这种方式评估解释器首先需要基准数据集,其中可以使用基本事实解释。

基准数据集

为了促进新图形解释器算法的开发和严格评估,PyG 现在提供了几个解释器数据集,如 BA2MotifDataset、BAMultiShapesDataset 和 InfectionDataset,以及创建合成基准数据集的简单方法。通过 ExplainerDataset 类提供支持,它创建来自 GraphGenerator 的合成图,并随机将 num_motifs 个来自 MotifGenerator 的图案附加到它上面。基于节点和边缘是否属于某个主题的一部分,给出了真实节点级和边缘级可解释性掩码。

目前支持的 GraphGenerator 有:

  • BAGraph:随机 Barabasi-Albert (BA) 图
  • ERGraph:随机 Erdos-Renyi (ER) 图
  • GridGraph:二维网格图

但是您可以通过子类化 GraphGenerator 类轻松地实现自己的。此外,对于我们支持的图案

  • HouseMotif:来自 [4] 的房屋结构图案
  • CycleMotif:来自 [4] 的循环主题
  • CustomMotif:基于来自 Data 对象或 networkx.Graph 对象(例如轮形)的自定义结构添加任何主题的简便方法

我们可以使用上述设置生成的数据集是 GNNExplainer [4]、PGExplainer [5]、SubgraphX [8]、PGMExplainer [9]、GraphFramEx [1] 等中使用的基准数据集的超类。

随机图形生成器和主题生成器

我们可以动态生成具有所需种子和大小的新数据集。例如,要生成基于 Barabasi-Albert 图的数据集,其中 80 个房屋图案用作地面实况解释标签,我们将使用:

代码语言:javascript
复制
from torch_geometric.datasets import ExplainerDataset
from torch_geometric.datasets.graph_generator import BAGraph
代码语言:javascript
复制
dataset = ExplainerDataset(
    graph_generator=BAGraph(num_nodes=300, num_edges=5),
    motif_generator='house',
    num_motifs=80,
)

BAMultiShapesDataset 是用于评估图分类可解释性算法的综合数据集 [10]。给定三个原子图案,即房屋 (H)、车轮 (W) 和网格 (G),BAMultiShapesDataset 包含 1,000 个 Barabasi-Albert 图,其标签取决于原子图案的附件,如下所示:

BAMultiShapesDataset 中的类取决于原子图案的存在

数据集是预先计算的,以便与官方实施相吻合。

另一个预先计算的数据集是 BA2MotifDataset[5]。它包含 1,000 个 Barabasi-Albert 图。一半的图形附有一个 HouseMotif,其余的附有一个五节点的 CycleMotif。根据附加图案的类型,这些图被分配到两个类别之一。要创建类似的数据集,您可以将 ExplainerDataset 与图形和主题生成器结合使用。

此外,我们提供 InfectionDataset [2] 生成器,其中节点预测它们与受感染节点(黄色)的距离,并使用到受感染节点的唯一路径作为解释。具有非唯一路径到受感染节点的节点被排除在外。不可到达的节点和距离至少为 max_path_length 的节点被合并为一类。

来自 [2] 的感染数据集

为了生成感染数据集,我们指定图形生成器、感染路径长度和感染节点数

代码语言:javascript
复制
# Generate Barabási-Albert base graph
graph_generator = BAGraph(num_nodes=300, num_edges=500)
代码语言:javascript
复制
# Create the InfectionDataset to the generated base graph
dataset = InfectionDataset(
    graph_generator=graph_generator,
    num_infected_nodes=50,
    max_path_length=3
)

我们的目标是在未来添加更多的解释数据集和图形生成器,敬请期待!

可解释性可视化

如前所述,Explanation 类通过两种方法 visualize_feature_importance() 和 visualize_graph() 提供基本的可视化功能。

为了可视化特征,我们可以使用 top_k 指定要绘制的顶级特征的数量,或者使用 feat_labels 指定特征标签的传递。

代码语言:javascript
复制
explanation.visualize_feature_importance(feature_importance.png, top_k= 10 )

输出存储到指定路径,这是上面 Cora 数据集解释器的示例输出:

Cora 上的特征重要性,详情参见gnn_explainer.py(https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gnn_explainer.py)示例

我们还可以很容易地可视化由解释得出的图形。visualize_graph() 的输出是根据重要性值(如果需要,通过配置的阈值)过滤掉边缘后解释子图的可视化。我们可以选择两个后端(graphviz 或 networkx):

代码语言:javascript
复制
explanation.visualize_graph( 'subgraph.png' , backend= "graphviz" )

我们得到有助于解释的节点和边的局部图,边不透明度对应于边重要性。

由gnn_explainer.py(https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gnn_explainer.py)示例中的解释导出的子图

实施您自己的 ExplainerAlgorithm

所有解释计算魔法都发生在传递给Explainer类的ExplainerAlgorithm中。多种流行的解释算法(GNNExplainer、PGExplainer等)已经实现,可以简单使用。但是,如果您发现自己需要一个未实现的ExplainerAlgorithm,请不要担心,只需将ExplainerAlgorithm接口子类化并实现两个必要的抽象方法即可。

前向方法计算解释,它具有以下签名

代码语言:javascript
复制
def forward(
  self,
  # the model used for explanations
  model: torch.nn.Module, 
  # the input node features
  x: Union[torch.Tensor, Dict[NodeType, torch.Tensor]], 
  # the input edge indices
  edge_index: Union[torch.Tensor, Dict[NodeType, torch.Tensor]], 
  # the target of the model (what we are explaining)
  target: torch.tensor, 
  # The index of the model output to explain. 
  # Can be a single index or a tensor of indices.
  index: Union[int, Tensor], optional, 
  # Additional keyword arguments passed to the model
  **kwargs: optional, 
) -> Union[Explanation, HeteroExplanation]

为了协助为不同的解释算法构建 forward() 方法,基类 ExplainerAlgorithm 提供了几个有用的辅助函数,如 _post_process_mask 用于后处理任何掩码以不包括消息传递期间未涉及的元素的任何属性,_get_hard_masks 返回硬节点和边缘掩码,仅 包括在消息传递期间访问的节点和边,_num_hops 以获取模型从中聚合信息的跳数,以及其他。

第二个需要实现的方法是supports()方法

代码语言:javascript
复制
supports(self) -> bool

supports()函数检查解释器是否支持self.explainer_config和 self.model_config中提供的用户定义设置,它检查是否为正在使用的特定解释设置定义了解释算法。

异构图的扩展

如上所述,Explanation可以简单地扩展到异构图和HeteroData。在这种情况下,解释也是一个掩码,但适用于所有节点和边缘特征(具有不同类型)。为此,我们实现了 HeteroExplanation类,它具有与Explanation几乎相同的接口。此外,为了促进未来在这个方向上的工作,我们添加了异构图支持CaptumExplainer,可以作为未来实现的模板。此外,大多数可解释性框架已经在这个方向上面向未来,许多参数被设置为异构案例的可选字典。

解释链接预测

对于那些想要为链接预测提供解释的人,我们添加了GNNExplainer链接解释支持。这个想法是通过索引到边缘张量而不是节点特征张量,将边缘解释视为一种新的目标索引方法。链接预测解释考虑了两个端点的k-hop-neighbourhoods的并集。

此实现与现有代码很好地集成以支持大多数解释配置。用于解释链接预测的示例设置如下所示

代码语言:javascript
复制
model_config = ModelConfig(
    mode='binary_classification',
    task_level='edge',
    return_type='raw',
)
代码语言:javascript
复制
# Explain model output for a single edge:
edge_label_index = val_data.edge_label_index[:, 0]explainer = Explainer(
    model=model,
    explanation_type='model',
    algorithm=GNNExplainer(epochs=200),
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=model_config,
)
explanation = explainer(
    x=train_data.x,
    edge_index=train_data.edge_index,
    edge_label_index=edge_label_index,
)
print(f'Generated model explanations in {explanation.available_explanations}')

要查看完整示例,请查看gnn_explainer_link_pred.py(https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gnn_explainer_link_pred.py)。为了更容易上手实现任何任务级别的解释方法,我们还提供了所有任务级别(图、节点、边缘)的示例参数化测试,感兴趣的读者可以看一下test/explain。(https://github.com/pyg-team/pytorch_geometric/tree/master/test/explain)

这是 PyG中可解释性的旋风之旅。目前,PyG正在研究许多令人兴奋的事情,包括图形可解释性方面以及其他图形机器学习领域。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2023-02-09,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 图神经网络与推荐系统 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 我们可以使用上述设置生成的数据集是 GNNExplainer [4]、PGExplainer [5]、SubgraphX [8]、PGMExplainer [9]、GraphFramEx [1] 等中使用的基准数据集的超类。
  • 异构图的扩展
  • 解释链接预测
  • 对于那些想要为链接预测提供解释的人,我们添加了GNNExplainer链接解释支持。这个想法是通过索引到边缘张量而不是节点特征张量,将边缘解释视为一种新的目标索引方法。链接预测解释考虑了两个端点的k-hop-neighbourhoods的并集。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档