Python实现:KNN分类算法

1、KNN分类算法

KNN分类算法(K-Nearest-Neighbors Classification),又叫K近邻算法,是一个概念极其简单,而分类效果又很优秀的分类算法。

他的核心思想就是,要确定测试样本属于哪一类,就寻找所有训练样本中与该测试样本“距离”最近的前K个样本,然后看这K个样本大部分属于哪一类,那么就认为这个测试样本也属于哪一类。简单的说就是让最相似的K个样本来投票决定。

这里所说的距离,一般最常用的就是多维空间的欧式距离。这里的维度指特征维度,即样本有几个特征就属于几维。

KNN示意图如下所示。

上图中要确定测试样本绿色属于蓝色还是红色。

显然,当K=3时,将以1:2的投票结果分类于红色;而K=5时,将以3:2的投票结果分类于蓝色。

KNN算法简单有效,但没有优化的暴力法效率容易达到瓶颈。如样本个数为N,特征维度为D的时候,该算法时间复杂度呈O(DN)增长。

所以通常KNN的实现会把训练数据构建成K-D Tree(K-dimensional tree),构建过程很快,甚至不用计算D维欧氏距离,而搜索速度高达O(D*log(N))。

不过当D维度过高,会产生所谓的”维度灾难“,最终效率会降低到与暴力法一样。

因此通常D>20以后,最好使用更高效率的Ball-Tree,其时间复杂度为O(D*log(N))。

人们经过长期的实践发现KNN算法虽然简单,但能处理大规模的数据分类,尤其适用于样本分类边界不规则的情况。最重要的是该算法是很多高级机器学习算法的基础。

当然,KNN算法也存在一切问题。比如如果训练数据大部分都属于某一类,投票算法就有很大问题了。这时候就需要考虑设计每个投票者票的权重了。

2、测试数据

测试数据的格式仍然和前面使用的身高体重数据一致。不过数据增加了一些:

[plain] view plaincopy

  1. 1.5 40 thin
  2. 1.5 50 fat
  3. 1.5 60 fat
  4. 1.6 40 thin
  5. 1.6 50 thin
  6. 1.6 60 fat
  7. 1.6 70 fat
  8. 1.7 50 thin
  9. 1.7 60 thin
  10. 1.7 70 fat
  11. 1.7 80 fat
  12. 1.8 60 thin
  13. 1.8 70 thin
  14. 1.8 80 fat
  15. 1.8 90 fat
  16. 1.9 80 thin
  17. 1.9 90 fat

3、Python代码

scikit-learn提供了优秀的KNN算法支持。使用Python代码如下:

[python] view plaincopy

  1. # -*- coding: utf-8 -*-
  2. import numpy as np
  3. from sklearn import neighbors
  4. from sklearn.metrics import precision_recall_curve
  5. from sklearn.metrics import classification_report
  6. from sklearn.cross_validation import train_test_split
  7. import matplotlib.pyplot as plt
  8. ''''' 数据读入 '''
  9. data = []
  10. labels = []
  11. with open("data\\1.txt") as ifile:
  12. for line in ifile:
  13. tokens = line.strip().split(' ')
  14. data.append([float(tk) for tk in tokens[:-1]])
  15. labels.append(tokens[-1])
  16. x = np.array(data)
  17. labels = np.array(labels)
  18. y = np.zeros(labels.shape)
  19. ''''' 标签转换为0/1 '''
  20. y[labels=='fat']=1
  21. ''''' 拆分训练数据与测试数据 '''
  22. x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.2)
  23. ''''' 创建网格以方便绘制 '''
  24. h = .01
  25. x_min, x_max = x[:, 0].min() - 0.1, x[:, 0].max() + 0.1
  26. y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
  27. xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
  28. np.arange(y_min, y_max, h))
  29. ''''' 训练KNN分类器 '''
  30. clf = neighbors.KNeighborsClassifier(algorithm='kd_tree')
  31. clf.fit(x_train, y_train)
  32. '''''测试结果的打印'''
  33. answer = clf.predict(x)
  34. print(x)
  35. print(answer)
  36. print(y)
  37. print(np.mean( answer == y))
  38. '''''准确率与召回率'''
  39. precision, recall, thresholds = precision_recall_curve(y_train, clf.predict(x_train))
  40. answer = clf.predict_proba(x)[:,1]
  41. print(classification_report(y, answer, target_names = ['thin', 'fat']))
  42. ''''' 将整个测试空间的分类结果用不同颜色区分开'''
  43. answer = clf.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:,1]
  44. z = answer.reshape(xx.shape)
  45. plt.contourf(xx, yy, z, cmap=plt.cm.Paired, alpha=0.8)
  46. ''''' 绘制训练样本 '''
  47. plt.scatter(x_train[:, 0], x_train[:, 1], c=y_train, cmap=plt.cm.Paired)
  48. plt.xlabel(u'身高')
  49. plt.ylabel(u'体重')
  50. plt.show()

4、结果分析

其输出结果如下:

[ 0. 0. 1. 0. 0. 1. 1. 0. 0. 1. 1. 0. 0. 1. 1. 0. 1.] [ 0. 1. 1. 0. 0. 1. 1. 0. 0. 1. 1. 0. 0. 1. 1. 0. 1.] 准确率=0.94, score=0.94 precision recall f1-score support thin 0.89 1.00 0.94 8 fat 1.00 0.89 0.94 9 avg / total 0.95 0.94 0.94 17

KNN分类器在众多分类算法中属于最简单的之一,需要注意的地方不多。有这几点要说明: 1、KNeighborsClassifier可以设置3种算法:‘brute’,‘kd_tree’,‘ball_tree’。如果不知道用哪个好,设置‘auto’让KNeighborsClassifier自己根据输入去决定。

2、注意统计准确率时,分类器的score返回的是计算正确的比例,而不是R2。R2一般应用于回归问题。

3、本例先根据样本中身高体重的最大最小值,生成了一个密集网格(步长h=0.01),然后将网格中的每一个点都当成测试样本去测试,最后使用contourf函数,使用不同的颜色标注出了胖、廋两类。

容易看到,本例的分类边界,属于相对复杂,但却又与距离呈现明显规则的锯齿形。

这种边界线性函数是难以处理的。而KNN算法处理此类边界问题具有天生的优势。我们在后续的系列中会看到,这个数据集达到准确率=0.94算是很优秀的结果了。

原文发布于微信公众号 - 大数据挖掘DT数据分析(datadw)

原文发表时间:2015-11-17

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏AI研习社

基于 Keras 对深度学习模型进行微调的全面指南 Part 2

本部分属该两部系列中的第二部分,该系列涵盖了基于 Keras 对深度学习模型的微调。第一部分阐述微调背后的动机和原理,并简要介绍常用的做法和技巧。本部分将详细地...

15330
来自专栏ATYUN订阅号

Python中的白噪声时间训练

白噪声是时间序列预测中的一个重要概念。如果一个时间序列是白噪声,它是一个随机数序列,不能预测。如果预测误差不是白噪声,它暗示了预测模型仍有改进空间。 在本教程中...

89460
来自专栏LhWorld哥陪你聊算法

【神经网络篇】--基于数据集cifa10的经典模型实例

最终,在cifar-10数据集上,通过一个短时间小迭代的训练,可以达到大致73%的准确率,持续增加max_steps,可以期望准确率逐渐增加 如果max_ste...

27710
来自专栏机器学习算法与Python学习

RBF神经网络及Python实现(附源码)

作者:ACdreamers http://blog.csdn.net/acdreamers/article/details/46327761 RBF网络能够逼近...

1.2K60
来自专栏机器之心

学界 | 取代学习率衰减的新方法:谷歌大脑提出增加Batch Size

345110
来自专栏专知

【最新TensorFlow1.4.0教程03】利用Eager Execution构建和训练卷积神经网络(CNN)

【导读】主题链路知识是我们专知的核心功能之一,为用户提供AI领域系统性的知识学习服务,一站式学习人工智能的知识,包含人工智能( 机器学习、自然语言处理、计算机视...

42050
来自专栏Petrichor的专栏

深度学习: Regularization (正则化)

正则化,regularization,也即 约束 。是防止过拟合的诸多手段之一,很常用。

1.3K40
来自专栏Petrichor的专栏

思考: 神经网络层 的 功能识别

卷积操作 即为 一系列 滤波操作 之和 。 有关卷积的概念,具体可参加我的这篇文章:深度学习: convolution (卷积) 。

11750
来自专栏一心无二用,本人只专注于基础图像算法的实现与优化。

优化的对比度增强算法用于有雾图像的清晰化处理(算法效果是我目前看到最为稳定的,且对天空具有天然的免疫力,极力推荐有需要的朋友研究)。

在未谈及具体的算法流程前,先贴几幅用该算法处理的效果。 ? ? ? ?   不知道各位对这个算法的效果第一印象如何。       这个算法的原理来自于文章...

28380
来自专栏人工智能头条

深度学习目标检测指南:如何过滤不感兴趣的分类及添加新分类?

19730

扫码关注云+社区

领取腾讯云代金券