距离产生美？k近邻算法python实现

1. 什么是k近邻算法？

k最近邻(k-Nearest Neighbor，kNN)分类算法是一个比较成熟也是最简单的机器学习(Machine Learning)算法之一。该方法的思路是：如果一个样本在特征空间中与k个实例最为相似(即特征空间中最邻近)，那么这k个实例中大多数属于哪个类别，则该样本也属于这个类别。

2. k近邻算法的本质

• 整个训练过程需要将所有的训练样本极其输出label存储起来，因此，空间成本很大。
• 测试过程中，每个测试样本都需要与所有的训练样本进行比较，运行时间成本很大。
• 采用距离比较的方式，分类准确率不高。

3. 数据准备

import numpy as np
import pandas as pd

data.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'species']    # 特征及类别名称

X = data.iloc[0:150, 0:4].values
y = data.iloc[0:150, 4].values
y[y == 'Iris-setosa'] = 0                                 # Iris-setosa 输出label用0表示
y[y == 'Iris-versicolor'] = 1                             # Iris-versicolor 输出label用1表示
y[y == 'Iris-virginica'] = 2                              # Iris-virginica 输出label用2表示
X_setosa, y_setosa = X[0:50], y[0:50]                     # Iris-setosa 4个特征
X_versicolor, y_versicolor = X[50:100], y[50:100]         # Iris-versicolor 4个特征
X_virginica, y_virginica = X[100:150], y[100:150]         # Iris-virginica 4个特征

import matplotlib.pyplot as plt

plt.scatter(X_setosa[:, 0], X_setosa[:, 2], color='red', marker='o', label='setosa')
plt.scatter(X_versicolor[:, 0], X_versicolor[:, 2], color='blue', marker='^', label='versicolor')
plt.scatter(X_virginica[:, 0], X_virginica[:, 2], color='green', marker='s', label='virginica')
plt.xlabel('sepal length')
plt.ylabel('petal length')
plt.legend(loc = 'upper left')
plt.show()

# training set
X_setosa_train = X_setosa[:30, :]
y_setosa_train = y_setosa[:30]
X_versicolor_train = X_versicolor[:30, :]
y_versicolor_train = y_versicolor[:30]
X_virginica_train = X_virginica[:30, :]
y_virginica_train = y_virginica[:30]
X_train = np.vstack([X_setosa_train, X_versicolor_train, X_virginica_train])
y_train = np.hstack([y_setosa_train, y_versicolor_train, y_virginica_train])

# validation set
X_setosa_val = X_setosa[30:40, :]
y_setosa_val = y_setosa[30:40]
X_versicolor_val = X_versicolor[30:40, :]
y_versicolor_val = y_versicolor[30:40]
X_virginica_val = X_virginica[30:40, :]
y_virginica_val = y_virginica[30:40]
X_val = np.vstack([X_setosa_val, X_versicolor_val, X_virginica_val])
y_val = np.hstack([y_setosa_val, y_versicolor_val, y_virginica_val])

# test set
X_setosa_test = X_setosa[40:50, :]
y_setosa_test = y_setosa[40:50]
X_versicolor_test = X_versicolor[40:50, :]
y_versicolor_test = y_versicolor[40:50]
X_virginica_test = X_virginica[40:50, :]
y_virginica_test = y_virginica[40:50]
X_test = np.vstack([X_setosa_test, X_versicolor_test, X_virginica_test])
y_test = np.hstack([y_setosa_test, y_versicolor_test, y_virginica_test])

4. kNN训练函数和预测函数

kNN的训练过程实际上是一种数据标类、数据存储的过程，不包含机器学习算法。首先我们需要定义一个类（class）来实现KNN算法模块。该类的初始化定义为：

class KNearestNeighbor(object):
def __init__(self):
pass

def train(self, X, y):
self.X_train = X
self.y_train = y

kNN的测试过程是核心部分。其中，有两点需要注意：

• 衡量距离的方式
• k值的选择

kNN距离衡量一般有两种方式：L1距离和L2距离。

L1距离的计算公式为：

d1(I1,I2)=∑p|Ip1−Ip2|d1(I1,I2)=∑p|I1p−I2p|

d1(I_1,I_2)=\sum_p|I_1^p-I_2^p|

L2距离的计算公式为：

d2(I1,I2)=∑p(Ip1−Ip2)2−−−−−−−−−−−√d2(I1,I2)=∑p(I1p−I2p)2

d2(I_1,I_2)=\sqrt{\sum_p(I_1^p-I_2^p)^2}

kNN中的k值选择至关重要，不同的k值也许能归属到不同的类别中，例如在下图中，k=3，则判定绿色实例属于红色三角形类别。

def predict(self, X, k=1)
# 计算L2距离
num_test = X.shape[0]
num_train = self.X_train.shape[0]
dists = np.zeros((num_test, num_train))    # 初始化距离函数
# because(X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train, so
d1 = -2 * np.dot(X, self.X_train.T)    # shape (num_test, num_train)
d2 = np.sum(np.square(X), axis=1, keepdims=True)    # shape (num_test, 1)
d3 = np.sum(np.square(self.X_train), axis=1)    # shape (1, num_train)
dist = np.sqrt(d1 + d2 + d3)
# 根据K值，选择最可能属于的类别
y_pred = np.zeros(num_test)
for i in range(num_test):
dist_k_min = np.argsort(dist[i])[:k]    # 最近邻k个实例位置
y_kclose = self.y_train[dist_k_min]     # 最近邻k个实例对应的标签
y_pred[i] = np.argmax(np.bincount(y_kclose))    # 找出k个标签中从属类别最多的作为预测类别

return y_pred

KNearestNeighbor类的完整定义代码如下：

class KNearestNeighbor(object):
def __init__(self):
pass

# 训练函数
def train(self, X, y):
self.X_train = X
self.y_train = y

# 预测函数
def predict(self, X, k=1):
# 计算L2距离
num_test = X.shape[0]
num_train = self.X_train.shape[0]
dists = np.zeros((num_test, num_train))    # 初始化距离函数
# because(X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train, so
d1 = -2 * np.dot(X, self.X_train.T)    # shape (num_test, num_train)
d2 = np.sum(np.square(X), axis=1, keepdims=True)    # shape (num_test, 1)
d3 = np.sum(np.square(self.X_train), axis=1)    # shape (1, num_train)
dist = np.sqrt(d1 + d2 + d3)
# 根据K值，选择最可能属于的类别
y_pred = np.zeros(num_test)
for i in range(num_test):
dist_k_min = np.argsort(dist[i])[:k]    # 最近邻k个实例位置
y_kclose = self.y_train[dist_k_min]     # 最近邻k个实例对应的标签
y_pred[i] = np.argmax(np.bincount(y_kclose.tolist()))    # 找出k个标签中从属类别最多的作为预测类别

return y_pred

5. 训练和预测

KNN.train(X_train, y_train)
y_pred = KNN.predict(X_test, k=6)
accuracy = np.mean(y_pred == y_test)
print('测试集预测准确率：%f' % accuracy)

# 训练集
plt.scatter(X_setosa_train[:, 0], X_setosa_train[:, 2], color='red', marker='o', label='setosa_train')
plt.scatter(X_versicolor_train[:, 0], X_versicolor_train[:, 2], color='blue', marker='^', label='versicolor_train')
plt.scatter(X_virginica_train[:, 0], X_virginica_train[:, 2], color='green', marker='s', label='virginica_train')
# 测试集
plt.scatter(X_setosa_test[:, 0], X_setosa_test[:, 2], color='y', marker='o', label='setosa_test')
plt.scatter(X_versicolor_test[:, 0], X_versicolor_test[:, 2], color='y', marker='^', label='versicolor_test')
plt.scatter(X_virginica_test[:, 0], X_virginica_test[:, 2], color='y', marker='s', label='virginica_test')

plt.xlabel('sepal length')
plt.ylabel('petal length')
plt.legend(loc = 4)
plt.show()

6. k近邻算法总结

k近邻算法是一种最简单最直观的分类算法。它的训练过程保留了所有样本的所有特征，把所有信息都记下来，没有经过处理和提取。而其它机器学习算法包括神经网络则是在训练过程中提取最重要、最有代表性的特征。在这一点上，kNN算法还非常不够“智能”。但是，kNN算法作为机器学习的基础算法，还是值得我们了解一下的。

0 条评论

• 斯坦福CS231n项目实战（一）：k最近邻（kNN）分类算法

版权声明：本文为博主原创文章，未经博主允许不得转载。 https://blog.csdn.n...

• 中国台湾大学林轩田机器学习技法课程学习笔记1 -- Linear Support Vector Machine

关于台湾大学林轩田老师的《机器学习基石》课程，我们已经总结了16节课的笔记。这里附上基石第一节课的博客地址： 台湾大学林轩田机器学习基石课程学习笔记1 – Th...

• 多元函数的泰勒(Taylor)展开式

多元函数的泰勒展开式 实际优化问题的目标函数往往比较复杂。为了使问题简化，通常将目标函数在某点附近展开为泰勒(Taylor)多项式来逼近原函数。 一元函数在点x...

• 基础知识 | 每日一练（177）

小林：这是个定义在 <assert.h> 中的宏, 用来测试断言。一个断言本质上是写下程序员的假设, 如果假设被违反, 那表明有个严重的程序错误。例如, 一个假...

• AI所有领域最优论文+代码查找神器：966个ML任务、8500+论文任你选

Papers With Code 由 Reddit 用户 rstoj 开发，旨在提供免费、开源的机器学习论文资源、代码和评估排行榜。目前该网站已经完成了代码与论...

• 操作系统-进程和线程

进程线程的区别 1、进程是什么？ 是具有一定独立功能的程序、它是系统进行资源分配和调度的一个独立单位，重点在系统调度和单独的单位，也就是说进程是可以独立运行...

• INotifyPropertyChanged 接口 CallerMemberName属性

使用调用方信息属性，可以获取关于调用方的信息传递给方法。 可以获取源代码、行号在源代码和调用方的成员名称的文件路径。 此信息用于跟踪，调试和创建诊断工具非常有用...

• 【答疑释惑第十一讲】开发应用主要用什么语言？

疑惑一 开发应用主要用什么语言？ 其实这个问题很多工作的，看到都觉得这个问的有点没有头绪，但是对于初学者来说还是希望更多的了解点这方面的，所谓应用开发就是处于底...

• LockSupport 工具类使用以及实现原理

JDK中的rt.jar包里面的是个LockSupport是个工具类，它的主要作用是挂起和唤醒线程，该工具类是创建锁和其他同步类的基础。