首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >利用skl实现KNN

利用skl实现KNN

作者头像
皮大大
发布2021-03-01 16:19:54
发布2021-03-01 16:19:54
32800
代码可运行
举报
运行总次数:0
代码可运行

利用SKL实现KNN算法

本文通过SKL的中KNN分类器,实现了对iris数据的的分类预测,主要涉及的内容包含:

  • KNN算法原理
  • KNN算法优缺点
  • K值选取
  • SKL实现KNN

KNN 算法原理

找个K个和新数据最近的样本,取样本中最多的一个类别作为新数据的类别

KNN优点

  • 算法简单易实现

把全部的数据当做模型本身

  • 对边界不规则的数据效果好

KNN缺点

  • 只适合小数据集
  • 数据不平衡效果不好
  • 数据必须标准化
  • 不适合特征维度太多的数据

K 值选取

  • k越小越容易过拟合
  • k越小大越容易欠拟合

k值一般选择是奇数,偶数可能难以抉择

SKL实现KNN

代码语言:javascript
代码运行次数:0
运行
复制
from sklearn import datasets
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
代码语言:javascript
代码运行次数:0
运行
复制
np.random.seed(0)  # 保证每次运行的结果相同;不设置的话,默认按照时间作为参数的
导入数据
代码语言:javascript
代码运行次数:0
运行
复制
iris = datasets.load_iris()
iris
代码语言:javascript
代码运行次数:0
运行
复制
X = iris.data
X[:5]
代码语言:javascript
代码运行次数:0
运行
复制
array([[5.1, 3.5, 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
       [4.7, 3.2, 1.3, 0.2],
       [4.6, 3.1, 1.5, 0.2],
       [5. , 3.6, 1.4, 0.2]])
代码语言:javascript
代码运行次数:0
运行
复制
y = iris.target
y
代码语言:javascript
代码运行次数:0
运行
复制
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
分割数据

将数据分成训练集和测试集,总共是150条,分成140的训练集和10条的测试集

注意是随机分割,使用的是permutation参数

代码语言:javascript
代码运行次数:0
运行
复制
# 产生该长度内的随机乱序的一维数据

random_array = np.random.permutation(len(X))
random_array
代码语言:javascript
代码运行次数:0
运行
复制
# 训练集

X_train = X[random_array[:-10]]  # 取出前140个数据
y_train = y[random_array[:-10]]
代码语言:javascript
代码运行次数:0
运行
复制
X_train
代码语言:javascript
代码运行次数:0
运行
复制
array([[5.8, 2.8, 5.1, 2.4],
       [6. , 2.2, 4. , 1. ],
       [5.5, 4.2, 1.4, 0.2],
       [7.3, 2.9, 6.3, 1.8],
       [5. , 3.4, 1.5, 0.2],
       [6.3, 3.3, 6. , 2.5],
       [5. , 3.5, 1.3, 0.3],
       [6.7, 3.1, 4.7, 1.5],
       [6.8, 2.8, 4.8, 1.4],
       [6.1, 2.8, 4. , 1.3],
       [6.1, 2.6, 5.6, 1.4],
       [6.4, 3.2, 4.5, 1.5],
       [6.1, 2.8, 4.7, 1.2],
       [6.5, 2.8, 4.6, 1.5],
       [6.1, 2.9, 4.7, 1.4],
       [4.9, 3.6, 1.4, 0.1],
       [6. , 2.9, 4.5, 1.5],
       [5.5, 2.6, 4.4, 1.2],
       [4.8, 3. , 1.4, 0.3],
       [5.4, 3.9, 1.3, 0.4],
       [5.6, 2.8, 4.9, 2. ],
       [5.6, 3. , 4.5, 1.5],
       [4.8, 3.4, 1.9, 0.2],
       [4.4, 2.9, 1.4, 0.2],
       [6.2, 2.8, 4.8, 1.8],
       [4.6, 3.6, 1. , 0.2],
       [5.1, 3.8, 1.9, 0.4],
       [6.2, 2.9, 4.3, 1.3],
       [5. , 2.3, 3.3, 1. ],
       [5. , 3.4, 1.6, 0.4],
       [6.4, 3.1, 5.5, 1.8],
       [5.4, 3. , 4.5, 1.5],
       [5.2, 3.5, 1.5, 0.2],
       [6.1, 3. , 4.9, 1.8],
       [6.4, 2.8, 5.6, 2.2],
       [5.2, 2.7, 3.9, 1.4],
       [5.7, 3.8, 1.7, 0.3],
       [6. , 2.7, 5.1, 1.6],
       [5.9, 3. , 4.2, 1.5],
       [5.8, 2.6, 4. , 1.2],
       [6.8, 3. , 5.5, 2.1],
       [4.7, 3.2, 1.3, 0.2],
       [6.9, 3.1, 5.1, 2.3],
       [5. , 3.5, 1.6, 0.6],
       [5.4, 3.7, 1.5, 0.2],
       [5. , 2. , 3.5, 1. ],
       [6.5, 3. , 5.5, 1.8],
       [6.7, 3.3, 5.7, 2.5],
       [6. , 2.2, 5. , 1.5],
       [6.7, 2.5, 5.8, 1.8],
       [5.6, 2.5, 3.9, 1.1],
       [7.7, 3. , 6.1, 2.3],
       [6.3, 3.3, 4.7, 1.6],
       [5.5, 2.4, 3.8, 1.1],
       [6.3, 2.7, 4.9, 1.8],
       [6.3, 2.8, 5.1, 1.5],
       [4.9, 2.5, 4.5, 1.7],
       [6.3, 2.5, 5. , 1.9],
       [7. , 3.2, 4.7, 1.4],
       [6.5, 3. , 5.2, 2. ],
       [6. , 3.4, 4.5, 1.6],
       [4.8, 3.1, 1.6, 0.2],
       [5.8, 2.7, 5.1, 1.9],
       [5.6, 2.7, 4.2, 1.3],
       [5.6, 2.9, 3.6, 1.3],
       [5.5, 2.5, 4. , 1.3],
       [6.1, 3. , 4.6, 1.4],
       [7.2, 3.2, 6. , 1.8],
       [5.3, 3.7, 1.5, 0.2],
       [4.3, 3. , 1.1, 0.1],
       [6.4, 2.7, 5.3, 1.9],
       [5.7, 3. , 4.2, 1.2],
       [5.4, 3.4, 1.7, 0.2],
       [5.7, 4.4, 1.5, 0.4],
       [6.9, 3.1, 4.9, 1.5],
       [4.6, 3.1, 1.5, 0.2],
       [5.9, 3. , 5.1, 1.8],
       [5.1, 2.5, 3. , 1.1],
       [4.6, 3.4, 1.4, 0.3],
       [6.2, 2.2, 4.5, 1.5],
       [7.2, 3.6, 6.1, 2.5],
       [5.7, 2.9, 4.2, 1.3],
       [4.8, 3. , 1.4, 0.1],
       [7.1, 3. , 5.9, 2.1],
       [6.9, 3.2, 5.7, 2.3],
       [6.5, 3. , 5.8, 2.2],
       [6.4, 2.8, 5.6, 2.1],
       [5.1, 3.8, 1.6, 0.2],
       [4.8, 3.4, 1.6, 0.2],
       [6.5, 3.2, 5.1, 2. ],
       [6.7, 3.3, 5.7, 2.1],
       [4.5, 2.3, 1.3, 0.3],
       [6.2, 3.4, 5.4, 2.3],
       [4.9, 3. , 1.4, 0.2],
       [5.7, 2.5, 5. , 2. ],
       [6.9, 3.1, 5.4, 2.1],
       [4.4, 3.2, 1.3, 0.2],
       [5. , 3.6, 1.4, 0.2],
       [7.2, 3. , 5.8, 1.6],
       [5.1, 3.5, 1.4, 0.3],
       [4.4, 3. , 1.3, 0.2],
       [5.4, 3.9, 1.7, 0.4],
       [5.5, 2.3, 4. , 1.3],
       [6.8, 3.2, 5.9, 2.3],
       [7.6, 3. , 6.6, 2.1],
       [5.1, 3.5, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.2],
       [5.2, 3.4, 1.4, 0.2],
       [5.7, 2.8, 4.5, 1.3],
       [6.6, 3. , 4.4, 1.4],
       [5. , 3.2, 1.2, 0.2],
       [5.1, 3.3, 1.7, 0.5],
       [6.4, 2.9, 4.3, 1.3],
       [5.4, 3.4, 1.5, 0.4],
       [7.7, 2.6, 6.9, 2.3],
       [4.9, 2.4, 3.3, 1. ],
       [7.9, 3.8, 6.4, 2. ],
       [6.7, 3.1, 4.4, 1.4],
       [5.2, 4.1, 1.5, 0.1],
       [6. , 3. , 4.8, 1.8],
       [5.8, 4. , 1.2, 0.2],
       [7.7, 2.8, 6.7, 2. ],
       [5.1, 3.8, 1.5, 0.3],
       [4.7, 3.2, 1.6, 0.2],
       [7.4, 2.8, 6.1, 1.9],
       [5. , 3.3, 1.4, 0.2],
       [6.3, 3.4, 5.6, 2.4],
       [5.7, 2.8, 4.1, 1.3],
       [5.8, 2.7, 3.9, 1.2],
       [5.7, 2.6, 3.5, 1. ],
       [6.4, 3.2, 5.3, 2.3],
       [6.7, 3. , 5.2, 2.3],
       [6.3, 2.5, 4.9, 1.5],
       [6.7, 3. , 5. , 1.7],
       [5. , 3. , 1.6, 0.2],
       [5.5, 2.4, 3.7, 1. ],
       [6.7, 3.1, 5.6, 2.4],
       [5.8, 2.7, 5.1, 1.9],
       [5.1, 3.4, 1.5, 0.2],
       [6.6, 2.9, 4.6, 1.3]])
代码语言:javascript
代码运行次数:0
运行
复制
y_train
代码语言:javascript
代码运行次数:0
运行
复制
array([2, 1, 0, 2, 0, 2, 0, 1, 1, 1, 2, 1, 1, 1, 1, 0, 1, 1, 0, 0, 2, 1,
       0, 0, 2, 0, 0, 1, 1, 0, 2, 1, 0, 2, 2, 1, 0, 1, 1, 1, 2, 0, 2, 0,
       0, 1, 2, 2, 2, 2, 1, 2, 1, 1, 2, 2, 2, 2, 1, 2, 1, 0, 2, 1, 1, 1,
       1, 2, 0, 0, 2, 1, 0, 0, 1, 0, 2, 1, 0, 1, 2, 1, 0, 2, 2, 2, 2, 0,
       0, 2, 2, 0, 2, 0, 2, 2, 0, 0, 2, 0, 0, 0, 1, 2, 2, 0, 0, 0, 1, 1,
       0, 0, 1, 0, 2, 1, 2, 1, 0, 2, 0, 2, 0, 0, 2, 0, 2, 1, 1, 1, 2, 2,
       1, 1, 0, 1, 2, 2, 0, 1])
代码语言:javascript
代码运行次数:0
运行
复制
# 测试集

X_test = X[random_array[-10:]]  # 后面的10个数据作为测试集
y_test = y[random_array[-10:]]
代码语言:javascript
代码运行次数:0
运行
复制
X_test
代码语言:javascript
代码运行次数:0
运行
复制
array([[5.6, 3. , 4.1, 1.3],
       [5.9, 3.2, 4.8, 1.8],
       [6.3, 2.3, 4.4, 1.3],
       [5.5, 3.5, 1.3, 0.2],
       [5.1, 3.7, 1.5, 0.4],
       [4.9, 3.1, 1.5, 0.1],
       [6.3, 2.9, 5.6, 1.8],
       [5.8, 2.7, 4.1, 1. ],
       [7.7, 3.8, 6.7, 2.2],
       [4.6, 3.2, 1.4, 0.2]])
代码语言:javascript
代码运行次数:0
运行
复制
y_test
代码语言:javascript
代码运行次数:0
运行
复制
array([1, 1, 1, 0, 0, 0, 2, 1, 2, 0])
KNN训练器
代码语言:javascript
代码运行次数:0
运行
复制
# 定义一个knn分类器对象
knn = KNeighborsClassifier()
# 调用对象的训练方法,传入两个参数:训练集数据极其标签
knn.fit(X_train,y_train)
代码语言:javascript
代码运行次数:0
运行
复制
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
                     metric_params=None, n_jobs=None, n_neighbors=5, p=2,
                     weights='uniform')
预测
代码语言:javascript
代码运行次数:0
运行
复制
y_predict = knn.predict(X_test)  # 传入预测的数据
y_predict
代码语言:javascript
代码运行次数:0
运行
复制
array([1, 2, 1, 0, 0, 0, 2, 1, 2, 0])
代码语言:javascript
代码运行次数:0
运行
复制
# 计算各个测试样本预测的概率值

probility = knn.predict_proba(X_test)
probility
代码语言:javascript
代码运行次数:0
运行
复制
array([[0. , 1. , 0. ],
       [0. , 0.4, 0.6],
       [0. , 1. , 0. ],
       [1. , 0. , 0. ],
       [1. , 0. , 0. ],
       [1. , 0. , 0. ],
       [0. , 0. , 1. ],
       [0. , 1. , 0. ],
       [0. , 0. , 1. ],
       [1. , 0. , 0. ]])
代码语言:javascript
代码运行次数:0
运行
复制
# 计算和最后一个测试样本距离最近的5个点,返回的是这些样本的序号组成的数组

neighborpoint = knn.kneighbors([X_test[-1]],5)
neighborpoint
代码语言:javascript
代码运行次数:0
运行
复制
(array([[0.14142136, 0.14142136, 0.2236068 , 0.2236068 , 0.2236068 ]]),
 array([[ 75,  41,  96,  78, 123]]))
代码语言:javascript
代码运行次数:0
运行
复制
# 调用对象的打分方法,计算出准确率

score = knn.score(X_test,y_test,sample_weight=None)
score
代码语言:javascript
代码运行次数:0
运行
复制
0.9
输出测试结果

对比knn的预测值和实际的值,准确率是90%

代码语言:javascript
代码运行次数:0
运行
复制
# knn的预测值
y_predict
代码语言:javascript
代码运行次数:0
运行
复制
array([1, 2, 1, 0, 0, 0, 2, 1, 2, 0])
代码语言:javascript
代码运行次数:0
运行
复制
# 原始值

y_test
代码语言:javascript
代码运行次数:0
运行
复制
array([1, 1, 1, 0, 0, 0, 2, 1, 2, 0])
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2020-11-2,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 利用SKL实现KNN算法
  • KNN 算法原理
  • KNN优点
  • KNN缺点
  • K 值选取
  • SKL实现KNN
    • 导入数据
    • 分割数据
    • KNN训练器
    • 预测
    • 输出测试结果
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档