前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >用通俗易懂的语言讲解 k-means

用通俗易懂的语言讲解 k-means

作者头像
不可言诉的深渊
发布2019-07-26 17:41:08
5330
发布2019-07-26 17:41:08
举报
文章被收录于专栏:Python机器学习算法说书人

在机器学习中,有无监督学习、监督学习和半监督学习三种学习方式。在无监督学习中,有一个叫聚类的玩意,对于聚类,有很多算法,其中最简单的就是今天要讲的 k-means 了。

概述

k-means 的执行过程非常简单,主要分为以下几步:

  1. 选取 k 个中心点(通常是随机选取)。
  2. 计算每一个数据点和 k 个点的距离,如果该数据点到第 i 个中心点距离最短,那么这个数据点就被归到了第 i 类。
  3. 计算每一个聚类的质心,并将其设置为中心点。
  4. 重复 2、3,直到聚类不变或者达到迭代次数(简单来说就是收敛)。

下面我以 iris 数据集为例手动实现一下 k-means 算法的过程,为了方便把数据和基于数据的操作封装在一起,这里我使用面向对象的方法来实现。

初始化

iris 数据集总共有 5 个字段(5 列),150 条数据(150 行)。这 5 个字段中的每一个字段分别表示花萼长度、花萼宽度、花瓣长度、花瓣宽度、类别,类别总共有 3 种可能——Iris Setosa、Iris Versicolour 和 Iris Virginica。所以在该算法中我们应该选取 k = 3,知道这些我们先写一个框架应该不是什么难事。

代码语言:javascript
复制
class KMeans:
    def __init__(self):
        self.sepal_length_in_cm = []
        self.sepal_width_in_cm = []
        self.petal_length_in_cm = []
        self.petal_width_in_cm = []
        self.k = 3
        # noinspection PyUnusedLocal
        self.classes = [[]for i in range(self.k)]
        self.centers = []
        self.load()
        self.choose_centers()
        for i in range(100):
            self.classify()
            self.modify()
        for i in range(len(self.classes)):
            print(f'{i+1}:{self.centers[i]}')
            print(self.classes[i])

载入数据集

载入数据集很简单,一行一行读取数据集文件,然后每一行按照逗号分割,如果分割后的长度为 5,那么我们就把前 4 个放到对应的字段列表中,也就是上面的 self.sepal_length_in_cm,self.sepal_width_in_cm,self.petal_length_in_cm,self.petal_width_in_cm。

代码语言:javascript
复制
    def load(self):
        dataset = open('iris.data').readlines()
        for data in dataset:
            data = data.split(',')
            if len(data) == 5:
                self.sepal_length_in_cm.append(float(data[0]))
                self.sepal_width_in_cm.append(float(data[1]))
                self.petal_length_in_cm.append(float(data[2]))
                self.petal_width_in_cm.append(float(data[3]))

选择中心点

中心点的选取是随机选取,但需要注意的是范围,确保所选择的点在数据集每一个特征分量对应的范围中,不可以超出这个范围。那么在选择中心点之前,我们需要计算每一个特征分量的最大值和最小值,然后在每一个分量范围中随机选一个值,接着把选取到的值组成一个元组(也就是一个点),并把该元组添加到列表 self.centers 中,重复选择步骤 self.k 次。

代码语言:javascript
复制
    def choose_centers(self):
        min_sepal_length, max_sepal_length = min(self.sepal_length_in_cm), max(self.sepal_length_in_cm)
        min_sepal_width, max_sepal_width = min(self.sepal_width_in_cm), max(self.sepal_width_in_cm)
        min_petal_length, max_petal_length = min(self.petal_length_in_cm), max(self.petal_length_in_cm)
        min_petal_width, max_petal_width = min(self.petal_width_in_cm), max(self.petal_width_in_cm)
        for i in range(self.k):
            sepal_length = min_sepal_length+(max_sepal_length-min_sepal_length)*random()
            sepal_width = min_sepal_width+(max_sepal_width-min_sepal_width)*random()
            petal_length = min_petal_length+(max_petal_length-min_petal_length)*random()
            petal_width = min_petal_width+(max_petal_width-min_petal_width)*random()
            self.centers.append((sepal_length, sepal_width, petal_length, petal_width))

计算距离

计算距离很简单,这里使用 Δxi 的平方和来表示距离的平方,因为距离的平方越小,距离也就越小,所以没有必要开根号。

代码语言:javascript
复制
    @staticmethod
    def calculate_distance(p1, p2):
        return (p2[0]-p1[0])**2+(p2[1]-p1[1])**2+(p2[2]-p1[2])**2+(p2[3]-p1[3])**2

分类

计算每一个数据点和 k 个点的距离,如果该数据点到第 i 个中心点距离最短,那么这个数据点就被归到了第 i 类。

代码语言:javascript
复制
    def classify(self):
        for i in range(len(self.sepal_length_in_cm)):
            p = (self.sepal_length_in_cm[i], self.sepal_width_in_cm[i], self.petal_length_in_cm[i],
                 self.petal_width_in_cm[i])
            distances = []
            for center in self.centers:
                distances.append((self.calculate_distance(p, center), self.centers.index(center)))
            min_distance = distances[0][0]
            min_distance_center = distances[0][1]
            for j in range(len(distances)):
                if min_distance > distances[j][0]:
                    min_distance = distances[j][0]
                    min_distance_center = distances[j][1]
            self.classes[min_distance_center].append(p)

修正

当分类完成之后,我们需要修正中心点的位置。在修正中心点位置之前,我们来想一个问题:之前选择的中心点会不会没有数据被归到,也就是代表它的数据完全没有,换句话说就是该点是无效的?真的有可能,这样就只能让它重新选择中心点并归类,然后判断,一直循环,直到每个中心点都可以有数据代表它。

修正中心点很简单,我们只要计算每一个聚类的质心即可,质心的计算很简单,就是每一个分量的平均值组成的元组(点)。

代码语言:javascript
复制
    def modify(self):
        while[]in self.classes:
            self.centers = []
            self.choose_centers()
            self.classify()
        for i in range(len(self.classes)):
            x1, x2, x3, x4 = [], [], [], []
            for c in self.classes[i]:
                x1.append(c[0])
                x2.append(c[1])
                x3.append(c[2])
                x4.append(c[3])
            x1, x2, x3, x4 = sum(x1)/len(x1), sum(x2)/len(x2), sum(x3)/len(x3), sum(x4)/len(x4)
            self.centers[i] = (x1, x2, x3, x4)

下面直接给出该程序的完整源码。

代码语言:javascript
复制
from random import random


class KMeans:
    def __init__(self):
        self.sepal_length_in_cm = []
        self.sepal_width_in_cm = []
        self.petal_length_in_cm = []
        self.petal_width_in_cm = []
        self.k = 3
        # noinspection PyUnusedLocal
        self.classes = [[]for i in range(self.k)]
        self.centers = []
        self.load()
        self.choose_centers()
        for i in range(100):
            self.classify()
            self.modify()
        for i in range(len(self.classes)):
            print(f'{i+1}:{self.centers[i]}')
            print(self.classes[i])

    def load(self):
        dataset = open('iris.data').readlines()
        for data in dataset:
            data = data.split(',')
            if len(data) == 5:
                self.sepal_length_in_cm.append(float(data[0]))
                self.sepal_width_in_cm.append(float(data[1]))
                self.petal_length_in_cm.append(float(data[2]))
                self.petal_width_in_cm.append(float(data[3]))

    def choose_centers(self):
        min_sepal_length, max_sepal_length = min(self.sepal_length_in_cm), max(self.sepal_length_in_cm)
        min_sepal_width, max_sepal_width = min(self.sepal_width_in_cm), max(self.sepal_width_in_cm)
        min_petal_length, max_petal_length = min(self.petal_length_in_cm), max(self.petal_length_in_cm)
        min_petal_width, max_petal_width = min(self.petal_width_in_cm), max(self.petal_width_in_cm)
        for i in range(self.k):
            sepal_length = min_sepal_length+(max_sepal_length-min_sepal_length)*random()
            sepal_width = min_sepal_width+(max_sepal_width-min_sepal_width)*random()
            petal_length = min_petal_length+(max_petal_length-min_petal_length)*random()
            petal_width = min_petal_width+(max_petal_width-min_petal_width)*random()
            self.centers.append((sepal_length, sepal_width, petal_length, petal_width))

    @staticmethod
    def calculate_distance(p1, p2):
        return (p2[0]-p1[0])**2+(p2[1]-p1[1])**2+(p2[2]-p1[2])**2+(p2[3]-p1[3])**2

    def classify(self):
        for i in range(len(self.sepal_length_in_cm)):
            p = (self.sepal_length_in_cm[i], self.sepal_width_in_cm[i], self.petal_length_in_cm[i],
                 self.petal_width_in_cm[i])
            distances = []
            for center in self.centers:
                distances.append((self.calculate_distance(p, center), self.centers.index(center)))
            min_distance = distances[0][0]
            min_distance_center = distances[0][1]
            for j in range(len(distances)):
                if min_distance > distances[j][0]:
                    min_distance = distances[j][0]
                    min_distance_center = distances[j][1]
            self.classes[min_distance_center].append(p)

    def modify(self):
        while[]in self.classes:
            self.centers = []
            self.choose_centers()
            self.classify()
        for i in range(len(self.classes)):
            x1, x2, x3, x4 = [], [], [], []
            for c in self.classes[i]:
                x1.append(c[0])
                x2.append(c[1])
                x3.append(c[2])
                x4.append(c[3])
            x1, x2, x3, x4 = sum(x1)/len(x1), sum(x2)/len(x2), sum(x3)/len(x3), sum(x4)/len(x4)
            self.centers[i] = (x1, x2, x3, x4)


if __name__ == '__main__':
    KMeans()

运行结果如图所示。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-05-25,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Python机器学习算法说书人 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档