首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >聚类模型--K 均值

聚类模型--K 均值

作者头像
黑泽君
发布2019-05-19 16:32:10
7530
发布2019-05-19 16:32:10
举报
文章被收录于专栏:黑泽君的专栏黑泽君的专栏

聚类模型--K 均值

0.引入依赖

import numpy as np
import matplotlib.pyplot as plt

# 这里直接 sklearn 里的数据集
from sklearn.datasets.samples_generator import make_blobs

1.数据的加载和预处理

x, y = make_blobs(n_samples=100, centers=6, random_state=1234, cluster_std=0.6)
# x # array([[-0.02708305,  5.0215929 ], ..., [-4.7583093 ,  5.85803377]])
# x.shape # (100, 2)
# y.shape # (100,)

plt.figure(figsize=(9, 9))
plt.scatter(x[:,0], x[:,1], c=y)
plt.show()

作图如下:

2.算法实现

# 引入 scipy 库中的距离函数,默认实现是欧式距离
from scipy.spatial.distance import cdist

class K_Means(object):
    # 初始化,参数 n_clusters(K)、max_iter(迭代次数)、centroids(初始质心)
    def __init__(self, n_clusters=6, max_iter=300, centroids=[]):
        self.n_clusters = n_clusters
        self.max_iter = max_iter
        self.centroids = np.array(centroids, dtype=np.float)

    # 定义训练模型方法,实现 K-means 聚类过程
    def fit(self, data):
        # 假如没有指定初始质心,就随机选取 data 中的点作为质心
        if (self.centroids.shape == (0,)):
            # 从 data 中随机生成 0 到 data 行数的 6 个整数,作为索引值
            self.centroids = data[np.random.randint(0, data.shape[0], self.n_clusters) ,:]

        # 开始迭代
        for i in range(self.max_iter):
            # 1.计算距离,得到的是一个 100x6 的矩阵,该矩阵的每一行含义是:一个样本点跟所有质心的距离
            distances = cdist(data, self.centroids)

            # 2.对距离按由近到远排序,选取最近的质心点的类别,作为当前点的分类
            c_index = np.argmin(distances, axis=1) # 得到 100x1 的矩阵

            # 3.对每一类数据进行均值计算,更新质心点的坐标
            for i in range(self.n_clusters): # 遍历每一类
                # 排除掉没有出现在 c_index 里的类别
                if i in c_index:
                    # 选择所有类别是 i 的点,取 data 里面坐标的均值,更新第 i 个质心
                    self.centroids[i] = np.mean(data[c_index==i], axis=0) # 得到一行数据,使用了布尔索引

    # 定义预测模型方法
    def predict(self, samples):
        # 跟上面一样,先计算距离,然后读取距离最近的那个质心的类别
        distances = cdist(samples, self.centroids)
        c_index = np.argmin(distances, axis=1) # 得到 100x1 的矩阵

        return c_index

# 测试
distances = np.array([[121, 221, 32, 43],
                      [121, 332, 166, 52],
                      [96, 411, 56, 158],
                      [45, 235, 542, 156],
                      [140, 54, 63, 255],
                     ], dtype=np.float)
c_index = np.argmin(distances, axis=1)
print(c_index)

x_new =x[0:5]
print(x_new)
print(c_index==2)
print(x_new[c_index==2])
np.mean(x_new[c_index==2], axis=0)

输出结果如下:

[2 3 2 0 1]
[[-0.02708305  5.0215929 ]
 [-5.49252256  6.27366991]
 [-5.37691608  1.51403209]
 [-5.37872006  2.16059225]
 [ 9.58333171  8.10916554]]
[ True False  True False False]
[[-0.02708305  5.0215929 ]
 [-5.37691608  1.51403209]]

array([-2.70199956,  3.26781249])

3.测试

# 定义一个绘制子图的函数
def plotKMeans(x, y, centroids, subplot, title):
    # 分配子图,121 表示 1 行 2 列的子图中的第一个
    plt.subplot(subplot)
    plt.scatter(x[:,0], x[:,1], c='r')
    # 画出质心点
    plt.scatter(centroids[:,0], centroids[:,1], c=np.array(range(6)), s=100)
    plt.title(title)

kmeans = K_Means(max_iter=300, centroids=np.array([[2, 1],[2, 2],[2, 3],[2, 4],[2, 5],[2, 6]]))

plt.figure(figsize=(18, 9))
plotKMeans(x, y, kmeans.centroids, 121, 'Initial State')

# 开始聚类
kmeans.fit(x)
plotKMeans(x, y, kmeans.centroids, 122, 'Final State')

# 开始预测
x_new = np.array([[0, 0],[10, 7]])
y_pred = kmeans.predict(x_new)

print(kmeans.centroids)
print(y_pred)

plt.scatter(x_new[:,0], x_new[:,1], c='black', s=100)

输出结果如下:

[[ 5.76444812 -4.67941789]
 [-2.89174024 -0.22808556]
 [-5.89115978  2.33887408]
 [-4.53406813  6.11523454]
 [-1.15698106  5.63230377]
 [ 9.20551979  7.56124841]]
[1 5]

<matplotlib.collections.PathCollection at 0x1543350c518>

作图如下:

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2019-05-18 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 0.引入依赖
  • 1.数据的加载和预处理
  • 2.算法实现
  • 3.测试
相关产品与服务
腾讯云服务器利旧
云服务器(Cloud Virtual Machine,CVM)提供安全可靠的弹性计算服务。 您可以实时扩展或缩减计算资源,适应变化的业务需求,并只需按实际使用的资源计费。使用 CVM 可以极大降低您的软硬件采购成本,简化 IT 运维工作。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档