前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >李航《统计学习方法》K近邻学习算法实现

李航《统计学习方法》K近邻学习算法实现

作者头像
Coggle数据科学
发布2019-09-12 17:50:47
4220
发布2019-09-12 17:50:47
举报
文章被收录于专栏:Coggle数据科学Coggle数据科学

K近邻学习算法的初始形式

数据集的选取

此次我们选取lris数据集进行实验

Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。

数据集特征:多变量 记录数:150 领域:生活

属性特征:实数 属性数目:4 捐赠日期1988-07-01

相关应用:分类 缺失值?无 网站点击数:1507217

这里用到一个更方便的方法,使用scikit-learn自带lris数据集

代码语言:javascript
复制
>>> from sklearn.datasets import load_iris
>>> iris = load_iris()
>>> iris.target[[10, 25, 50]]
array([0, 0, 1])
>>> list(data.target_names)
['setosa', 'versicolor', 'virginica']

Python代码

代码语言:javascript
复制
import numpy as np
import random
import heapq

#对数据进行预处理,分成训练集和测试集
def Preprocessing():
    from sklearn import datasets

    testset = []       #测试集
    test_labels = []   #测试集标签
    trainset = []      #训练集
    train_labels = []  #训练集标签
    
    #加载scikit-learn中自带的Iris数据集
    iris = datasets.load_iris()
    datas = iris.data[:, 0:3] #获取所有数据,不含标签
    label = iris.target       #获取所有标签
    
    #获取0到149中任意不重复的10的数
    test_list = np.sort([random.randint(0,149) for _ in range(10)])
    #对150个数据进行划分
    for i in range(150):
        if i in test_list:
            testset.append(datas[i])
            test_labels.append(label[i])
        else:
            trainset.append(datas[i])
            train_labels.append(label[i])

    testset = np.mat(testset) #转换为矩阵类型
    test_labels = np.mat(test_labels).transpose() #进行矩阵的转置
    trainset = np.mat(trainset)
    train_labels = np.mat(train_labels).transpose()
    return testset, test_labels, trainset, train_labels

def Predict(testset, trainset, train_labels, n):
    k = n
    predict = []
    #遍历测试数据集中的测试用例坐标
    for  test_vec in testset:
        dist_list = []      #当前测试数据与训练数据的距离
        knn_list = []       #当前k个最邻近点
        
        #遍历训练数据集中的训练用例坐标及标签
        for i in range(len(train_labels)):
            label = train_labels[i]
            train_vec = trainset[i]
            
            dist = np.linalg.norm(train_vec - test_vec)   #计算两个坐标的欧式距离
            dist_list.append((dist, label))
        
        #通过python中堆结构获取k个最近邻
        knn_list = heapq.nsmallest(k, dist_list, key=lambda x: x[0])
        
        # 统计选票
        class_total = 3
        class_count = [0 for i in range(class_total)]
        for dist, label in knn_list:
            label = label.tolist()  #将矩阵类型转换为列表类型
            class_count[label[0][0]] += 1
        
        # 找出最大选票
        mmax = max(class_count)

        # 找出最大选票标签
        for i in range(class_total):
            if mmax == class_count[i]:
                predict.append(i)
                break
    
    return predict

def main():
    k = 10
    testset, test_labels, trainset, train_labels = Preprocessing()
    predict = Predict(testset, trainset, train_labels, k)
    test_labels = test_labels.tolist()

    error_count = 0
    for i in range(10):
        print("testset:",test_labels[i][0], "predict:", predict[i])
        if test_labels[i][0] != predict[i]:
            error_count += 1
    print("error_count:",error_count)
   
if __name__ == '__main__':
    main()

最后输出结果:

代码语言:javascript
复制
testset: 1 predict: 1
testset: 1 predict: 1
testset: 1 predict: 1
testset: 1 predict: 1
testset: 2 predict: 2
testset: 2 predict: 2
testset: 2 predict: 2
testset: 2 predict: 1
testset: 2 predict: 2
testset: 2 predict: 2
error_count: 1

也许有小伙伴要问了,这个k如何来确定。我们先在看张图

有两类不同的样本数据,分别用蓝色的小正方形和红色的小三角形表示,而图正中间的那个绿色的圆所标示的数据则是待分类的数据。也就是说,现在,我们不知道中间那个绿色的数据是从属于哪一类(蓝色小正方形or红色小三角形),下面,我们就要解决这个问题:给这个绿色的圆分类。 我们常说,物以类聚,人以群分,判别一个人是一个什么样品质特征的人,常常可以从他/她身边的朋友入手,所谓观其友,而识其人。我们不是要判别上图中那个绿色的圆是属于哪一类数据么,所以就从它的邻居下手。但一次性看多少个邻居呢(也就是如何确定k)?从上图中,你还能看到:

  • 如果K=3,绿色圆点的最近的3个邻居是2个红色小三角形和1个蓝色小正方形,少数从属于多数,基于统计的方法,判定绿色的这个待分类点属于红色的三角形一类。
  • 如果K=5,绿色圆点的最近的5个邻居是2个红色三角形和3个蓝色的正方形,还是少数从属于多数,基于统计的方法,判定绿色的这个待分类点属于蓝色的正方形一类。

于此我们看到,不同的k取值可能得到不同的分类,而在应用中,k值一般取一个比较小的值,通常采用交叉验证法来选取最优的k值。July_的博客

写在最后:

KNN并不需要训练,但需要遍历整个训练集,所以预测比较慢。书中提到使用KD树进行优化,来提高k近邻搜索的效率。

在后面的文章中将会对KD树展开实现。

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • K近邻学习算法的初始形式
  • 数据集的选取
  • Python代码
  • 最后输出结果:
  • 写在最后:
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档