前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >机器学习_knn算法_2

机器学习_knn算法_2

作者头像
以某
发布2023-03-07 14:03:10
4200
发布2023-03-07 14:03:10
举报
文章被收录于专栏:Jack96

鸢尾花的数据分析

  • target : 目标值,代表着类别(c)
  • datasets.load_iris()得到的不是字典是普通的hash类型,因为pandas,matplotlib,sklearn,numpy都是由c封装的,所有的库不是python写的都是由c封装的
  • 是一个元类

In [2]:

代码语言:javascript
复制
#target : 目标值,代表着类别,
代码语言:javascript
复制
#得到的不是字典是普通的hash类型
代码语言:javascript
复制
from sklearn import datasets
iris = datasets.load_iris()

In [3]:

代码语言:javascript
复制
#DESCR得到的是英文文档,需要print解析(内有大量的\n,\t)
代码语言:javascript
复制
# print(iris.DESCR)

In [4]:

代码语言:javascript
复制
data = iris.data
代码语言:javascript
复制
target = iris.target
代码语言:javascript
复制
target_names = iris.target_names

In [5]:

代码语言:javascript
复制
from sklearn.model_selection import train_test_split
  • 数据的预备部分,,切割成四部分,按()内先后的位置,(可以按比例切),测试部分占10%
  • 随机切分的

In [26]:

代码语言:javascript
复制
X_train,X_test,y_train,y_test =  train_test_split(data,target,test_size=0.1)
  • 进行训练(只是为了训练出一组数学模型)
  • k : n_neighbors=5
  • n_jobs=2 : 用几个进程(核数),核心数量越多,则运作的越快

In [27]:

代码语言:javascript
复制
代码语言:javascript
复制
from sklearn.neighbors import KNeighborsClassifier
代码语言:javascript
复制
knn = KNeighborsClassifier(n_neighbors=5,n_jobs=2)
代码语言:javascript
复制
knn.fit(X_train,y_train)

Out[27]:

代码语言:javascript
复制
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
           metric_params=None, n_jobs=2, n_neighbors=5, p=2,
           weights='uniform')
  • 进行评估(传入的值任然是训练时的值但是评估结果不是1,则这种评估的结果就是准确率)

In [28]:

代码语言:javascript
复制
knn.score(X_train,y_train)

Out[28]:

代码语言:javascript
复制
0.9703703703703703
  • 进行测试

In [29]:

代码语言:javascript
复制
y_pred = knn.predict(X_test)
  • 准确率判断原理(利用的是广播机制,索引对比)

In [30]:

代码语言:javascript
复制
import pandas as pd
代码语言:javascript
复制
import numpy as np
代码语言:javascript
复制
np.sum(y_test==y_pred)/len(y_test)

Out[30]:

代码语言:javascript
复制
0.9333333333333333

In [15]:

代码语言:javascript
复制
import matplotlib.pyplot as plt
代码语言:javascript
复制
plt.scatter(data[:,0],data[:,1],c=target)

Out[15]:

代码语言:javascript
复制
<matplotlib.collections.PathCollection at 0x27b0250ae80>

拟合 : 让测试的和真实的很接近(相似度)

  • 过拟合 : 训练评估的值高(1.0),测试的值(准确率)低(0.92) /0.98 对 0.89(训练的数据中没有异常值)
  • 欠拟合 : 两者都低,或则差异较大(一般不会出现这种情况,数据量及其少的时候会出现)

网格搜索(找值高的)和交叉验证(找相对较高的)(优化算法)

  • grid search & cross validation
  • GridSearchCV (可见网络搜索较重要一些)
  • estimator : 估计器,算法(要求写算法的实例,knn)
  • param_grid : 估计器的参数,k的值,要求是一个散列表(n_neighbors:[1,3,5,7,10])(散列类型,hash)
  • cv : 指定折交的次数
  • 网络搜索 : 用for循环遍历k 的值然后训练得分(计算量较大)
  • 交叉验证 : 数据部分,多次抽取数据进行验证(3/7每次抽取70%),大量的随机测试保证模型的稳定性
  • 效率低下 : k的值可以取16,一共抽取十组数据 ,效率比之前低160倍
  • 优化找到k的哪个数量值比较合适(让测量值和真实值比较接近) 模型产生的结果
  • bestparams : 最稳定,模型产生最好的结果的参数
  • bestscore :评估结果既稳定,又高的一次
  • bestestimator :最好的模型

泛化性 / 鲁棒性 / 稳定性

  • 应用与哪种数据都可以,分辨的错误率不高,

In [83]:

代码语言:javascript
复制
from sklearn.model_selection import GridSearchCV

In [84]:

代码语言:javascript
复制
knn = KNeighborsClassifier()

In [85]:

代码语言:javascript
复制
param_grid = {'n_neighbors':list(range(1,26))}

In [86]:

代码语言:javascript
复制
gc = GridSearchCV(estimator = knn,param_grid = param_grid,cv=3,n_jobs=2)

In [87]:

代码语言:javascript
复制
gc.fit(X_train,y_train)

Out[87]:

代码语言:javascript
复制
GridSearchCV(cv=3, error_score='raise',
       estimator=KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
           metric_params=None, n_jobs=1, n_neighbors=5, p=2,
           weights='uniform'),
       fit_params=None, iid=True, n_jobs=2,
       param_grid={'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]},
       pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',
       scoring=None, verbose=0)

In [88]:

代码语言:javascript
复制
best_pa = gc.best_params_

In [89]:

代码语言:javascript
复制
best_knn = gc.best_estimator_

In [90]:

代码语言:javascript
复制
gc.best_score_

Out[90]:

代码语言:javascript
复制
0.9703703703703703

In [91]:

代码语言:javascript
复制
best_knn.score(X_train,y_train)

Out[91]:

代码语言:javascript
复制
0.9703703703703703

In [92]:

代码语言:javascript
复制
best_knn.score(X_test,y_test)

Out[92]:

代码语言:javascript
复制
0.9333333333333333

In [ ]:

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2019-09-05,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 鸢尾花的数据分析
  • 拟合 : 让测试的和真实的很接近(相似度)
  • 网格搜索(找值高的)和交叉验证(找相对较高的)(优化算法)
  • 泛化性 / 鲁棒性 / 稳定性
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档