本文完整代码和数据已经上传到Github. https://github.com/YoungTimes/GNN/tree/master/GAT
GraphSAGE通过采样邻居的策略解决了GCN只能采用Full Patch训练的问题。在GAT中指出了GCN的另外两个缺点:
Graph Attention Network(GAT)将注意力(Attention)机制对邻居节点特征进行加权求和,不同的邻居节点有不同的权重;不同临近节点特征的权重完全取决于节点的特征,独立于图(Graph)结构,从而也能获得更好的泛化能力。
Graph Attention Network(GAT)与Graph Convolution Network的核心区别在于:如何聚合一阶邻居的信息。
GAT中聚合一阶邻居的信息的过程如下:
下面一步步将公式转换为代码。
首先对输入节点进行进行一次线性变换,从而对顶点的特征进行增维,这是一种常见的特征增强(Feature Augment)的方法。
class MultiHeadGATLayer(tf.keras.layers.Layer):
def __init__(self, in_dim, out_dim,
attn_heads = 1,
# ....
kernel_initializer = 'glorot_uniform'):
self.in_dim = in_dim
self.out_dim = out_dim
self.attn_heads = attn_heads
self.kernel_initializer = kernel_initializer
# ...
self.kernels = []
super(MultiHeadGATLayer, self).__init__()
def build(self, input_shape):
assert len(input_shape) >= 2
for head in range(self.attn_heads):
kernel = self.add_weight(shape=(self.in_dim, self.out_dim),
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
name='kernel_{}'.format(head))
self.kernels.append(kernel)
# ....
self.built = True
def call(self, inputs, training):
X = inputs[0]
outputs = []
for head in range(self.attn_heads):
kernel = self.kernels[head]
features = tf.matmul(X, kernel)
对于变换后的节点特征进行拼接(concatenate),然后通过矩阵a把拼接后的高维特征映射到一个实数上,这是通过Single Layer Feedforward Neural Network实现的,网络的激活函数是LeakyReLu。
什么是LeakyReLu函数? ReLu是将所有的负值都设为零,而LeakyReLu是给所有负值赋予一个非零斜率,在本论文中以数学的方式我们可以表示为:
显然,节点i和节点j的相关性是通过可学习的参数W和a(*)完成的。
完整的注意力机制公式如下:
效果如下图所示:
这里的Attention论文作者称之为Mask Graph Attention,这是因为Attention机制同事考虑了Graph的结构,注意力机制只在邻居节点上进行。
def call(self, inputs, training):
X = inputs[0]
A = inputs[1]
N = X.shape[0]
outputs = []
for head in range(self.attn_heads):
kernel = self.kernels[head]
features = tf.matmul(X, kernel)
concat_features = tf.concat(\
[tf.reshape(tf.tile(features, [1, N]), [N * N, -1]),\
tf.tile(features, [N, 1])], axis = 1)
concat_features = tf.reshape(concat_features, [N, -1, 2 * self.out_dim])
atten_kernel = self.atten_kernels[head]
dense = tf.matmul(concat_features, atten_kernel)
dense = tf.keras.layers.LeakyReLU(alpha=0.2)(dense)
dense = tf.reshape(dense, [N, -1])
zero_vec = -9e15 * tf.ones_like(dense)
attention = tf.where(A > 0, dense, zero_vec)
dense = tf.keras.activations.softmax(attention, axis = -1)
dropout_attn = tf.keras.layers.Dropout(self.dropout_rate)(dense, training = training)
dropout_feat = tf.keras.layers.Dropout(self.dropout_rate)(features, training = training)
node_features = tf.matmul(dropout_attn, dropout_feat)
if self.use_bias:
node_features = tf.add(node_features, self.biases[head])
if self.activation is not None:
node_features = self.activation(node_features)
# ...
如同卷积神经网络(CNN)中滤波核一样,作者发现将多个Attention拼接起来,每个Attention结构可以学习到不同的空间特征,可以进一步提升网络的表达能力。
Attention的拼接的方式为concat或者avg。
def call(self, inputs, training):
outputs = []
for head in range(self.attn_heads):
# attention...
outputs.append(node_features)
if self.attn_heads_reduction == 'concat':
output = tf.concat(outputs, axis = -1)
else:
output = tf.reduce_mean(tf.stack(outputs), axis=-1)
return output
训练数据仍然采用Cora DataSet,这里不再详细介绍。数据获取的代码与GraphSAGE也几乎完全相同。这里一个核心的区别是,我把整个Graph切成一个个的小图,同时把Graph的Edge也做了切割,保证Gapha的Edge与Node完全匹配。
def train():
for e in range(EPOCHS):
for batch in range(NUM_BATCH_PER_EPOCH):
batch_src_index = np.random.choice(train_index, size=(BTACH_SIZE,))
batch_src_label = train_label[batch_src_index].astype(float)
batch_sampling_x = data.x[batch_src_index]
batch_adj = data.adj[np.ix_(batch_src_index, batch_src_index)]
loss = 0.0
with tf.GradientTape() as tape:
batch_train_logits = model([batch_sampling_x, batch_adj], training = True)
loss = loss_object(batch_src_label, batch_train_logits)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
最后对网络进行训练(Trainning).
from train_batch import train
train()
Process data ...
Loading cora dataset...
Epoch 000 train accuracy: 0.7599999904632568 val accuracy: 0.5028571486473083 test accuracy:0.40625
Epoch 001 train accuracy: 0.9266666769981384 val accuracy: 0.5400000214576721 test accuracy:0.4778079688549042
Epoch 002 train accuracy: 0.9666666388511658 val accuracy: 0.5571428537368774 test accuracy:0.5294383764266968
Epoch 003 train accuracy: 0.9800000190734863 val accuracy: 0.5857142806053162 test accuracy:0.554347813129425
Epoch 004 train accuracy: 0.9733333587646484 val accuracy: 0.5685714483261108 test accuracy:0.5036231875419617
Epoch 005 train accuracy: 0.9733333587646484 val accuracy: 0.5628571510314941 test accuracy:0.5335144996643066
Epoch 006 train accuracy: 0.9800000190734863 val accuracy: 0.545714259147644 test accuracy:0.5375905632972717
Epoch 007 train accuracy: 0.9800000190734863 val accuracy: 0.5600000023841858 test accuracy:0.5149456262588501
Epoch 008 train accuracy: 0.9800000190734863 val accuracy: 0.5771428346633911 test accuracy:0.5652173757553101
Epoch 009 train accuracy: 0.9933333396911621 val accuracy: 0.5428571701049805 test accuracy:0.5321558117866516
Epoch 010 train accuracy: 0.9933333396911621 val accuracy: 0.5542857050895691 test accuracy:0.5276268124580383
Epoch 011 train accuracy: 0.9866666793823242 val accuracy: 0.5485714077949524 test accuracy:0.5185688138008118
Epoch 012 train accuracy: 0.9866666793823242 val accuracy: 0.5799999833106995 test accuracy:0.5398550629615784
Epoch 013 train accuracy: 0.9866666793823242 val accuracy: 0.5657142996788025 test accuracy:0.5466485619544983
Epoch 014 train accuracy: 0.9866666793823242 val accuracy: 0.5542857050895691 test accuracy:0.508152186870575
Epoch 015 train accuracy: 0.9866666793823242 val accuracy: 0.5571428537368774 test accuracy:0.5335144996643066
Epoch 016 train accuracy: 0.9933333396911621 val accuracy: 0.5657142996788025 test accuracy:0.5457427501678467
Epoch 017 train accuracy: 0.9866666793823242 val accuracy: 0.5828571319580078 test accuracy:0.542119562625885
Epoch 018 train accuracy: 0.9933333396911621 val accuracy: 0.5771428346633911 test accuracy:0.5557065010070801
Epoch 019 train accuracy: 0.9866666793823242 val accuracy: 0.5771428346633911 test accuracy:0.5439311861991882
最终的效果如上图所示,在训练集和验证集上的效果均比GraphSAGE有较大差距,并且波动很大,具体原因待进一步追查!