前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >图神经网络(GNN)TensorFlow实现

图神经网络(GNN)TensorFlow实现

作者头像
里克贝斯
发布2021-05-21 10:22:04
1.1K0
发布2021-05-21 10:22:04
举报
文章被收录于专栏:图灵技术域图灵技术域

图神经网络的研究与图嵌入或网络嵌入密切相关,图嵌入或网络嵌入是数据挖掘和机器学习界日益关注的另一个课题。图嵌入旨在通过保留图的网络拓扑结构和节点内容信息,将图中顶点表示为低维向量,以便使用简单的机器学习算法(例如,支持向量机分类)进行处理。许多图嵌入算法通常是无监督的算法,它们可以大致可以划分为三个类别,即矩阵分解、随机游走和深度学习方法。同时图嵌入的深度学习方法也属于图神经网络,包括基于图自动编码器的算法(如DNGR和SDNE)和无监督训练的图卷积神经网络(如GraphSage)。——https://zhuanlan.zhihu.com/p/75307407

下面给出一个图神经网络TensorFlow的实现,代码参考自:https://github.com/Ivan0131/gnn_demo

Python

代码语言:javascript
复制
import tensorflow as tf
import numpy as np
import gnn.gnn_utils as gnn_utils

data_path = "./data"
set_name = "sub_15_7_200"

# 训练集
inp, arcnode, nodegraph, nodein, labels, _ = gnn_utils.set_load_general(data_path, "train", set_name=set_name)
inp = [a[:, 1:] for a in inp]

# 验证集
inp_val, arcnode_val, nodegraph_val, nodein_val, labels_val, _ = gnn_utils.set_load_general(data_path, "validation",
                                                                                            set_name=set_name)
inp_val = [a[:, 1:] for a in inp_val]
input_dim = len(inp[0][0])
state_dim = 10
output_dim = 2
state_threshold = 0.001
max_iter = 50

tf.compat.v1.disable_eager_execution()
tf.reset_default_graph()
comp_inp = tf.placeholder(tf.float32, shape=(None, input_dim), name="input")
y = tf.placeholder(tf.float32, shape=(None, output_dim), name="target")

state = tf.placeholder(tf.float32, shape=(None, state_dim), name="state")
state_old = tf.placeholder(tf.float32, shape=(None, state_dim), name="old_state")

ArcNode = tf.sparse_placeholder(tf.float32, name="ArcNode")


def f_w(inp):
    with tf.variable_scope('State_net'):
        layer1 = tf.layers.dense(inp, 5, activation=tf.nn.sigmoid)
        layer2 = tf.layers.dense(layer1, state_dim, activation=tf.nn.sigmoid)
        return layer2


def g_w(inp):
    with tf.variable_scope('Output_net'):
        layer1 = tf.layers.dense(inp, 5, activation=tf.nn.sigmoid)
        layer2 = tf.layers.dense(layer1, output_dim, activation=None)
        return layer2


def convergence(a, state, old_state, k):
    with tf.variable_scope('Convergence'):
        # assign current state to old state
        old_state = state

        # 获取子结点上一个时刻的状态
        # grub states of neighboring node
        gat = tf.gather(old_state, tf.cast(a[:, 0], tf.int32))
        print(a[:, 0])

        # 去除第一列,即子结点的id
        # slice to consider only label of the node and that of it's neighbor
        # sl = tf.slice(a, [0, 1], [tf.shape(a)[0], tf.shape(a)[1] - 1])
        # equivalent code
        sl = a[:, 1:]

        # 将子结点上一个时刻的状态放到最后一列
        # concat with retrieved state
        inp = tf.concat([sl, gat], axis=1)
        print('inp', inp)

        # evaluate next state and multiply by the arch-node conversion matrix to obtain per-node states
        # 计算子结点对父结点状态的贡献
        layer1 = f_w(inp)
        # 聚合子结点对父结点状态的贡献,得到当前时刻的父结点的状态
        print('ArcNode', ArcNode)
        state = tf.sparse_tensor_dense_matmul(ArcNode, layer1)

        # update the iteration counter
        k = k + 1
    return a, state, old_state, k


def condition(a, state, old_state, k):
    # evaluate condition on the convergence of the state
    with tf.variable_scope('condition'):
        # 检查当前状态和上一个时刻的状态的欧式距离是否小于阈值
        # evaluate distance by state(t) and state(t-1)
        outDistance = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(state, old_state)), 1) + 1e-10)
        # vector showing item converged or not (given a certain threshold)
        checkDistanceVec = tf.greater(outDistance, state_threshold)
        print(outDistance)
        print(checkDistanceVec)
        c1 = tf.reduce_any(checkDistanceVec)
        print(c1)
        # 是否达到最大迭代次数
        c2 = tf.less(k, max_iter)
        print(c2)
    return tf.logical_and(c1, c2)


# 迭代计算,直到状态达到稳定状态
with tf.variable_scope('Loop'):
    k = tf.constant(0)
    res, st, old_st, num = tf.while_loop(condition, convergence, [comp_inp, state, state_old, k])
    out = g_w(st)

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=out))
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(out, 1), tf.argmax(y, 1)), dtype=tf.float32))
optimizer = tf.train.AdamOptimizer(0.001)
grads = optimizer.compute_gradients(loss)
train_op = optimizer.apply_gradients(grads, name='train_op')



# 模型训练
num_epoch = 5000
# 训练集placeholder输入
arcnode_train = tf.SparseTensorValue(indices=arcnode[0].indices, values=arcnode[0].values,
                                     dense_shape=arcnode[0].dense_shape)
fd_train = {comp_inp: inp[0], state: np.zeros((arcnode[0].dense_shape[0], state_dim)),
            state_old: np.ones((arcnode[0].dense_shape[0], state_dim)),
            ArcNode: arcnode_train, y: labels}
# 验证集placeholder输入
arcnode_valid = tf.SparseTensorValue(indices=arcnode_val[0].indices, values=arcnode_val[0].values,
                                     dense_shape=arcnode_val[0].dense_shape)
fd_valid = {comp_inp: inp_val[0], state: np.zeros((arcnode_val[0].dense_shape[0], state_dim)),
            state_old: np.ones((arcnode_val[0].dense_shape[0], state_dim)),
            ArcNode: arcnode_valid, y: labels_val}

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    for i in range(0, num_epoch):
        _, loss_val, accuracy_val = sess.run(
            [train_op, loss, accuracy],
            feed_dict=fd_train)
        if i % 100 == 0:
            loss_valid_val, accuracy_valid_val = sess.run(
                [loss, accuracy],
                feed_dict=fd_valid)
            print(
                "iter %s\t training loss: %s,\t training accuracy: %s,\t validation loss: %s,\t validation accuracy: %s" %
                (i, loss_val, accuracy_val, loss_valid_val, accuracy_valid_val))

数据集和参考代码:https://github.com/Ivan0131/gnn_demo

相关文章

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2020-08-17,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 相关文章
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档