首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >K最近邻算法(KNN)介绍及实现

K最近邻算法(KNN)介绍及实现

作者头像
生信编程日常
发布2020-11-12 11:46:03
发布2020-11-12 11:46:03
1.3K00
代码可运行
举报
运行总次数:0
代码可运行

KNN,即K nearest neighbor,K近邻算法。KNN的思想非常简单,所需的数学知识较少。比如下图,星星是一个新的样本,要判断星星是属于蓝色的还是黄色的样本分类,就要看它周围的邻居是什么分类。假设K=3,就是看周围三个点的分类,如图,周围有两个红点,一个黄点,应该归类为红色类别。

导入数据:

代码语言:javascript
代码运行次数:0
运行
复制
from sklearn.datasets import load_breast_cancer

cancer = load_breast_cancer()
data = cancer.data
target = cancer.target

import numpy as np

X = np.array([20, 30])
plt.scatter(data[target==1,0], data[target==1,1], alpha = 0.5, color = 'green', label = '1')
plt.scatter(data[target==0,0], data[target==0,1], alpha = 0.5, color = 'orange', label = '0')
plt.plot(X[0], X[1], color = 'purple', marker = '*', ms = 20)
plt.legend(loc = 'upper right')
plt.show()

KNN的计算中,先计算距离,比较常见的是欧拉距离:

也就是两个点(或者多个点)对应的横纵坐标差的平方和,然后开平方。

根据欧拉距离写一个KNN的实现:

代码语言:javascript
代码运行次数:0
运行
复制
def KNN_test(X_train, y_train, test, K):
    distance = []

    for t in X_train:
        d = sqrt(np.sum((t - test)**2))
        distance.append(d)  
    
    ind = np.argsort(distance)
    topK_target = y_train[ind[:K]]
    c = Counter(topK_target)
    return c.most_common(1)[0][0]

判断刚才的点是属于哪一类:

代码语言:javascript
代码运行次数:0
运行
复制
KNN_test(data[:, :2], target, X, 6)

sklearn中的实现:

代码语言:javascript
代码运行次数:0
运行
复制
from sklearn.neighbors import KNeighborsClassifier

knn = KNeighborsClassifier(n_neighbors=6)
knn.fit(data[:, :2], target)
knn.predict(X.reshape(1, -1))
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档