前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >用通俗易懂的语言讲解 kNN

用通俗易懂的语言讲解 kNN

作者头像
不可言诉的深渊
发布2019-07-27 19:53:21
5290
发布2019-07-27 19:53:21
举报

最近有个同学问我 k-means 和 kNN 是不是差不多?其实差太多了,k-means 是在不知道类别的情况下进行分类的,而 kNN 是通过已经存在的已经分好类的数据集给新的数据集归类。上次讲了 k-means,这次就正好讲一下 kNN。

概述

kNN 算法的执行过程非常简单,主要步骤如下:

  1. 加载已知数据集和未知数据集。
  2. 未知数据集的每一个未知数据与已知数据集的每一个数据求距离,找到 k 个最小的距离(最小,第 2 小……第 k 小),这里要注意一下:是每一个未知数据都有 k 个最小值,不是所有的未知数据一共有 k 个最小值!
  3. 找到 k 个最小距离所对应的点大多数属于哪一类,就把该未知数据归到这个类中。

是不是很简单?下面依旧是以 iris 数据集为例手动实现一下该算法,为了把数据和基于数据的操作封装在一起,依旧使用面向对象来实现。

初始化

在这个算法中,我们需要存储已知数据集,未知数据集,最近的邻居,还需要设置 k 的值,初始化完成之后就是载入数据,找邻居,分类了。知道这些先写一个框架应该不难。

代码语言:javascript
复制
class KNN:
    def __init__(self):
        self.known_dataset = []
        self.unknown_dataset = []
        self.nearest_neighbours = []
        self.k = 6
        self.load()
        self.find_nearest_neighbours()
        self.classify()

载入数据集

载入数据集分为两个步骤:载入已知数据集和载入未知数据集。已知数据集在文件中,通过读取文件外加适当处理就能载入,未知数据集我选择在已知数据集所在的空间中随机生成。知道这样的一个过程,写出代码应该不难。

代码语言:javascript
复制
def load(self):
    dataset = open('iris.data').readlines()
    x1s, x2s, x3s, x4s = [], [], [], []
    for data in dataset:
        data = data.strip().split(',')
        if len(data) == 5:
            data = ((float(data[0]), float(data[1]), float(data[2]), float(data[3])), data[4])
            x1s.append(data[0][0])
            x2s.append(data[0][1])
            x3s.append(data[0][2])
            x4s.append(data[0][3])
            self.known_dataset.append(data)
    min_x1, max_x1 = min(x1s), max(x1s)
    min_x2, max_x2 = min(x2s), max(x2s)
    min_x3, max_x3 = min(x3s), max(x3s)
    min_x4, max_x4 = min(x4s), max(x4s)
    for i in range(10):
        x1 = min_x1+(max_x1-min_x1)*random()
        x2 = min_x2+(max_x2-min_x2)*random()
        x3 = min_x3+(max_x3-min_x3)*random()
        x4 = min_x4+(max_x4-min_x4)*random()
        self.unknown_dataset.append((x1, x2, x3, x4))

计算距离

计算距离很简单,就是 Δxi 的平方和开根号,我选择直接按使用距离的平方,毕竟距离越大距离的平方也越大。

代码语言:javascript
复制
@staticmethod
def calculate_distance(p1, p2):
    return sum([(p2[i]-p1[i])**2for i in range(len(p1))])

找最近的邻居

找最近的邻居是给每一个未知数据找到最近的邻居,我们只要给每一个未知数据找到 k 个最近的邻居,这一点也不难,使用选择排序的思想很简单(选择排序是指先找到最小的放在第一个,然后找到次小的放在第二个,以此类推)!

代码语言:javascript
复制
def find_nearest_neighbours(self):
    for unknown_data in self.unknown_dataset:
        points_and_distances = []
        for known_data in self.known_dataset:
            points_and_distances.append((known_data, self.calculate_distance(unknown_data, known_data[0])))
        for i in range(self.k):
            for j in range(i+1, len(points_and_distances)):
                if points_and_distances[i][1] > points_and_distances[j][1]:
                    points_and_distances[i], points_and_distances[j]\
                        = points_and_distances[j], points_and_distances[i]
        self.nearest_neighbours.append(points_and_distances[:self.k])

分类

分类很简单,k 个最近的邻居中大多数属于哪一类,那么这个未知数据就归到那一类中。这里有个问题,万一出现一样多的两个怎么办?这里可以选择进一步操作(比如计算距离平均值),找到最接近的,大家可以自己实现,我暂时选择出现这种情况就选择听天由命,随机到哪个就是哪个了。

代码语言:javascript
复制
def classify(self):
    for i in range(len(self.unknown_dataset)):
        d = {'Iris-setosa': 0, 'Iris-versicolor': 0, 'Iris-virginica': 0}
        for nearest_neighbour in self.nearest_neighbours[i]:
            d[nearest_neighbour[0][-1]] += 1
        max_v = 0
        max_k = None
        for k, v in d.items():
            if v > max_v:
                max_k = k
print(self.unknown_dataset[i], max_k)

下面直接给出完整源代码。

代码语言:javascript
复制
from random import random


class KNN:
    def __init__(self):
        self.known_dataset = []
        self.unknown_dataset = []
        self.nearest_neighbours = []
        self.k = 6
        self.load()
        self.find_nearest_neighbours()
        self.classify()

    def load(self):
        dataset = open('iris.data').readlines()
        x1s, x2s, x3s, x4s = [], [], [], []
        for data in dataset:
            data = data.strip().split(',')
            if len(data) == 5:
                data = ((float(data[0]), float(data[1]), float(data[2]), float(data[3])), data[4])
                x1s.append(data[0][0])
                x2s.append(data[0][1])
                x3s.append(data[0][2])
                x4s.append(data[0][3])
                self.known_dataset.append(data)
        min_x1, max_x1 = min(x1s), max(x1s)
        min_x2, max_x2 = min(x2s), max(x2s)
        min_x3, max_x3 = min(x3s), max(x3s)
        min_x4, max_x4 = min(x4s), max(x4s)
        for i in range(10):
            x1 = min_x1+(max_x1-min_x1)*random()
            x2 = min_x2+(max_x2-min_x2)*random()
            x3 = min_x3+(max_x3-min_x3)*random()
            x4 = min_x4+(max_x4-min_x4)*random()
            self.unknown_dataset.append((x1, x2, x3, x4))

    @staticmethod
    def calculate_distance(p1, p2):
        return sum([(p2[i]-p1[i])**2for i in range(len(p1))])

    def find_nearest_neighbours(self):
        for unknown_data in self.unknown_dataset:
            points_and_distances = []
            for known_data in self.known_dataset:
                points_and_distances.append((known_data, self.calculate_distance(unknown_data, known_data[0])))
            for i in range(self.k):
                for j in range(i+1, len(points_and_distances)):
                    if points_and_distances[i][1] > points_and_distances[j][1]:
                        points_and_distances[i], points_and_distances[j]\
                            = points_and_distances[j], points_and_distances[i]
            self.nearest_neighbours.append(points_and_distances[:self.k])

    def classify(self):
        for i in range(len(self.unknown_dataset)):
            d = {'Iris-setosa': 0, 'Iris-versicolor': 0, 'Iris-virginica': 0}
            for nearest_neighbour in self.nearest_neighbours[i]:
                d[nearest_neighbour[0][-1]] += 1
            max_v = 0
            max_k = None
            for k, v in d.items():
                if v > max_v:
                    max_k = k
print(self.unknown_dataset[i], max_k)


if __name__ == '__main__':
    KNN()

运行结果如图所示。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-06-01,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Python机器学习算法说书人 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档