前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >PyG搭建GCN实现节点分类

PyG搭建GCN实现节点分类

作者头像
Cyril-KI
发布2022-11-09 15:05:05
1.2K0
发布2022-11-09 15:05:05
举报
文章被收录于专栏:KI的算法杂记KI的算法杂记

I. 前言

GCN原理可以参考:ICLR 2017 | GCN:基于图卷积网络的半监督分类

一开始是打算手写一下GCN,毕竟原理也不是很难,但想了想还是直接调包吧。在使用各种深度学习框架时我们首先需要知道的是框架内的数据结构,因此这篇文章分为两个部分:第一部分数据处理,主要讲解PyG中的数据结构,第二部分模型搭建。

PyG (PyTorch Geometric)是一个基于PyTorch构建的库,可轻松编写和训练图形神经网络 (GNN),用于与结构化数据相关的广泛应用。

II. PyG数据结构

原始论文中使用的数据集:

这里就以Citeseer网络为例。Citeseer网络是一个引文网络,节点为论文,一共3327篇论文。论文一共分为六类:Agents、AI(人工智能)、DB(数据库)、IR(信息检索)、ML(机器语言)和HCI。如果两篇论文间存在引用关系,那么它们之间就存在链接关系。

使用PyG加载数据集:

代码语言:javascript
复制
data = Planetoid(root='/data/CiteSeer', name='CiteSeer')
print(len(data))

输出为1,说明CiteSeer中只有一个网络,然后我们输出一下这个网络:

代码语言:javascript
复制
data = data[0]
print(data)
print(data.is_directed())
代码语言:javascript
复制
Data(x=[3327, 3703], edge_index=[2, 9104], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327])
False

1. x=[3327, 3703]。表示一共有3327个节点,然后节点的特征维度为3703,这里实际上是去除停用词和在文档中出现频率小于10次的词,整理得到的3703个唯一词。

2. edge_index=[2, 9104],表示一共9104条edge。数据一共两行,每一行都表示节点编号。

3. 输出一下data.y:

代码语言:javascript
复制
tensor([3, 1, 5,  ..., 3, 1, 5])

data.y表示节点的标签编号,比如3表示该篇论文属于第3类。

4. 输出data.train_mask:

代码语言:javascript
复制
tensor([ True,  True,  True,  ..., False, False, False])

data.train_mask的长度和y的长度一致,如果某个位置为True就表示该样本为训练样本。val_mask和test_mask类似,分别表示验证集和训练集。

那么很显然,如果我们最终得到了预测值,我们就可以通过以下代码来计算分类的正确数:

代码语言:javascript
复制
correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())

III. GCN

首先导入包:

代码语言:javascript
复制
from torch_geometric.nn import GCNConv

模型参数:

1. in_channels:输入通道,比如节点分类中表示每个节点的特征数。

2. out_channels:输出通道,最后一层GCNConv的输出通道为节点类别数(节点分类)。

3. improved:如果为True表示自环加强,也就是原始邻接矩阵基础上加上2I而不是I,默认为False。

4. cached:如果为True,GCNConv在第一次对邻接矩阵进行归一化时会进行缓存,以后将不再重复计算。

5. add_self_loops:如果为False不再强制添加自环,默认为True。

6. normalize:默认为True,表示对邻接矩阵进行归一化。

7. bias:默认添加偏置。

于是模型搭建如下:

代码语言:javascript
复制
class GCN(torch.nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_node_features, 16)
        self.conv2 = GCNConv(16, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = F.softmax(x, dim=1)

        return x

1. 前向传播

查看官方文档中GCNConv的输入输出要求:

可以发现,GCNConv中需要输入的是节点特征矩阵x和邻接关系edge_index,还有一个可选项edge_weight。因此我们首先:

代码语言:javascript
复制
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)

此时我们不妨输出一下x及其size:

代码语言:javascript
复制
tensor([[0.0000, 0.1630, 0.0000,  ..., 0.0000, 0.0488, 0.0000],
        [0.0000, 0.2451, 0.1614,  ..., 0.0000, 0.0125, 0.0000],
        [0.1175, 0.0262, 0.2141,  ..., 0.2592, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.1825, 0.0000],
        [0.0000, 0.1024, 0.0000,  ..., 0.0498, 0.0000, 0.0000],
        [0.0000, 0.3263, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0', grad_fn=<FusedDropoutBackward0>)
代码语言:javascript
复制
torch.Size([3327, 16])

此时的x一共3327行,每一行表示一个节点经过第一层卷积更新后的状态向量。第二层卷积同理,即最终输出为:

代码语言:javascript
复制
torch.Size([3327, 6])

即每个节点的维度为6的状态向量。由于我们需要进行6分类,所以最后需要加上一个softmax:

代码语言:javascript
复制
x = F.softmax(x, dim=1)

dim=1表示对每一行进行运算,最终每一行之和加起来为1,也就表示了该节点为每一类的概率。

2. 反向传播

在训练时,我们首先利用前向传播计算出输出,然后算出损失函数:

代码语言:javascript
复制
out = model(data)
代码语言:javascript
复制
loss = loss_function(out[data.train_mask], data.y[data.train_mask])

然后计算梯度,反向更新!

3. 模型训练

代码语言:javascript
复制
def train():
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    loss_function = torch.nn.CrossEntropyLoss().to(device)
    model.train()
    for epoch in range(500):
        out = model(data)
        optimizer.zero_grad()
        loss = loss_function(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()

        print('Epoch {:03d} loss {:.4f}'.format(epoch, loss.item()))

4. 模型测试

代码语言:javascript
复制
def test(model, data):
    model.eval()
    _, pred = model(data).max(dim=1)
    correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
    acc = correct / int(data.test_mask.sum())
    print('GCN Accuracy: {:.4f}'.format(acc))

IV. 完整代码

完整代码及数据:https://github.com/ki-ljl/PyG-GCN,点击阅读原文即可跳转至代码下载界面。

项目结构:

README文件:

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2022-03-14,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 KI的算法杂记 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. 前向传播
  • 2. 反向传播
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档