专栏首页程序生活Spektral:使用TF2实现经典GNN的开源库

Spektral:使用TF2实现经典GNN的开源库

简介

Spektral工具还发表了论文: 《Graph Neural Networks in TensorFlow and Keras with Spektral》 https://arxiv.org/abs/2006.12138

github地址:https://github.com/danielegrattarola/spektral/

在本文中,我们介绍了 Spektral,这是一个开源 Python 库,用于使用 TensorFlow 和 Keras 应用程序编程接口构建图神经网络。Spektral 实现了大量的图深度学习方法,包括消息传递和池化运算符,以及用于处理图和加载流行基准数据集的实用程序。这个库的目的是为创建图神经网络提供基本的构建块,重点是 Keras 所基于的用户友好性和快速原型设计的指导原则。因此,Spektral 适合绝对的初学者和专业的深度学习从业者。

主要网络

Spektral 实现了一些主流的图深度学习层,包括:

安装

pip安装:

pip install spektral

源码安装:

git clone https://github.com/danielegrattarola/spektral.git
cd spektral
python setup.py install  # Or 'pip install .'

Spektral实现GCN

对于TF爱好者很友好:

import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.optimizers import Adam

from spektral.data.loaders import SingleLoader
from spektral.datasets.citation import Citation
from spektral.layers import GCNConv
from spektral.models.gcn import GCN
from spektral.transforms import AdjToSpTensor, LayerPreprocess

learning_rate = 1e-2
seed = 0
epochs = 200
patience = 10
data = "cora"

tf.random.set_seed(seed=seed)  # make weight initialization reproducible

# Load data
dataset = Citation(
    data, normalize_x=True, transforms=[LayerPreprocess(GCNConv), AdjToSpTensor()]
)


# We convert the binary masks to sample weights so that we can compute the
# average loss over the nodes (following original implementation by
# Kipf & Welling)
def mask_to_weights(mask):
    return mask.astype(np.float32) / np.count_nonzero(mask)


weights_tr, weights_va, weights_te = (
    mask_to_weights(mask)
    for mask in (dataset.mask_tr, dataset.mask_va, dataset.mask_te)
)

model = GCN(n_labels=dataset.n_labels, n_input_channels=dataset.n_node_features)
model.compile(
    optimizer=Adam(learning_rate),
    loss=CategoricalCrossentropy(reduction="sum"),
    weighted_metrics=["acc"],
)

# Train model
loader_tr = SingleLoader(dataset, sample_weights=weights_tr)
loader_va = SingleLoader(dataset, sample_weights=weights_va)
model.fit(
    loader_tr.load(),
    steps_per_epoch=loader_tr.steps_per_epoch,
    validation_data=loader_va.load(),
    validation_steps=loader_va.steps_per_epoch,
    epochs=epochs,
    callbacks=[EarlyStopping(patience=patience, restore_best_weights=True)],
)

# Evaluate model
print("Evaluating model.")
loader_te = SingleLoader(dataset, sample_weights=weights_te)
eval_results = model.evaluate(loader_te.load(), steps=loader_te.steps_per_epoch)
print("Done.\n" "Test loss: {}\n" "Test accuracy: {}".format(*eval_results))

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 使用TF2与Keras实现经典GNN的开源库——Spektral

    Spektral 是一个基于 Keras API 和 TensorFlow 2,用于图深度学习的开源 Python 库。该项目的主要目的是提供一个简单但又不失灵...

    机器之心
  • 知识图谱与机器学习 | KG入门 -- Part1-b 图深度学习

    来源 | Medium 【磐创AI导读】:本系列文章为大家介绍了知识图谱与机器学习,这篇文章是上一篇文章:知识图谱与机器学习 | KG入门 -- Part1 D...

    磐创AI
  • 基于图卷积神经网络GCN的时间序列预测:图与递归结构相结合的库存品需求预测

    时间序列预测任务可以按照不同的方法执行。最经典的是基于统计和自回归的方法。更准确的是基于增强和集成的算法,我们必须使用滚动周期生成大量有用的手工特性。另一方面,...

    deephub
  • 图深度学习入门教程(二)——模型基础与实现框架

    深度学习还没学完,怎么图深度学习又来了?别怕,这里有份系统教程,可以将0基础的你直接送到图深度学习。还会定期更新哦。

    代码医生工作室
  • NLP简报(Issue#6)

    基于Transformer的模型已经被证实可以有效地处理从序列标记到问题解答等不同类型的NLP任务,其中一种称为BERT[1]的模型得到了广泛使用,但是像其他采...

    NewBeeNLP
  • 兼容性Up!Object Detection API 现已支持 TensorFlow 2

    作者 | Vivek Rathod 和 Jonathan Huang,Google Research

    磐创AI
  • 微软开源了一个用TF实现的GNN例程库

    2019年接近尾声,许多学术机构盘点本年度AI领域技术关键词总少不了图神经网络(GNN),业界渐成共识:CNN处理图像视频等矩阵数据、RNN处理序列数据,GNN...

    CV君
  • 有了TensorFlow2.0,我手里的1.x程序怎么办?

    导读: 自 2015 年开源以来,TensorFlow 凭借性能、易用、配套资源丰富,一举成为当今最炙手可热的 AI 框架之一,当前无数前沿技术、企业项目都基于...

    Datawhale
  • 都在关心TensorFlow2.0,那么我手里的1.x程序怎么办?

    TensorFlow 是谷歌在 2015 年开源的一个通用高性能计算库。从一开始,TensorFlow 的主要目的就是为构建神经网络(NN)提供高性能 API。...

    代码医生工作室
  • 掌握TensorFlow1与TensorFlow2共存的秘密,一篇文章就够了

    TensorFlow是Google推出的深度学习框架,也是使用最广泛的深度学习框架。目前最新的TensorFlow版本是2.1。可能有很多同学想跃跃欲试安装Te...

    蒙娜丽宁
  • ROS机器人TF基础(坐标相关概念和实践)

    机器人建模和控制必须掌握坐标系和坐标变换等基础知识。机器人在空间中运动主要有两种形式:

    zhangrelay
  • 库克最新采访:苹果平均2周收购一家公司,不care华为等竞争对手

    蒂姆·库克最新访谈中透露,苹果保持着对小公司的高频率收购,平均2-3周就会买下一家公司。

    量子位
  • 万字综述,GNN在NLP中的应用,建议收藏慢慢看

    今天为大家解读的是由京东硅谷研发中心首席科学家吴凌飞博士等研究者最新发表的GNN for NLP综述,几乎覆盖了围绕NLP任务的所有GNN相关技术,是迄今为止G...

    Houye
  • 【深度学习】2021 年了,TensorFlow 和 PyTorch 两个深度学习框架地位又有什么变化吗?

    链接:https://www.zhihu.com/question/452749603/answer/1826252757

    黄博的机器学习圈子
  • Go语言经典库使用分析(四)| Gorilla Handlers 源代码实现分析

    上一篇 Go语言经典库使用分析(三)| Gorilla Handlers 详细介绍 中介绍了Handlers常用中间件的使用,这一篇介绍下这些中间件实现的原理...

    飞雪无情
  • 带你入门机器学习与TensorFlow2.x

    本文主要介绍人工智能、机器学习和深度学习的区别,以及软硬件环境的搭建,包括Tensorflow1.x和Tensorflow2.x在同一台机器上如何共存。在后续的...

    蒙娜丽宁
  • [深度学习概念]·图神经网络综述:模型与应用

    近年来,图神经网络的研究成为深度学习领域的热点。近日,清华大学孙茂松组在 arXiv 上发布预印版综述文章 Graph Neural Networks: A R...

    小宋是呢
  • 清华大学图神经网络综述:模型与应用

    该文总结了近年来图神经网络领域的经典模型与典型应用,并提出了四个开放性问题。对于希望快速了解这一领域的读者,不妨先从这篇文章看起。

    机器之心
  • TF2下变分自编码的N种写法

    在开篇之前,请允许我吐槽几段文字,发泄一下TF的不便之处。如果对这部分内容不敢兴趣请直接看正文内容。

    代码医生工作室

扫码关注云+社区

领取腾讯云代金券