前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >基于sklearn的K邻近分类器概念代码实现

基于sklearn的K邻近分类器概念代码实现

作者头像
月见樽
发布2018-04-27 11:57:55
6220
发布2018-04-27 11:57:55
举报

概念

KNN(K临近)分类器应该算是概率派的机器学习算法中比较简单的。基本的思想为在预测时,计算输入向量到每个训练样本的欧氏距离(几何距离),选取最近的K个训练样本,K个训练样本中出现最多的类别即预测为输入向量的类别(投票)

代码实现

载入数据集——鸢尾花数据集

代码语言:javascript
复制
from sklearn.datasets import load_iris
dataset = load_iris()
代码语言:javascript
复制
print(dataset.data.shape)
print(dataset.DESCR)
代码语言:javascript
复制
(150, 4)
Iris Plants Database
====================

Notes
-----
Data Set Characteristics:
    :Number of Instances: 150 (50 in each of three classes)
    :Number of Attributes: 4 numeric, predictive attributes and the class
    :Attribute Information:
        - sepal length in cm
        - sepal width in cm
        - petal length in cm
        - petal width in cm
        - class:
                - Iris-Setosa
                - Iris-Versicolour
                - Iris-Virginica
    :Summary Statistics:

    ============== ==== ==== ======= ===== ====================
                    Min  Max   Mean    SD   Class Correlation
    ============== ==== ==== ======= ===== ====================
    sepal length:   4.3  7.9   5.84   0.83    0.7826
    sepal width:    2.0  4.4   3.05   0.43   -0.4194
    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)
    petal width:    0.1  2.5   1.20  0.76     0.9565  (high!)
    ============== ==== ==== ======= ===== ====================

    :Missing Attribute Values: None
    :Class Distribution: 33.3% for each of 3 classes.
    :Creator: R.A. Fisher
    :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)
    :Date: July, 1988

This is a copy of UCI ML iris datasets.
http://archive.ics.uci.edu/ml/datasets/Iris

The famous Iris database, first used by Sir R.A Fisher

This is perhaps the best known database to be found in the
pattern recognition literature.  Fisher's paper is a classic in the field and
is referenced frequently to this day.  (See Duda & Hart, for example.)  The
data set contains 3 classes of 50 instances each, where each class refers to a
type of iris plant.  One class is linearly separable from the other 2; the
latter are NOT linearly separable from each other.

References
----------
   - Fisher,R.A. "The use of multiple measurements in taxonomic problems"
     Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to
     Mathematical Statistics" (John Wiley, NY, 1950).
   - Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis.
     (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218.
   - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System
     Structure and Classification Rule for Recognition in Partially Exposed
     Environments".  IEEE Transactions on Pattern Analysis and Machine
     Intelligence, Vol. PAMI-2, No. 1, 67-71.
   - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule".  IEEE Transactions
     on Information Theory, May 1972, 431-433.
   - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s AUTOCLASS II
     conceptual clustering system finds 3 classes in the data.
   - Many, many more ...

数据预处理

分割数据

代码语言:javascript
复制
from sklearn.cross_validation import train_test_split
x_train,x_test,y_train,y_test = train_test_split(dataset.data,dataset.target,test_size=0.25,random_state=1)
print(x_train.shape)
print(x_test.shape)
代码语言:javascript
复制
(112, 4)
(38, 4)

标准化

代码语言:javascript
复制
from sklearn.preprocessing import StandardScaler
stantard = StandardScaler()
x_train = stantard.fit_transform(x_train)
x_test = stantard.transform(x_test)

调用K邻近分类器

代码语言:javascript
复制
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier()
knn.fit(x_train,y_train)
代码语言:javascript
复制
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
           metric_params=None, n_jobs=1, n_neighbors=5, p=2,
           weights='uniform')

模型评估

自带评估

代码语言:javascript
复制
print(knn.score(x_test,y_test))
代码语言:javascript
复制
0.973684210526

评估器评估

代码语言:javascript
复制
from sklearn.metrics import classification_report
y_pre = knn.predict(x_test)
print(classification_report(y_test,y_pre,target_names=dataset.target_names))
代码语言:javascript
复制
             precision    recall  f1-score   support

     setosa       1.00      1.00      1.00        13
 versicolor       1.00      0.94      0.97        16
  virginica       0.90      1.00      0.95         9

avg / total       0.98      0.97      0.97        38

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2017.11.11 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 概念
  • 代码实现
    • 载入数据集——鸢尾花数据集
      • 数据预处理
        • 分割数据
        • 标准化
      • 调用K邻近分类器
        • 模型评估
          • 自带评估
          • 评估器评估
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档