Python实现:KNN分类算法

1、KNN分类算法

KNN分类算法(K-Nearest-Neighbors Classification),又叫K近邻算法,是一个概念极其简单,而分类效果又很优秀的分类算法。

他的核心思想就是,要确定测试样本属于哪一类,就寻找所有训练样本中与该测试样本“距离”最近的前K个样本,然后看这K个样本大部分属于哪一类,那么就认为这个测试样本也属于哪一类。简单的说就是让最相似的K个样本来投票决定。

这里所说的距离,一般最常用的就是多维空间的欧式距离。这里的维度指特征维度,即样本有几个特征就属于几维。

KNN示意图如下所示。

上图中要确定测试样本绿色属于蓝色还是红色。

显然,当K=3时,将以1:2的投票结果分类于红色;而K=5时,将以3:2的投票结果分类于蓝色。

KNN算法简单有效,但没有优化的暴力法效率容易达到瓶颈。如样本个数为N,特征维度为D的时候,该算法时间复杂度呈O(DN)增长。

所以通常KNN的实现会把训练数据构建成K-D Tree(K-dimensional tree),构建过程很快,甚至不用计算D维欧氏距离,而搜索速度高达O(D*log(N))。

不过当D维度过高,会产生所谓的”维度灾难“,最终效率会降低到与暴力法一样。

因此通常D>20以后,最好使用更高效率的Ball-Tree,其时间复杂度为O(D*log(N))。

人们经过长期的实践发现KNN算法虽然简单,但能处理大规模的数据分类,尤其适用于样本分类边界不规则的情况。最重要的是该算法是很多高级机器学习算法的基础。

当然,KNN算法也存在一切问题。比如如果训练数据大部分都属于某一类,投票算法就有很大问题了。这时候就需要考虑设计每个投票者票的权重了。

2、测试数据

测试数据的格式仍然和前面使用的身高体重数据一致。不过数据增加了一些:

[plain] view plaincopy

  1. 1.5 40 thin
  2. 1.5 50 fat
  3. 1.5 60 fat
  4. 1.6 40 thin
  5. 1.6 50 thin
  6. 1.6 60 fat
  7. 1.6 70 fat
  8. 1.7 50 thin
  9. 1.7 60 thin
  10. 1.7 70 fat
  11. 1.7 80 fat
  12. 1.8 60 thin
  13. 1.8 70 thin
  14. 1.8 80 fat
  15. 1.8 90 fat
  16. 1.9 80 thin
  17. 1.9 90 fat

3、Python代码

scikit-learn提供了优秀的KNN算法支持。使用Python代码如下:

[python] view plaincopy

  1. # -*- coding: utf-8 -*-
  2. import numpy as np
  3. from sklearn import neighbors
  4. from sklearn.metrics import precision_recall_curve
  5. from sklearn.metrics import classification_report
  6. from sklearn.cross_validation import train_test_split
  7. import matplotlib.pyplot as plt
  8. ''''' 数据读入 '''
  9. data = []
  10. labels = []
  11. with open("data\\1.txt") as ifile:
  12. for line in ifile:
  13. tokens = line.strip().split(' ')
  14. data.append([float(tk) for tk in tokens[:-1]])
  15. labels.append(tokens[-1])
  16. x = np.array(data)
  17. labels = np.array(labels)
  18. y = np.zeros(labels.shape)
  19. ''''' 标签转换为0/1 '''
  20. y[labels=='fat']=1
  21. ''''' 拆分训练数据与测试数据 '''
  22. x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.2)
  23. ''''' 创建网格以方便绘制 '''
  24. h = .01
  25. x_min, x_max = x[:, 0].min() - 0.1, x[:, 0].max() + 0.1
  26. y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
  27. xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
  28. np.arange(y_min, y_max, h))
  29. ''''' 训练KNN分类器 '''
  30. clf = neighbors.KNeighborsClassifier(algorithm='kd_tree')
  31. clf.fit(x_train, y_train)
  32. '''''测试结果的打印'''
  33. answer = clf.predict(x)
  34. print(x)
  35. print(answer)
  36. print(y)
  37. print(np.mean( answer == y))
  38. '''''准确率与召回率'''
  39. precision, recall, thresholds = precision_recall_curve(y_train, clf.predict(x_train))
  40. answer = clf.predict_proba(x)[:,1]
  41. print(classification_report(y, answer, target_names = ['thin', 'fat']))
  42. ''''' 将整个测试空间的分类结果用不同颜色区分开'''
  43. answer = clf.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:,1]
  44. z = answer.reshape(xx.shape)
  45. plt.contourf(xx, yy, z, cmap=plt.cm.Paired, alpha=0.8)
  46. ''''' 绘制训练样本 '''
  47. plt.scatter(x_train[:, 0], x_train[:, 1], c=y_train, cmap=plt.cm.Paired)
  48. plt.xlabel(u'身高')
  49. plt.ylabel(u'体重')
  50. plt.show()

4、结果分析

其输出结果如下:

[ 0. 0. 1. 0. 0. 1. 1. 0. 0. 1. 1. 0. 0. 1. 1. 0. 1.] [ 0. 1. 1. 0. 0. 1. 1. 0. 0. 1. 1. 0. 0. 1. 1. 0. 1.] 准确率=0.94, score=0.94 precision recall f1-score support thin 0.89 1.00 0.94 8 fat 1.00 0.89 0.94 9 avg / total 0.95 0.94 0.94 17

KNN分类器在众多分类算法中属于最简单的之一,需要注意的地方不多。有这几点要说明: 1、KNeighborsClassifier可以设置3种算法:‘brute’,‘kd_tree’,‘ball_tree’。如果不知道用哪个好,设置‘auto’让KNeighborsClassifier自己根据输入去决定。

2、注意统计准确率时,分类器的score返回的是计算正确的比例,而不是R2。R2一般应用于回归问题。

3、本例先根据样本中身高体重的最大最小值,生成了一个密集网格(步长h=0.01),然后将网格中的每一个点都当成测试样本去测试,最后使用contourf函数,使用不同的颜色标注出了胖、廋两类。

容易看到,本例的分类边界,属于相对复杂,但却又与距离呈现明显规则的锯齿形。

这种边界线性函数是难以处理的。而KNN算法处理此类边界问题具有天生的优势。我们在后续的系列中会看到,这个数据集达到准确率=0.94算是很优秀的结果了。

原文发布于微信公众号 - 大数据挖掘DT数据分析(datadw)

原文发表时间:2015-11-17

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏人工智能LeadAI

零基础入门深度学习 | 第六章:长短时记忆网络(LSTM)

无论即将到来的是大数据时代还是人工智能时代,亦或是传统行业使用人工智能在云上处理大数据的时代,作为一个有理想有追求的程序员,不懂深度学习(Deep Learni...

5235
来自专栏机器之心

入门 | 一文简述循环神经网络

什么是循环神经网络(RNN)?它们如何运行?可以用在哪里呢?本文试图回答上述这些问题,还展示了一个 RNN 实现 demo,你可以根据自己的需要进行扩展。

953
来自专栏SIGAI学习与实践平台

理解牛顿法

牛顿法是数值优化算法中的大家族,她和她的改进型在很多实际问题中得到了应用。在机器学习中,牛顿法是和梯度下降法地位相当的的主要优化算法。在本文中,SIGAI将为大...

1032
来自专栏机器之心

CVPR 2018 | Poster论文:处理多种退化类型的卷积超分辨率

1445
来自专栏人工智能LeadAI

卷积神经网络 | 深度学习笔记1

一、初识卷积网络结构 01 概览 前面已经做过手写体识别的任务。之前我们只是使用了两层的全连接层,或者根据喜好多加几层。我更喜欢把这样的结构叫做多层感知机,即使...

2787
来自专栏数据分析

[机器学习]-[数据预处理]-中心化 缩放 KNN(二)

上次我们使用精度评估得到的成绩是 61%,成绩并不理想,再使 recall 和 f1 看下成绩如何? 首先我们先了解一下 召回率和 f1。 真实结果 ...

3266
来自专栏CDA数据分析师

入门 | 一文简述循环神经网络

什么是循环神经网络(RNN)?它们如何运行?可以用在哪里呢?本文试图回答上述这些问题,还展示了一个 RNN 实现 demo,你可以根据自己的需要进行扩展。

813
来自专栏ACM算法日常

第一篇:《机器学习之入门初探》

这是一篇帮助你了解和理解机器学习、神经网络、深度学习相关概念的文章,如果你对智能领域感兴趣并且想要未来投身AI方向的话,希望你可以耐下心来理解这些概念,以及我...

751
来自专栏机器学习算法与理论

《白话深度学习与Tensorflow》学习笔记(4)Deep Residual Networks

深度残差网络:主要应用于计算机视觉——图像分类、语义分割(semantic segmentation)、目标检测(object detection),其主要是使...

3886
来自专栏ATYUN订阅号

谷歌大脑团队新尝试 用奇异向量典型相关分析解释深度神经网络

深度神经网络(DNNs)在视觉、语言理解和语音识别等领域取得了前所未有的进展。但这些成功也带来了新的挑战。与许多以前的机器学习方法不同的是,深度神经网络可以很容...

3488

扫描关注云+社区