前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >图神经网络实战-图注意力网络Tensorflow 2.0实现

图神经网络实战-图注意力网络Tensorflow 2.0实现

作者头像
YoungTimes
发布2022-04-28 19:24:18
5730
发布2022-04-28 19:24:18
举报

本文完整代码和数据已经上传到Github. https://github.com/YoungTimes/GNN/tree/master/GAT

1. GCN的另一个缺陷

GraphSAGE通过采样邻居的策略解决了GCN只能采用Full Patch训练的问题。在GAT中指出了GCN的另外两个缺点:

  1. 无法为不同的Neighbor节点指定不同的权重,也就说GCN对于同阶邻域上的不同邻居分配的权重是完全相同的,这限制了GCN模型对于空间信息相关系的捕捉能力;
  2. GCN聚合临近节点特征的方式与图(Graph)的结构密切相关,这限制了训练所得模型在其它图(Graph)结构上的泛化能力;

2. 引入注意力(Attention)机制

Graph Attention Network(GAT)将注意力(Attention)机制对邻居节点特征进行加权求和,不同的邻居节点有不同的权重;不同临近节点特征的权重完全取决于节点的特征,独立于图(Graph)结构,从而也能获得更好的泛化能力。

Graph Attention Network(GAT)与Graph Convolution Network的核心区别在于:如何聚合一阶邻居的信息

  1. GCN中聚合一阶邻居信息的过程如下:
h_{i}^{(l+1)}=\sigma\left(\sum_{j \in \mathcal{N}(i)} \frac{1}{c_{i j}} W^{(l)} h_{j}^{(l)}\right)

GAT中聚合一阶邻居的信息的过程如下:

\begin{aligned} z_{i}^{(l)} &=W^{(l)} h_{i}^{(l)} \\ e_{i j}^{(l)} &=\operatorname{LeakyReLU}\left(\vec{a}^{T(l)}\left(z_{i}^{(l)} \| z_{j}^{(l)}\right)\right) \\ \alpha_{i j}^{(l)} &=\frac{\exp \left(e_{i j}^{(l)}\right)}{\sum_{k \in \mathcal{N}(i)} \exp \left(e_{i k}^{(l)}\right)} \\ h_{i}^{(l+1)} &=\sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{i j}^{(l)} z_{j}^{(l)}\right) \end{aligned}

3. Graph Attention Model

下面一步步将公式转换为代码。

1.特征增强(Feature Augment)

首先对输入节点进行进行一次线性变换,从而对顶点的特征进行增维,这是一种常见的特征增强(Feature Augment)的方法。

z_{i}^{(l)} =W^{(l)} h_{i}^{(l)}
代码语言:javascript
复制
class MultiHeadGATLayer(tf.keras.layers.Layer):
    def __init__(self, in_dim, out_dim,
                 attn_heads = 1,
                 # ....
                 kernel_initializer = 'glorot_uniform'):

        self.in_dim = in_dim
        self.out_dim = out_dim

        self.attn_heads = attn_heads

        self.kernel_initializer = kernel_initializer
        
        # ...

        self.kernels = []

        super(MultiHeadGATLayer, self).__init__()

    def build(self, input_shape):
        assert len(input_shape) >= 2

        for head in range(self.attn_heads):
            kernel = self.add_weight(shape=(self.in_dim, self.out_dim),
                                     initializer=self.kernel_initializer,
                                     regularizer=self.kernel_regularizer,
                                     name='kernel_{}'.format(head))
            self.kernels.append(kernel)
            
            # ....

        self.built = True

    def call(self, inputs, training):
        X = inputs[0]

        outputs = []
        for head in range(self.attn_heads):

            kernel = self.kernels[head]

            features = tf.matmul(X, kernel)


2.计算注意力系数

e_{i j}^{(l)} =\operatorname{LeakyReLU}\left(\vec{a}^{T(l)}\left(z_{i}^{(l)} \| z_{j}^{(l)}\right)\right)

对于变换后的节点特征进行拼接(concatenate),然后通过矩阵a把拼接后的高维特征映射到一个实数上,这是通过Single Layer Feedforward Neural Network实现的,网络的激活函数是LeakyReLu。

什么是LeakyReLu函数? ReLu是将所有的负值都设为零,而LeakyReLu是给所有负值赋予一个非零斜率,在本论文中以数学的方式我们可以表示为:

y_{i}=\left\{\begin{array}{ll} x_{i} & \text { if } x_{i} \geq 0 \\ 0.2 & \text { if } x_{i} \leq 0 \end{array}\right.

显然,节点i和节点j的相关性是通过可学习的参数W和a(*)完成的。

完整的注意力机制公式如下:

\alpha_{i j}=\frac{\exp \left(\operatorname{Leaky} \operatorname{Re} L u\left(\overrightarrow{\mathrm{a}}^{T}\left[W \vec{h}_{i}|| W \vec{h}_{j}\right]\right)\right)}{\sum_{k \in N_{i}} \exp \left(\operatorname{Leaky} \operatorname{Re} L u\left(\overrightarrow{\mathrm{a}}^{T}\left[W \vec{h}_{i}|| W \vec{h}_{k}\right]\right)\right)}

效果如下图所示:

这里的Attention论文作者称之为Mask Graph Attention,这是因为Attention机制同事考虑了Graph的结构,注意力机制只在邻居节点上进行。

代码语言:javascript
复制
def call(self, inputs, training):
    X = inputs[0]
    A = inputs[1]

    N = X.shape[0]

    outputs = []
    for head in range(self.attn_heads):

        kernel = self.kernels[head]

        features = tf.matmul(X, kernel)

        concat_features = tf.concat(\
                [tf.reshape(tf.tile(features, [1, N]), [N * N, -1]),\
                tf.tile(features, [N, 1])], axis = 1)

        concat_features = tf.reshape(concat_features, [N, -1, 2 * self.out_dim])

        atten_kernel = self.atten_kernels[head]
            
        dense = tf.matmul(concat_features, atten_kernel)

        dense = tf.keras.layers.LeakyReLU(alpha=0.2)(dense)

        dense = tf.reshape(dense, [N, -1])

        zero_vec = -9e15 * tf.ones_like(dense)
        attention = tf.where(A > 0, dense, zero_vec)

        dense = tf.keras.activations.softmax(attention, axis = -1)

        dropout_attn = tf.keras.layers.Dropout(self.dropout_rate)(dense, training = training)
        dropout_feat = tf.keras.layers.Dropout(self.dropout_rate)(features, training = training)

        node_features = tf.matmul(dropout_attn, dropout_feat)
        
        if self.use_bias:
            node_features = tf.add(node_features, self.biases[head])

        if self.activation is not None:
            node_features = self.activation(node_features)

         # ...

3. Multi-head Attention

如同卷积神经网络(CNN)中滤波核一样,作者发现将多个Attention拼接起来,每个Attention结构可以学习到不同的空间特征,可以进一步提升网络的表达能力。

Attention的拼接的方式为concat或者avg。

h_{i}^{\prime}=\|_{k=1}^{K} \sigma\left(\sum_{j \in \mathcal{N}_{i}} \alpha_{i j}^{k} W^{k} h_{j}\right)
\vec{h}_{i}^{\prime}=\sigma\left(\frac{1}{K} \sum_{k=1}^{K} \sum_{j \in \mathcal{N}_{i}} \alpha_{i j}^{k} \mathbf{W}^{k} \vec{h}_{j}\right)
代码语言:javascript
复制
def call(self, inputs, training):

    outputs = []
    for head in range(self.attn_heads):

        # attention...
        
        outputs.append(node_features)

    if self.attn_heads_reduction == 'concat':
        output = tf.concat(outputs, axis = -1)
    else:
        output = tf.reduce_mean(tf.stack(outputs), axis=-1)

    return output

4. 网络训练过程

训练数据仍然采用Cora DataSet,这里不再详细介绍。数据获取的代码与GraphSAGE也几乎完全相同。这里一个核心的区别是,我把整个Graph切成一个个的小图,同时把Graph的Edge也做了切割,保证Gapha的Edge与Node完全匹配。

代码语言:javascript
复制
def train():
    for e in range(EPOCHS):
        for batch in range(NUM_BATCH_PER_EPOCH):
            batch_src_index = np.random.choice(train_index, size=(BTACH_SIZE,))
            batch_src_label = train_label[batch_src_index].astype(float)

            batch_sampling_x = data.x[batch_src_index]
            batch_adj = data.adj[np.ix_(batch_src_index, batch_src_index)]

            loss = 0.0
            with tf.GradientTape() as tape:
                batch_train_logits = model([batch_sampling_x, batch_adj], training = True)
                loss = loss_object(batch_src_label, batch_train_logits)
                grads = tape.gradient(loss, model.trainable_variables)

                optimizer.apply_gradients(zip(grads, model.trainable_variables))

最后对网络进行训练(Trainning).

代码语言:javascript
复制
from train_batch import train

train()
代码语言:javascript
复制
Process data ...
Loading cora dataset...
Epoch 000 train accuracy: 0.7599999904632568 val accuracy: 0.5028571486473083 test accuracy:0.40625
Epoch 001 train accuracy: 0.9266666769981384 val accuracy: 0.5400000214576721 test accuracy:0.4778079688549042
Epoch 002 train accuracy: 0.9666666388511658 val accuracy: 0.5571428537368774 test accuracy:0.5294383764266968
Epoch 003 train accuracy: 0.9800000190734863 val accuracy: 0.5857142806053162 test accuracy:0.554347813129425
Epoch 004 train accuracy: 0.9733333587646484 val accuracy: 0.5685714483261108 test accuracy:0.5036231875419617
Epoch 005 train accuracy: 0.9733333587646484 val accuracy: 0.5628571510314941 test accuracy:0.5335144996643066
Epoch 006 train accuracy: 0.9800000190734863 val accuracy: 0.545714259147644 test accuracy:0.5375905632972717
Epoch 007 train accuracy: 0.9800000190734863 val accuracy: 0.5600000023841858 test accuracy:0.5149456262588501
Epoch 008 train accuracy: 0.9800000190734863 val accuracy: 0.5771428346633911 test accuracy:0.5652173757553101
Epoch 009 train accuracy: 0.9933333396911621 val accuracy: 0.5428571701049805 test accuracy:0.5321558117866516
Epoch 010 train accuracy: 0.9933333396911621 val accuracy: 0.5542857050895691 test accuracy:0.5276268124580383
Epoch 011 train accuracy: 0.9866666793823242 val accuracy: 0.5485714077949524 test accuracy:0.5185688138008118
Epoch 012 train accuracy: 0.9866666793823242 val accuracy: 0.5799999833106995 test accuracy:0.5398550629615784
Epoch 013 train accuracy: 0.9866666793823242 val accuracy: 0.5657142996788025 test accuracy:0.5466485619544983
Epoch 014 train accuracy: 0.9866666793823242 val accuracy: 0.5542857050895691 test accuracy:0.508152186870575
Epoch 015 train accuracy: 0.9866666793823242 val accuracy: 0.5571428537368774 test accuracy:0.5335144996643066
Epoch 016 train accuracy: 0.9933333396911621 val accuracy: 0.5657142996788025 test accuracy:0.5457427501678467
Epoch 017 train accuracy: 0.9866666793823242 val accuracy: 0.5828571319580078 test accuracy:0.542119562625885
Epoch 018 train accuracy: 0.9933333396911621 val accuracy: 0.5771428346633911 test accuracy:0.5557065010070801
Epoch 019 train accuracy: 0.9866666793823242 val accuracy: 0.5771428346633911 test accuracy:0.5439311861991882

最终的效果如上图所示,在训练集和验证集上的效果均比GraphSAGE有较大差距,并且波动很大,具体原因待进一步追查!

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-06-21,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 半杯茶的小酒杯 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. GCN的另一个缺陷
  • 2. 引入注意力(Attention)机制
  • 3. Graph Attention Model
    • 1.特征增强(Feature Augment)
      • 2.计算注意力系数
        • 3. Multi-head Attention
        • 4. 网络训练过程
        领券
        问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档