前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >GraphGallery:几行代码玩转图神经网络

GraphGallery:几行代码玩转图神经网络

作者头像
Houye
发布2021-10-12 10:47:47
7750
发布2021-10-12 10:47:47
举报
文章被收录于专栏:图与推荐

TensorFlow or PyTorch, both!

本文介绍中山大学图学习团队开发的图神经网络基准模型库GraphGallery,支持多种深度学习框架(PyTorch与TensorFlow)以及两种图神经网络开发后端(PyG与DGL),能够帮助你快速训练和测试图神经网络模型。

1前言

图神经网络(Graph Neural Networks,GNN)是近几年兴起的新的研究热点,其借鉴了传统卷积神经网络等模型的思想,在图结构数据上定义了一种新的神经网络架构。如果作为初入该领域的科研人员,想要快速学习并验证自己的想法,需要花费一定的时间搜集数据集,定义模型的训练测试过程,寻找现有的模型进行比较测试,这无疑是繁琐且不必要的。GraphGallery 为科研人员提供了一个简单方便的框架,用于在一些常用的数据集上快速建立和测试自己的模型,并且与现有的基准模型进行比较。GraphGallery目前支持主流的两大机器学习框架:TensorFlow 和 PyTorch,以及两种图神经网络开发后端PyG与DGL,带你几行代码玩转图神经网络。

GraphGallery项目地址:https://github.com/EdisonLeeeee/GraphGallery

2GraphGallery项目概览

GraphGallery架构图

GraphGallery的架构主要包括输入数据流,模型构建,以及训练测试pipeline,用于对目前现有的GNN模型进行快速搭建。GraphGallery目前实现了节点分类任务主流的图神经网络模型(如GCN,GAT等),以及部分节点嵌入模型(如DeepWalk,Node2Vec等):

论文模型实现列表(截取部分)

3GraphGallery安装及使用

1安装

安装前需要用户自行安装所需版本的PyTorch,其余TensorFlow,PyTorch Geometric与DGL为可选安装项。

  • 直接从源码安装(推荐使用)
代码语言:javascript
复制
# Recommended
git clone https://github.com/EdisonLeeeee/GraphGallery.git
cd GraphGallery
pip install -e . --verbose
  • 从 Pypi 安装(版本更新相对滞后)
代码语言:javascript
复制
# Maybe outdated
pip install -U graphgallery

2快速上手

Dataset

以领域内常用的固定划分基准数据集Planetoid为例:

代码语言:javascript
复制
from graphgallery.datasets import Planetoid
# set `verbose=False` to avoid informational messages 
data = Planetoid('cora', verbose=False)
graph = data.graph
splits = data.split_nodes() # 使用节点固定的划分
>>> graph
Graph(adj_matrix(2708, 2708),
      node_attr(2708, 1433),
      node_label(2708,),
      metadata=None, multiple=False)

目前包含 6 种数据集

代码语言:javascript
复制
>>> data.available_datasets()
Objects in BunchDict:
╒════════════╤═══════════════════════════╕
│ Names      │ Objects                   │
╞════════════╪═══════════════════════════╡
│ citeseer   │ citeseer citation dataset │
├────────────┼───────────────────────────┤
│ cora       │ cora citation dataset     │
├────────────┼───────────────────────────┤
│ pubmed     │ pubmed citation dataset   │
├────────────┼───────────────────────────┤
│ nell.0.1   │ NELL dataset              │
├────────────┼───────────────────────────┤
│ nell.0.01  │ NELL dataset              │
├────────────┼───────────────────────────┤
│ nell.0.001 │ NELL dataset              │
╘════════════╧═══════════════════════════╛

graphgallery.datasets模块还提供了相当多的数据集,具体可查看项目主页:

https://github.com/EdisonLeeeee/GraphGallery

Model Gallery

顾名思义,GraphGallery 是一个GNN模型的 Gallery

GraphGallery 实现了一系列的面向不同下游任务的GNN模型,以最常见的GCN模型与节点分类任务为例

代码语言:javascript
复制
from graphgallery.gallery.nodeclas import GCN
trainer = GCN()
trainer.setup_graph(graph)
trainer.build()
trainer.fit(splits.train_nodes, splits.val_nodes)
results = trainer.evaluate(splits.test_nodes)

训练过程如下:

代码语言:javascript
复制
Training...
100/100 [==============================] - Total: 6.46s - 64ms/step - loss: 0.081 - accuracy: 0.986 - val_loss: 0.699 - val_accuracy: 0.788
Testing...
1/1 [====================] - Total: 14.41ms - 14ms/step - loss: 1.119 - accuracy: 0.815

上述代码究竟做了哪些事情呢?

  • 第一步(初始化):trainer = GCN()初始化了一个GCN的训练模型,可以传入参数seeddevice设定随机数种子和运行设备
  • 第二步(数据处理):trainer.setup_graph(graph)对输入的图数据进行预处理,并转换为张量用于后续训练
  • 第三步(模型构建):train.build()实现了模型搭建的步骤,build方法可以指定包含隐藏层单元个数(层数),激活函数,学习率等参数
  • 第四步(训练):trainer.fit(splits.train_nodes, splits.val_nodes)实现了对训练集节点的拟合,并利用验证集节点存储模型最优参数
  • 第五步(测试):训练好后,调用trainer.evaluate(splits.test_nodes)在测试集节点上进行验证。result保存了模型测试结果,输出如下:
代码语言:javascript
复制
>>> result
Objects in BunchDict:
╒══════════╤═══════════╕
│ Names    │   Objects │
╞══════════╪═══════════╡
│ loss     │   1.11898 │
├──────────┼───────────┤
│ accuracy │   0.815   │
╘══════════╧═══════════╛

至此,只需要几行代码即可完成对一个模型的调用和训练测试,并且当你切换不同的后端,调用的是不同后端实现的模型(甚至不需要更改上述调用代码),例如:

代码语言:javascript
复制
import graphgallery
# 修改为TensorFlow后端(需要提前安装好 TensorFlow)
>>> graphgallery.set_backend('tf')
# 修改为PyG后端(需要提前安装好 PyG)
>>> graphgallery.set_backend('pyg')
# 修改为DGL后端(需要提前安装好 DGL)
>>> graphgallery.set_backend('dgl')

当你切换不同的后端,GraphGallery后台会帮你切换模型对应的框架实现(如果有存在模型实现的话),并且不需要修改原先代码,上述的训练代码仍然可以无需修改直接使用:

代码语言:javascript
复制
from graphgallery.gallery.nodeclas import GCN
trainer = GCN()
# 预处理,模型构建,训练,测试代码都不需要改变

如果不清楚当前后端及任务所实现的模型列表,可以调用如下API查看(以节点分类任务为例):

代码语言:javascript
复制
>>> graphgallery.gallery.nodeclas.models()
Registry of PyTorch-Gallery (Node Classification):
╒════════════╤════════════════════════════════════════════════════════════════════════════╕
│ Names      │ Objects                                                                    │
╞════════════╪════════════════════════════════════════════════════════════════════════════╡
│ GCN        │ <class 'graphgallery.gallery.nodeclas.pytorch.gcn.GCN'>                    │
├────────────┼────────────────────────────────────────────────────────────────────────────┤
│ DenseGCN   │ <class 'graphgallery.gallery.nodeclas.pytorch.gcn.DenseGCN'>               │
├────────────┼────────────────────────────────────────────────────────────────────────────┤
│ GAT        │ <class 'graphgallery.gallery.nodeclas.pytorch.gat.GAT'>                    │
├────────────┼────────────────────────────────────────────────────────────────────────────┤

如上所示,输出的是节点分类任务以及PyTorch后端实现的模型(部分输出结果)。

其它模型

除了主流的基于不同框架实现的图神经网络,GraphGallery还实现了一些常用的无监督节点嵌入模型,如DeepWalk,Node2Vec等。GraphGallery使用Scipy+Numpy实现,并采用Numba进行加速,在保证模型性能与原论文相近的同时,大大提高了该方法的速度:

代码语言:javascript
复制
from graphgallery.gallery.embedding import DeepWalk
model = DeepWalk()
model.fit(graph.adj_matrix)
embedding = model.get_embedding()

其中,graph.adj_matrix是输入的邻接矩阵(以Scipy.sparse.csr_matrix方式存储)。如上所示,只需几行代码就可以得到最终的结点嵌入。

4后续工作

在实现上,GraphGallery借鉴了许多优秀的开源项目,如:Pytorch Geometric, Stellargraph 和 DGL等。当前, GraphGallery 仍然处于开发阶段,还有许多工作需要完成:

  • 实现更多的 GNN 模型(多种后端)
  • 支持更多的任务(目前主要支持半监督的节点分类任务),未来会加入更多链路预测,图分类等下游任务
  • 支持更多样的图数据结构(目前主要支持单一无向同构图),未来会考虑异构图,动态图等
  • 为项目提供更好的项目文档和注释(完善中...)

最后,附上项目地址及论文:

[1] GraphGallery 项目主页:https://github.com/EdisonLeeeee/GraphGallery [2] Jintang Li, Kun Xu, Liang Chen*, Zibin Zheng and Xiao Liu, “GraphGallery: A Platform for Fast Benchmarking and Easy Development of Graph Neural Networks Based Intelligent Software”, 2021 IEEE/ACM 43rd International Conference on Software Engineering: Companion Proceedings (ICSE-Companion). IEEE, 2021: 13-16

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2021-10-08,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 图神经网络与推荐系统 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1前言
  • 2GraphGallery项目概览
  • 3GraphGallery安装及使用
    • 1安装
      • 2快速上手
        • Dataset
        • Model Gallery
        • 其它模型
    • 4后续工作
    相关产品与服务
    图数据库 KonisGraph
    图数据库 KonisGraph(TencentDB for KonisGraph)是一种云端图数据库服务,基于腾讯在海量图数据上的实践经验,提供一站式海量图数据存储、管理、实时查询、计算、可视化分析能力;KonisGraph 支持属性图模型和 TinkerPop Gremlin 查询语言,能够帮助用户快速完成对图数据的建模、查询和可视化分析。
    领券
    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档