前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >使用 PyG 进行图神经网络训练

使用 PyG 进行图神经网络训练

作者头像
EmoryHuang
发布2022-11-30 11:58:06
1.4K0
发布2022-11-30 11:58:06
举报
文章被收录于专栏:EmoryHuang's Blog

使用 PyG 进行图神经网络训练

前言

最近一直在想创新点,搭模型,想尝试一下图神经网络,想着自己实现一个,但是之前也没有尝试过写 GNN 模型,对其中的实现细节也没有实际尝试过,最后找到了 PyG ,尝试一下之后发现还是挺简单的,也比较好拿到现有模型里面,于是开始挖坑。

PyG (PyTorch Geometric) 是一个基于 PyTorch 的库,可轻松编写和训练图形神经网络 (GNN),用于与结构化数据相关的广泛应用。

目前网上对 PyG 的相关文档并不多,大本部也都是比较重复的内容,因此我主要参考的还是官方文档。具体安装方式参考 PyG Installation

图结构

建图

首先,我们需要根据数据集进行建图,在 PyG 中,一个 Graph 的通过torch_geometric.data.Data进行实例化,它包括下面两个最主要的属性:

  • data.x: 节点的特征矩阵,形状为[num_nodes, num_node_features];
  • data.edge_index: 图的边索引,用 COO 稀疏矩阵格式保存。形状为[2, num_edges];

稀疏矩阵是数值计算中普遍存在的一类矩阵,主要特点是绝大部分的矩阵元为零。COO (Coordinate) 格式是将矩阵中的非零元素存储,每一个元素用一个三元组来表示,分别是(行号,列号,数值)

上面两个属性也就是整个图中最重要的部分。同时你也可以定义下面的一些其他信息:

  • data.edge_attr: 边的特征矩阵,形状为[num_edges, num_edge_features];
  • data.y: 计算损失所需的目标数据,比如节点级别的形状为[num_nodes, *],或者图级别的形状为[1, *];

现在,我们来简单创建一张图:

代码语言:javascript
复制
import torch
from torch_geometric.data import Data

edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index)
>>> Data(edge_index=[2, 4], x=[3, 1])

需要注意的是:

  1. 第一行 edge_index[0] 表示起点,第二行 edge_index[1] 表示终点;
  2. 虽然只有两条边,但在 PyG 中处理无向图时实际上是互为头尾节点;
  3. 矩阵中的值表示索引,指代 x 中的节点。

实际上,你不需要找对专门的属性去存取某些信息,你完全可以「自定义属性名」,并放入你想要的内容,事实上就像一个字典。例如,你想要为图再添加一个 freq 属性以表示访问频率:

代码语言:javascript
复制
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index)
data.freq = torch.tensor([2, 5, 1], dtype=torch.long)
>>> Data(x=[3, 1], edge_index=[2, 4], freq=[3])

此外,这里再列举一些个人觉得比较常用的操作,感觉官方文档有些地方不清不楚,完整的类说明可以看这里

  • coalesce(): 对 edge_index 中的边排序并去重
  • clone(): 创建副本
  • to(cuda:0): 把数据放到 GPU
  • num_edges: 返回边数量
  • num_nodes: 返回节点数量

关于 train_mask

在官方的数据集中,划分「训练集」和「测试集」的方式是创建一张大图,然后指定训练节点以及测试节点,通过 train_masktest_mask 来实现。

代码语言:javascript
复制
from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='/tmp/Cora', name='Cora')

data = dataset[0]
>>> Data(edge_index=[2, 10556], test_mask=[2708],
         train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])

data.is_undirected()
>>> True

data.train_mask.sum().item()
>>> 140

data.val_mask.sum().item()
>>> 500

data.test_mask.sum().item()
>>> 1000

上面是官方提供的数据集中的一个例子,但个人觉得并不是所有的数据集都适合这种方法,我仍习惯于训练集和测试集分开,也就是创建两张图,因为大多数情况下都需要自定义一个 MyDataset,你可以把建图的操作放在里面,所以其实感觉这样还更方便。

代码语言:javascript
复制
data_train, data_test = split_train_test(data)

# Create dataset
dataset_train = MyDataset(data_train)
dataset_test = MyDataset(data_test)

当然,这还是要看具体做什么任务的,对某些任务来说可能只需要一张大图,对我来说可能就需要分成两张。

关于 Embedding

最开始的时候我有想过,为什么要在一开始创建 x 的时候就让我把节点的维度给定下来,这不应该是我后面模型里面 Embedding 的时候再做的事情吗,难不成建图的时候就要 Embedding?当然,对于有些情况,建图的时候完全就可以直接 Embedding。

事实上,其实就是我想多了,其实把它当做一个字典就好。

代码语言:javascript
复制
import torch
from torch import nn
from torch_geometric.data import Data

# create graph
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([1, 0, 2], dtype=torch.long)

data = Data(x=x, edge_index=edge_index)
data.freq = torch.tensor([2, 5, 1], dtype=torch.long)

...

# model.py
# init embedding layer
emb_layer = nn.Embedding(3, 8)

data.x = emb_layer(data.x)
# data.x_emb = emb_layer(data.x)

print(data)
>>> Data(x=[3, 8], edge_index=[2, 4], freq=[3])

就像上面展示的这样,我的做法是先把节点的 ID 填进去,接着在模型里面进行 Embedding,当然你可以直接使用 data.x = emb_layer(data.x) 把原来的 ID 给替换掉;也可能你需要保留 ID,那么就可以把它放到一个新的属性中,比如 data.x_emb。总的来说其实就是不需要把 data 的属性看得那么死,完全可以按照你的需要来。

关于 Batch 和 DataLoader

Batch

在 PyG 中,你可以通过手动的方式进行批量化操作:

代码语言:javascript
复制
import torch
from torch import nn
from torch_geometric.data import Data, Batch

edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([1, 0, 2], dtype=torch.long)
g1 = Data(x=x, edge_index=edge_index)
g2 = Data(x=x, edge_index=edge_index)
g3 = Data(x=x, edge_index=edge_index)
data = [g1, g2, g3]
>>> Data(x=[3], edge_index=[2, 4])

batch = Batch().from_data_list(data)
>>> DataDataBatch(x=[9], edge_index=[2, 12], batch=[9], ptr=[4])

在上面的例子中,为了方便我创建了 3 张相同的图。现在我们来看看 batch 具体做了些什么:

代码语言:javascript
复制
print(batch.x)
>>> tensor([1, 0, 2, 1, 0, 2, 1, 0, 2])

print(batch.edge_index)
>>> tensor([[0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8],
            [1, 0, 2, 1, 4, 3, 5, 4, 7, 6, 8, 7]])

print(batch.batch)
>>> tensor([0, 0, 0, 1, 1, 1, 2, 2, 2])

print(batch.ptr)
>>> tensor([0, 3, 6, 9])

实际上,Batch 的本质就是将多个图组合起来,创建一张大图。可以看出,首先在 x 中将所有节点拼接在一起,batch 指明了每个节点所属的批次,ptr 指明了每个 batch 的节点的起始索引号,然后在 edge_index 中对边进行拼接,同时为了避免冲突对索引加上了 ptr

用公式来表示也就是:

DataLoader

PyTorch 原生的 DataLoader 实际上对 Data 并不支持,虽然可以创建成功,但在遍历取数据的时候,你会发现如下错误:

default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class ‘torch_geometric.data.data.Data’>

PyG 有一个自己的 DataLoader,实际上只需要用它替换 PyTorch 原生的 DataLoader 就可以,个人觉得使用体验上和 PyTorch 差别不大。

代码语言:javascript
复制
from torch.utils.data import Dataset
# from torch.utils.data import DataLoader
from torch_geometric.loader import DataLoader

class MyDataset(Dataset):
    def __init__(self, data):
        super(MyDataset, self).__init__()
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sub = self.data[idx]
        return sub

dataset = MyDataset(data)
dataloader = DataLoader(dataset)

图神经网络

讲完了图结构,以及数据集之后,现在正式进入到了模型训练阶段

Convolutional Layers

PyG 其实定义了非常多可供直接使用的 Convolutional Layers,具体你可以看这里

… 待施工 …

参考资料

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2022-11-09,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 使用 PyG 进行图神经网络训练
    • 前言
      • 图结构
        • 建图
        • 关于 train_mask
        • 关于 Embedding
      • 关于 Batch 和 DataLoader
        • Batch
        • DataLoader
      • 图神经网络
        • Convolutional Layers
      • 参考资料
      相关产品与服务
      批量计算
      批量计算(BatchCompute,Batch)是为有大数据计算业务的企业、科研单位等提供高性价比且易用的计算服务。批量计算 Batch 可以根据用户提供的批处理规模,智能地管理作业和调动其所需的最佳资源。有了 Batch 的帮助,您可以将精力集中在如何分析和处理数据结果上。
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档