K近邻算法的Python实现

作者:黄耀鹏人工智能爱好者,一名数据科学研究者及科普写作者

知乎专栏:https://zhuanlan.zhihu.com/data-science-meditation

作为『十大机器学习算法』之一的K-近邻(K-Nearest Neighbors)算法是思想简单、易于理解的一种分类和回归算法。今天,我们来一起学习KNN算法的基本原理,并用Python实现该算法,最后,通过一个案例阐述其应用价值。

KNN算法的直观理解

(添加一个直观的图)

它基于这样的简单假设:彼此靠近的点更有可能属于同一个类别。用大俗话来说就是『臭味相投』,或者说『近朱者赤,近墨者黑』。

它并未试图建立一个显示的预测模型,而是直接通过预测点的临近训练集点来确定其所属类别。

K近邻算法的实现主要基于三大基本要素:

K的选择;

距离度量方法的确定;

分类决策规则。

下面,即围绕这三大基本要素,探究它的分类实现原理。

KNN算法的原理

算法步骤

K近邻算法的实施步骤如下:

根据给定的距离度量,在训练集TT中寻找出与xx最近邻的kk个点,涵盖这kk个点的xx的邻域记作Nk(x)Nk(x);

在Nk(x)Nk(x)中根据分类决策规则决定样本的所属类别yy:

y=arg maxcj∑xi∈Nk(x)I(yi=cj),i=1,2,⋯,N;j=1,2,⋯,K.y=arg maxcj∑xi∈Nk(x)I(yi=cj),i=1,2,⋯,N;j=1,2,⋯,K.

K的选择

K近邻算法对K的选择非常敏感。K值越小意味着模型复杂度越高,从而容易产生过拟合;K值越大则意味着整体的模型变得简单,学习的近似近似误差会增大。

在实际的应用中,一般采用一个比较小的K值。并采用交叉验证的方法,选取一个最优的K值。

距离度量

距离度量一般采用欧式距离。也可以根据需要采用LpLp距离或明氏距离。

分类决策规则

K近邻算法中的分类决策多采用多数表决的方法进行。它等价于寻求经验风险最小化。

但这个规则存在一个潜在的问题:有可能多个类别的投票数同为最高。这个时候,究竟应该判为哪一个类别?

可以通过以下几个途径解决该问题:

从投票数相同的最高类别中随机地选择一个;

通过距离来进一步给票数加权;

减少K的个数,直到找到一个唯一的最高票数标签。

KNN算法的优缺点

优点

精度高

对异常值不敏感

没有对数据的分布假设

缺点

计算复杂度高

在高维情况下,会遇到『维数诅咒』的问题

.

KNN算法的算法实现

importosos.chdir('D:\\my_python_workfile\\Project\\Writting')os.getcwd()

'D:\\my_python_workfile\\Project\\Writting'

from__future__importdivisionfromcollectionsimportCounter#from linear_algebra import distance#from statistics import meanimportmath,randomimportmatplotlib.pyplotasplt# 定义投票函数defraw_majority_vote(labels):votes=Counter(labels)winner,_=votes.most_common(1)[]returnwinner

以上的投票函数存在潜在的问题:有可能多个类别的投票数同为最高。

下面的函数则实现了解决方案中的第三种分类决策方法。

#defmajority_vote(labels):"""assumes that labels are ordered from nearest to farthest """vote_counts=Counter(labels)winner,winner_count=vote_counts.most_common(1)[]num_winners=len([countforcountinvote_counts.values()ifcount==winner_count])ifnum_winners==1:returnwinnerelse:returnmajority_vote(labels[:-1])# try again wthout the farthest

# define distance functionimportmath#### 减法定义defvector_substract(v,w):"""substracts coresponding elements"""return[v_i-w_iforv_i,w_iinzip(v,w)]defsquared_distance(v,w):""""""returnsum_of_squares(vector_substract(v,w))defdistance(v,w):returnmath.sqrt(squared_distance(v,w))########################################### define sum_of_squares### 向量的点乘defdot(v,w):returnsum(v_i*w_iforv_i,w_iinzip(v,w))### 向量的平房和defsum_of_squares(v):"""v_1*v_1+v_2*v_2+...+v_n*v_n"""returndot(v,v)

# classifierdefknn_classify(k,labeled_points,new_point):"""each labeled point should be a pair (point,label)"""# order the labeled points from nearest to farthestby_distance=sorted(labeled_points,key=lambda(point,_):distance(point,new_point))# find the labels for the k cloestk_nearest_labels=[labelfor_,labelinby_distance[:k]]# and let them votereturnmajority_vote(k_nearest_labels)

KNN算法的应用:案例分析

#cities=[(-86.75,33.5666666666667,'Python'),(-88.25,30.6833333333333,'Python'),(-112.016666666667,33.4333333333333,'Java'),(-110.933333333333,32.1166666666667,'Java'),(-92.2333333333333,34.7333333333333,'R'),(-121.95,37.7,'R'),(-118.15,33.8166666666667,'Python'),(-118.233333333333,34.05,'Java'),(-122.316666666667,37.8166666666667,'R'),(-117.6,34.05,'Python'),(-116.533333333333,33.8166666666667,'Python'),(-121.5,38.5166666666667,'R'),(-117.166666666667,32.7333333333333,'R'),(-122.383333333333,37.6166666666667,'R'),(-121.933333333333,37.3666666666667,'R'),(-122.016666666667,36.9833333333333,'Python'),(-104.716666666667,38.8166666666667,'Python'),(-104.866666666667,39.75,'Python'),(-72.65,41.7333333333333,'R'),(-75.6,39.6666666666667,'Python'),(-77.0333333333333,38.85,'Python'),(-80.2666666666667,25.8,'Java'),(-81.3833333333333,28.55,'Java'),(-82.5333333333333,27.9666666666667,'Java'),(-84.4333333333333,33.65,'Python'),(-116.216666666667,43.5666666666667,'Python'),(-87.75,41.7833333333333,'Java'),(-86.2833333333333,39.7333333333333,'Java'),(-93.65,41.5333333333333,'Java'),(-97.4166666666667,37.65,'Java'),(-85.7333333333333,38.1833333333333,'Python'),(-90.25,29.9833333333333,'Java'),(-70.3166666666667,43.65,'R'),(-76.6666666666667,39.1833333333333,'R'),(-71.0333333333333,42.3666666666667,'R'),(-72.5333333333333,42.2,'R'),(-83.0166666666667,42.4166666666667,'Python'),(-84.6,42.7833333333333,'Python'),(-93.2166666666667,44.8833333333333,'Python'),(-90.0833333333333,32.3166666666667,'Java'),(-94.5833333333333,39.1166666666667,'Java'),(-90.3833333333333,38.75,'Python'),(-108.533333333333,45.8,'Python'),(-115.166666666667,36.0833333333333,'Java'),(-71.4333333333333,42.9333333333333,'R'),(-74.1666666666667,40.7,'R'),(-106.616666666667,35.05,'Python'),(-78.7333333333333,42.9333333333333,'R'),(-73.9666666666667,40.7833333333333,'R'),(-80.9333333333333,35.2166666666667,'Python'),(-78.7833333333333,35.8666666666667,'Python'),(-100.75,46.7666666666667,'Java'),(-84.5166666666667,39.15,'Java'),(-81.85,41.4,'Java'),(-82.8833333333333,40,'Java'),(-97.6,35.4,'Python'),(-122.666666666667,45.5333333333333,'Python'),(-75.25,39.8833333333333,'Python'),(-80.2166666666667,40.5,'Python'),(-71.4333333333333,41.7333333333333,'R'),(-81.1166666666667,33.95,'R'),(-96.7333333333333,43.5666666666667,'Python'),(-90,35.05,'R'),(-86.6833333333333,36.1166666666667,'R'),(-97.7,30.3,'Python'),(-96.85,32.85,'Java'),(-98.4666666666667,29.5333333333333,'Java'),(-111.966666666667,40.7666666666667,'Python'),(-73.15,44.4666666666667,'R'),(-77.3333333333333,37.5,'Python'),(-122.3,47.5333333333333,'Python'),(-95.9,41.3,'Python'),(-95.35,29.9666666666667,'Java'),(-89.3333333333333,43.1333333333333,'R'),(-104.816666666667,41.15,'Java')]cities=[([longitude,latitude],language)forlongitude,latitude,languageincities]

# plot_state_bordersimportresegments=[]points=[]lat_long_regex=r"

# key is language, value is pairplots={"Java":([],[]),"Python":([],[]),"R":([],[])}#mark and colormarkers={"Java":"o","Python":"s","R":"^"}colors={"Java":"r","Python":"b","R":"g"}for(logitude,latitude),languageincities:plots[language][].append(logitude)plots[language][1].append(latitude)# create a scatter series for each languageforlanguage,(x,y)inplots.iteritems():plt.scatter(x,y,color=colors[language],marker=markers[language],label=language,zorder=10)plot_state_borders(plt)plt.legend(loc=)plt.axis([-130,-60,20,55])plt.title("Favorite Programming Languages")plt.show()

# try several different values for kforkin[1,3,5,7]:num_correct=forcityincities:location,actual_language=cityother_cities=[other_cityforother_cityincitiesifother_city!=city]predicted_language=knn_classify(k,other_cities,location)ifpredicted_language==actual_language:num_correct+=1printk,"neighbor[s]:",num_correct,"correct out of ",len(cities)

1neighbor[s]:40correct outof753neighbor[s]:44correct outof755neighbor[s]:41correct outof757neighbor[s]:35correct outof75

参考文献

Grus J. Data Science from Scratch: First Principles with Python[M]. “ O’Reilly Media, Inc.”, 2015.

本文来自企鹅号 - 天善智能媒体

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏算法channel

机器学习|聚类算法之DBSCAN

DBSCAN,全称:Density-Based Spatial Clustering of Applications with Noise,是一个比较有代表性的...

4159
来自专栏机器之心

你可能不再需要Attention:这是一个贼简单的神经机器翻译架构

自从编码器解码器架构崛起以来,主流的神经机器翻译(NMT)模型都使用这种架构,因为它允许原文序列长度和译文序列长度不一样。而自 Bahdanau 等研究者在 1...

833
来自专栏数据科学与人工智能

【数据分析】异常值检测

什么是异常(outlier)?Hawkins(1980)给出了异常的本质性的定义:异常是在数据集中与众不同的数据,使人怀疑这些数据并非随机偏差,而是产生于完全不...

4386
来自专栏机器之心

教程 | 如何使用贪婪搜索和束搜索解码算法进行自然语言处理

2885

如何实现自然语言处理的集束搜索解码器

自然语言处理任务(例如字幕生成和机器翻译)涉及生成单词序列。

3878
来自专栏杨熹的专栏

RNN与机器翻译

---- CS224d-Day 9: GRUs and LSTMs -- for machine translation 视频链接 课件链接 ---- 本...

3415
来自专栏量化投资与机器学习

从Seq2seq到Attention模型到Self Attention(二)

系列一介绍了Seq2seq和 Attention model。这篇文章将重点摆在Google於2017年发表论文“Attention is all you ne...

3885
来自专栏专知

【干货】NLP中“词袋”模型和词嵌入模型的比较(附代码)

【导读】词袋模型和词向量表示是自然语言处理中最常用的特征表示方法,但这两种方法各适用于哪些不同的任务,有什么区别,作者Edward Ma详细讲解了这两类使用技巧...

941
来自专栏Python小屋

Python+sklearn使用DBSCAN聚类算法案例一则

DBSCAN聚类算法概述: DBSCAN属于密度聚类算法,把类定义为密度相连对象的最大集合,通过在样本空间中不断搜索最大集合完成聚类。 DBSCAN能够在带有噪...

5124
来自专栏Brian

熵的理解

---- 熵 熵在信息论中代表随机变量不确定度的度量。一个离散型随机变量X的熵H(X)定义为: image.png 明确定义的科学名词且与内容无关,而且不随信息...

2996

扫码关注云+社区