Python实现:KNN分类算法

1、KNN分类算法

KNN分类算法(K-Nearest-Neighbors Classification),又叫K近邻算法,是一个概念极其简单,而分类效果又很优秀的分类算法。

他的核心思想就是,要确定测试样本属于哪一类,就寻找所有训练样本中与该测试样本“距离”最近的前K个样本,然后看这K个样本大部分属于哪一类,那么就认为这个测试样本也属于哪一类。简单的说就是让最相似的K个样本来投票决定。

这里所说的距离,一般最常用的就是多维空间的欧式距离。这里的维度指特征维度,即样本有几个特征就属于几维。

KNN示意图如下所示。

上图中要确定测试样本绿色属于蓝色还是红色。

显然,当K=3时,将以1:2的投票结果分类于红色;而K=5时,将以3:2的投票结果分类于蓝色。

KNN算法简单有效,但没有优化的暴力法效率容易达到瓶颈。如样本个数为N,特征维度为D的时候,该算法时间复杂度呈O(DN)增长。

所以通常KNN的实现会把训练数据构建成K-D Tree(K-dimensional tree),构建过程很快,甚至不用计算D维欧氏距离,而搜索速度高达O(D*log(N))。

不过当D维度过高,会产生所谓的”维度灾难“,最终效率会降低到与暴力法一样。

因此通常D>20以后,最好使用更高效率的Ball-Tree,其时间复杂度为O(D*log(N))。

人们经过长期的实践发现KNN算法虽然简单,但能处理大规模的数据分类,尤其适用于样本分类边界不规则的情况。最重要的是该算法是很多高级机器学习算法的基础。

当然,KNN算法也存在一切问题。比如如果训练数据大部分都属于某一类,投票算法就有很大问题了。这时候就需要考虑设计每个投票者票的权重了。

2、测试数据

测试数据的格式仍然和前面使用的身高体重数据一致。不过数据增加了一些:

[plain] view plaincopy

  1. 1.5 40 thin
  2. 1.5 50 fat
  3. 1.5 60 fat
  4. 1.6 40 thin
  5. 1.6 50 thin
  6. 1.6 60 fat
  7. 1.6 70 fat
  8. 1.7 50 thin
  9. 1.7 60 thin
  10. 1.7 70 fat
  11. 1.7 80 fat
  12. 1.8 60 thin
  13. 1.8 70 thin
  14. 1.8 80 fat
  15. 1.8 90 fat
  16. 1.9 80 thin
  17. 1.9 90 fat

3、Python代码

scikit-learn提供了优秀的KNN算法支持。使用Python代码如下:

[python] view plaincopy

  1. # -*- coding: utf-8 -*-
  2. import numpy as np
  3. from sklearn import neighbors
  4. from sklearn.metrics import precision_recall_curve
  5. from sklearn.metrics import classification_report
  6. from sklearn.cross_validation import train_test_split
  7. import matplotlib.pyplot as plt
  8. ''''' 数据读入 '''
  9. data = []
  10. labels = []
  11. with open("data\\1.txt") as ifile:
  12. for line in ifile:
  13. tokens = line.strip().split(' ')
  14. data.append([float(tk) for tk in tokens[:-1]])
  15. labels.append(tokens[-1])
  16. x = np.array(data)
  17. labels = np.array(labels)
  18. y = np.zeros(labels.shape)
  19. ''''' 标签转换为0/1 '''
  20. y[labels=='fat']=1
  21. ''''' 拆分训练数据与测试数据 '''
  22. x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.2)
  23. ''''' 创建网格以方便绘制 '''
  24. h = .01
  25. x_min, x_max = x[:, 0].min() - 0.1, x[:, 0].max() + 0.1
  26. y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
  27. xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
  28. np.arange(y_min, y_max, h))
  29. ''''' 训练KNN分类器 '''
  30. clf = neighbors.KNeighborsClassifier(algorithm='kd_tree')
  31. clf.fit(x_train, y_train)
  32. '''''测试结果的打印'''
  33. answer = clf.predict(x)
  34. print(x)
  35. print(answer)
  36. print(y)
  37. print(np.mean( answer == y))
  38. '''''准确率与召回率'''
  39. precision, recall, thresholds = precision_recall_curve(y_train, clf.predict(x_train))
  40. answer = clf.predict_proba(x)[:,1]
  41. print(classification_report(y, answer, target_names = ['thin', 'fat']))
  42. ''''' 将整个测试空间的分类结果用不同颜色区分开'''
  43. answer = clf.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:,1]
  44. z = answer.reshape(xx.shape)
  45. plt.contourf(xx, yy, z, cmap=plt.cm.Paired, alpha=0.8)
  46. ''''' 绘制训练样本 '''
  47. plt.scatter(x_train[:, 0], x_train[:, 1], c=y_train, cmap=plt.cm.Paired)
  48. plt.xlabel(u'身高')
  49. plt.ylabel(u'体重')
  50. plt.show()

4、结果分析

其输出结果如下:

[ 0. 0. 1. 0. 0. 1. 1. 0. 0. 1. 1. 0. 0. 1. 1. 0. 1.] [ 0. 1. 1. 0. 0. 1. 1. 0. 0. 1. 1. 0. 0. 1. 1. 0. 1.] 准确率=0.94, score=0.94 precision recall f1-score support thin 0.89 1.00 0.94 8 fat 1.00 0.89 0.94 9 avg / total 0.95 0.94 0.94 17

KNN分类器在众多分类算法中属于最简单的之一,需要注意的地方不多。有这几点要说明: 1、KNeighborsClassifier可以设置3种算法:‘brute’,‘kd_tree’,‘ball_tree’。如果不知道用哪个好,设置‘auto’让KNeighborsClassifier自己根据输入去决定。

2、注意统计准确率时,分类器的score返回的是计算正确的比例,而不是R2。R2一般应用于回归问题。

3、本例先根据样本中身高体重的最大最小值,生成了一个密集网格(步长h=0.01),然后将网格中的每一个点都当成测试样本去测试,最后使用contourf函数,使用不同的颜色标注出了胖、廋两类。

容易看到,本例的分类边界,属于相对复杂,但却又与距离呈现明显规则的锯齿形。

这种边界线性函数是难以处理的。而KNN算法处理此类边界问题具有天生的优势。我们在后续的系列中会看到,这个数据集达到准确率=0.94算是很优秀的结果了。

原文发布于微信公众号 - 大数据挖掘DT数据分析(datadw)

原文发表时间:2015-11-17

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏Bingo的深度学习杂货店

《机器学习实战》总结篇

前些天完成了《机器学习实战》这本书的学习,也利用 Python3 实现了各个章节的代码,对传统的机器学习方法有了更进一步的了解,这里做一个总结。 代码传送门: ...

3544
来自专栏机器之心

业界 | 苹果发文:全局语义信息能否改进神经语言模型?

在 iPhone 上输入文本、发掘用户可能感兴趣的新闻、查明用户遇到问题的答案,以及其他语言相关的任务都取决于稳健的自然语言处理(NLP)模型。词嵌入是一类 N...

792
来自专栏技术与生活

机器学习之线性回归(1)

在开始理解回归分析之前,先有个二手房房价的例子,假设这里的房价只和面积,居室情况两个因素有关,那么我们就有下列的数据形式 面积 居室 房价[万] 80 2 10...

1223
来自专栏机器之心

教程 | 如何使用深度学习为照片自动生成文本描述?

42911
来自专栏用户2442861的专栏

循环神经网络教程第一部分-RNN简介

作者:徐志强 链接:https://zhuanlan.zhihu.com/p/22266022 来源:知乎 著作权归作者所有。商业转载请联系作者获得授权,非商业...

701
来自专栏SIGAI学习与实践平台

文本表示简介

文本分类是自然语言处理中研究最为广泛的任务之一,通过构建模型实现对文本内容进行自动分类,有很多应用场景,比如新闻文章主题分类,产品评论情感分类,检索中用户查询的...

902
来自专栏机器之心

深度 | 可视化线性修正网络:看Fisher-Rao范数与泛化之间的关系

33711
来自专栏机器之心

盘点 | 对比图像分类五大方法:KNN、SVM、BPNN、CNN和迁移学习

选自Medium 机器之心编译 参与:蒋思源、黄小天、吴攀 图像分类是人工智能领域的基本研究主题之一,研究者也已经开发了大量用于图像分类的算法。近日,Shiyu...

5368
来自专栏人工智能头条

用反卷积(Deconvnet)可视化和理解卷积网络

1463
来自专栏人工智能头条

深度卷积对抗生成网络(DCGAN)

1773

扫码关注云+社区