前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >节点聚类分析:DeepWalk + K-means

节点聚类分析:DeepWalk + K-means

作者头像
Cyril-KI
发布2022-09-19 14:19:18
2440
发布2022-09-19 14:19:18
举报
文章被收录于专栏:KI的算法杂记

1.DeepWalk算法分析

同node2vec一致,我们也需要先找到所有节点的 条随机游走路径,每条路径长度为 。

DeepWalk最大的贡献是将Skip-Gram模型引入到了图嵌入中,至于随机游走方式则并不如node2vec那样复杂。

DeepWalk中的RandomWalk:每次随机选择当前节点的一个邻居节点!

2.DeepWalk代码实现

2.1 RandomWalk

代码语言:javascript
复制
def random_walk(self, u):
    g = self.G
    walk = [u]
    while len(walk) < self.l:
        curr = walk[-1]
        v_curr = list(g.neighbors(curr))
        if len(v_curr) > 0:
            walk.append(random.choice(v_curr))
        else:
            break

     return walk

代码分析:

代码语言:javascript
复制
walk.append(random.choice(v_curr))

每次随机选择当前节点的一个邻居节点作为下一个节点。

2.2 Skip-Gram

得到所有节点的全部路径后,利用gensim库中的Word2Vec来进行训练:

代码语言:javascript
复制
def learning_features(self):
    g = self.G
    walks = []
    nodes = list(g.nodes())
    for t in range(self.r):
        random.shuffle(nodes)
        for node in nodes:
            walk = self.random_walk(node)
            walks.append(walk)
    # embedding
    walks = [list(map(str, walk)) for walk in walks]
    model = Word2Vec(sentences=walks, vector_size=self.d, window=self.k, min_count=0, sg=1, workers=3)
    f = model.wv
    return f

调用:

代码语言:javascript
复制
if __name__ == '__main__':
    d, r, l, k = 128, 10, 80, 10
    G = nx.les_miserables_graph()
    deep_walk = deep_walk(G, d, r, l, k)
    model = deep_walk.learning_features()

3.节点聚类分析

得到所有节点的向量表示后,简单利用k-means对所有节点进行聚类:

代码语言:javascript
复制
def k_means(self, m, K):
    """
    :param m: deepwalk的训练结果
    :param K: 类别数
    :return: 聚类结果
    """
    d = self.d
    nodes = list(G.nodes)
    # 任意选择K个节点作为初始聚类中心
    centers = []
    temp = []
    for i in range(K):
        t = np.random.randint(0, len(nodes) - 1)
        if nodes[t] not in centers:
            temp.append(nodes[t])
            centers.append(m[nodes[t]])  # 中心为向量

    # 迭代50次
    res = {}
    for i in range(K):
        res[i] = []

    for time in range(50):
        # clear
        for i in range(K):
            res[i].clear()
        # 算出每个点的向量到聚类中心的距离
        nodes_distance = {}
        for node in nodes:
            # node到中心节点的距离
            node_distance = []
            for center in centers:
                node_distance.append(get_dis(m[node], center))
            nodes_distance[node] = node_distance  # 保存node节点到各个中心的距离
        # 对每个节点重新划分类别,选择一个最近的节点进行分类,类别为0-5
        for node in nodes:
            temp = nodes_distance[node]  # 存放着6个距离
            cls = temp.index(min(temp))
            res[cls].append(node)

        # 更新聚类中心
        centers.clear()
        for i in range(K):
            center = []
            for j in range(d):
                t = [m[node][j] for node in res[i]]  # 第i个类别中所有node节点的第j个坐标
                center.append(np.mean(t))
            centers.append(center)

    return res

画图:

代码语言:javascript
复制
def plot(self, m, K):
    """
    :param m: 节点的向量表示
    :param K: 聚类个数
    :return: none
    """
    g = self.G
    # 根据原始标签画图
    pos = nx.spring_layout(G)

    color_map = []
    ns = list(G.nodes.data())
    nodes = list(g.nodes)
    # for node in range(len(nodes)):
    #     if ns[node][1]['bipartite'] == 0:
    #         color_map.append('#DCBB8A')
    #     else:
    #         color_map.append('#98BBEF')
    #
    # plt.subplot(2, 1, 1)
    # nx.draw(G, node_color=color_map, pos=pos, with_labels=True, node_size=1000)

    res = self.k_means(m, K)

    colors = ['#DCBB8A', '#98BBEF', 'navy', 'indigo', 'orange', 'blue']
    color_map.clear()
    for node in nodes:
        for i in range(len(res)):
            if node in res[i]:
                color_map.append(colors[i])
                break
    # draw
    # plt.subplot(2, 1, 2)
    nx.draw(G, node_color=color_map, pos=pos, with_labels=True, node_size=1000)
    plt.show()

KDD2016 | node2vec:可拓展的网络特征学习中对《悲惨世界》中的人物关系网络做了节点聚类分析,共分为六类,结果如下:

采用DeepWalk的效果为(k-means迭代50次):

迭代100次:

4.完整代码

代码语言:javascript
复制
# -*- coding: utf-8 -*-
"""
@Time : 2021/12/18 20:34
@Author :KI 
@File :deepwalk.py
@Motto:Hungry And Humble

"""
import networkx as nx
from gensim.models import Word2Vec
from numpy import random
import numpy as np
import matplotlib.pyplot as plt


class deep_walk:
    def __init__(self, G, d, r, l, k):
        self.G = G
        self.d = d  # 向量维度
        self.r = r  # 游走个数
        self.l = l  # 游走长度
        self.k = k  # 上下文长度

    def random_walk(self, u):
        g = self.G
        walk = [u]
        while len(walk) < self.l:
            curr = walk[-1]
            v_curr = list(g.neighbors(curr))
            if len(v_curr) > 0:
                walk.append(random.choice(v_curr))
            else:
                break

        return walk

    def learning_features(self):
        g = self.G
        walks = []
        nodes = list(g.nodes())
        for t in range(self.r):
            random.shuffle(nodes)
            for node in nodes:
                walk = self.random_walk(node)
                walks.append(walk)
        # embedding
        walks = [list(map(str, walk)) for walk in walks]
        model = Word2Vec(sentences=walks, vector_size=self.d, window=self.k, min_count=0, sg=1, workers=3)
        f = model.wv
        print(f['MmeBurgon'])
        return f

    def plot(self, m, K):
        """
        :param m: 节点的向量表示
        :param K: 聚类个数
        :return: none
        """
        g = self.G
        # 根据原始标签画图
        pos = nx.spring_layout(G)

        color_map = []
        ns = list(G.nodes.data())
        nodes = list(g.nodes)

        res = self.k_means(m, K, 50)

        colors = ['#DCBB8A', '#98BBEF', 'navy', 'indigo', 'orange', 'blue']
        color_map.clear()
        for node in nodes:
            for i in range(len(res)):
                if node in res[i]:
                    color_map.append(colors[i])
                    break
        # draw
        # plt.subplot(2, 1, 2)
        nx.draw(G, node_color=color_map, pos=pos, with_labels=False, node_size=2000)
        plt.show()

        res = self.k_means(m, K, 100)

        colors = ['#DCBB8A', '#98BBEF', 'navy', 'indigo', 'orange', 'blue']
        color_map.clear()
        for node in nodes:
            for i in range(len(res)):
                if node in res[i]:
                    color_map.append(colors[i])
                    break
        # draw
        # plt.subplot(2, 1, 2)
        nx.draw(G, node_color=color_map, pos=pos, with_labels=False, node_size=2000)

        plt.show()

    def get_dis(self, x, y):
        s = 0
        for i in range(len(x)):
            s += (x[i] - y[i]) ** 2

        return np.sqrt(s)

    def k_means(self, m, K, t):
        """
        :param m: 节点的向量表示
        :param K: 类别数
        :return: 聚类结果
        """
        d = self.d
        nodes = list(G.nodes)
        # 任意选择K个节点作为初始聚类中心
        centers = []
        temp = []
        for i in range(K):
            t = np.random.randint(0, len(nodes) - 1)
            if nodes[t] not in temp:
                temp.append(nodes[t])
                centers.append(m[nodes[t]])  # 中心为128维向量

        # 迭代50次
        res = {}
        for i in range(K):
            res[i] = []

        for time in range(t):
            # clear
            for i in range(K):
                res[i].clear()
            # 算出每个点的向量到聚类中心的距离
            nodes_distance = {}
            for node in nodes:
                # node到中心节点的距离
                node_distance = []
                for center in centers:
                    node_distance.append(self.get_dis(m[node], center))
                nodes_distance[node] = node_distance  # 保存node节点到各个中心的距离
            # 对每个节点重新划分类别,选择一个最近的节点进行分类,类别为0-5
            for node in nodes:
                temp = nodes_distance[node]  # 存放着6个距离
                cls = temp.index(min(temp))
                res[cls].append(node)

            # 更新聚类中心
            centers.clear()
            for i in range(K):
                center = []
                for j in range(d):
                    t = [m[node][j] for node in res[i]]  # 第i个类别中所有node节点的第j个坐标
                    center.append(np.mean(t))
                centers.append(center)

        return res


if __name__ == '__main__':
    d, r, l, k = 128, 10, 80, 10
    G = nx.les_miserables_graph()
    deep_walk = deep_walk(G, d, r, l, k)
    model = deep_walk.learning_features()
    deep_walk.plot(model, 6)
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2021-12-19,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 KI的算法杂记 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1.DeepWalk算法分析
  • 2.DeepWalk代码实现
    • 2.1 RandomWalk
      • 2.2 Skip-Gram
      • 3.节点聚类分析
      相关产品与服务
      腾讯云代码分析
      腾讯云代码分析(内部代号CodeDog)是集众多代码分析工具的云原生、分布式、高性能的代码综合分析跟踪管理平台,其主要功能是持续跟踪分析代码,观测项目代码质量,助力维护团队卓越代码文化。
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档