上一篇文章中,我们了解了机器学习的一系列基本概念。
机器学习的基本概念 本文中我们来介绍最简单的分类算法:k 近邻算法(kNN)
k 近邻算法是一种采用测量不同特征值之间的距离的方法对样本进行分类的算法。 他的工作原理是,存在一个样本数据集合,并且每个数据都存在分类标签,对于没有标签的新数据,将这个新数据的每个特征与样本集中的数据对应的特征进行比较,然后提取样本集中特征最相似的数据(最近邻)的分类标签。 通常来说,我们只选择样本数据集中前 k 个最相近的数据,这就是 k 近邻算法的得名,通常 k 都不大于 20,在这 k 个数据中,出现次数最多的分类就输出作为新数据的分类。
计算样本数据的距离是非常简单的:
上面的公式进行推广就是欧式距离计算公式:
k 近邻算法具有下面三个优点: 1. 简单好用,容易理解,精度高,理论成熟,既可以用来做分类也可以用来做回归 2. 可用于数值型数据和离散型数据 3. 训练时间复杂度为 O(n),无数据输入假定 4. 对异常值不敏感
但是,k近邻算法也具有下面的缺点: 1. 计算复杂性高;空间复杂性高 2. 样本不平衡问题(即有些类别的样本数量很多,而其它样本的数量很少) 3. 一般数值很大的时候不用这个,计算量太大 4. 单个样本不能太少,否则容易发生误分 5. 无法给出数据的内在含义
我们用 KNN 算法来实现一个电影分类的模型。 在电影中,打斗镜头和亲吻镜头是频繁出现的,但是我们不能认为有打斗镜头就把电影分类为动作片,也不能认为有亲吻镜头就认为电影是爱情片。 那么,假如此时有一部电影,其中有 18 个打斗镜头,90 个亲吻镜头,他究竟是爱情片还是动作片呢? 下面我们就用 KNN 算法来实现这个模型和预测。
电影名称 | 打斗镜头 | 接吻镜头 | 电影类型 |
---|---|---|---|
1 | 3 | 104 | 爱情片 |
2 | 2 | 100 | 爱情片 |
3 | 99 | 5 | 动作片 |
4 | 98 | 2 | 动作片 |
未知电影 | 18 | 90 | 未知 |
# -*- coding: utf-8 -*-
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import operator
# noinspection PyShadowingNames
def createDataSet(): # 创建数据集
group = np.array([[3, 104], [2, 100], [99, 5], [98, 2]])
labels = ['爱情片', '爱情片', '动作片', '动作片']
return group, labels
def knn(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0] # shape[0]返回dataSet的行数
diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet # tile(inX,(a,b))函数将inX重复a行,重复b列,以便用于和所有样本作差
sqDiffMat = diffMat ** 2 # 作差后平方
sqDistances = sqDiffMat.sum(axis=1) # sum()求和函数,sum(0)每列所有元素相加,sum(1)每行所有元素相加
distances = sqDistances ** 0.5 # 开平方,求欧式距离
sortedDistIndicies = distances.argsort() # argsort函数返回的是数组值从小到大的索引值
classCount = {}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]] # 取出前k个距离对应的标签
# 计算每个类别的样本数。字典get()函数返回指定键的值,如果值不在字典中返回默认值0
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0] # 返回字典的第一条的key,也即是测试样本所属类别
if __name__ == '__main__':
group, labels = createDataSet() # 创建数据集
print('group:\n', group) # 打印数据集
print('labels:', labels)
zhfont = matplotlib.font_manager.FontProperties(fname=r'c:\windows\fonts\simsun.ttc') # 设置中文字体路径
plt.figure(figsize=(10, 8)) # 可视化
ax = plt.subplot(111) # 图片在第一行,第一列的第一个位置
ax.scatter(group[0:2, 0], group[0:2, 1], color='red', s=50)
ax.scatter(group[2:4, 0], group[2:4, 1], color='blue', s=50)
ax.scatter(18, 90, color='orange', s=50)
plt.annotate('which class?', xy=(18, 90), xytext=(3, 2), arrowprops=dict(facecolor='black', shrink=0.05), )
plt.xlabel('打斗镜头', fontproperties=zhfont)
plt.ylabel('接吻镜头', fontproperties=zhfont)
plt.title('电影分类可视化', fontproperties=zhfont)
plt.show()
testclass = knn([18, 90], group, labels, 3) # 用未知的样本来测试算法
print('测试结果:', testclass) # 打印测试结果
运行结果展示了:
group:
[[ 3 104]
[ 2 100]
[ 99 5]
[ 98 2]]
labels: ['爱情片', '爱情片', '动作片', '动作片']
测试结果: 爱情片
代码清晰的分为三个部分: 1. 数据准备 — createDataSet 2. 算法实现 — knn 3. 结果输出 — main
最重要的就是 knn 函数,他承担了 knn 算法的实现。 很巧妙的是,通过将样本数据进行 np.tile 操作 — 把单一样本变成样本矩阵,从而通过矩阵操作实现了多个样本与测试数据之间的作差、平方操作。 在此之后,对求得的结果进行排序,取出前 K 个最近的结果中出现次数最多的 label,那就是我们的预测结果了。
虽然 KNN 算法非常简单,但 sklearn 包中有着封装好的现成实现,可以直接传递参数进行调用。 下面我们来看看如何使用 sklearn 来进行 KNN 算法的实现。
Sklearn 的全称是 Scikit learn,是机器学习领域当中最知名的python模块之一。 sklearn包含了很多机器学习的算法:
由于 Sklearn 对算法进行了非常完备的封装,一个复杂的算法,只需要调用现成的函数传入参数即可,可以大大减少我们开发和调试的时间与精力。
sklearn 安装较为简单,只要执行下面的命令即可:
pip install sklearn
但有些环境下,会报出错误,通常在下面的网站中下载安装对应版本的 whl 文件再执行上面的命令安装即可:
下面是 KNN 算法的 sklearn 官方文档地址:
sklearn.neighbors.KNeighborsClassifier 就是实现 k 近邻算法的类:
class sklearn.neighbors.KNeighborsClassifier(n_neighbors=5, weights=’uniform’, algorithm=’auto’, leaf_size=30, p=2, metric=’minkowski’, metric_params=None, n_jobs=None, **kwargs)
KNneighborsClassifier 具有以下类构造参数:
在使用构造参数构造出 KNeighborsClassifier 类对象以后,调用其类成员函数就可以完成模型的构建与调用了:
>>> X = [[0], [1], [2], [3]]
>>> y = [0, 0, 1, 1]
>>> from sklearn.neighbors import KNeighborsClassifier
>>> neigh = KNeighborsClassifier(n_neighbors=3)
>>> neigh.fit(X, y)
KNeighborsClassifier(...)
>>> print(neigh.predict([[1.1]]))
[0]
>>> print(neigh.predict_proba([[0.9]]))
[[0.66666667 0.33333333]]
官方文档中还有很多其他的深入的例子可以供进一步研究:
Peter Harrington 《机器学习实战》。 http://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html。 https://blog.csdn.net/c406495762/article/details/75172850。 https://blog.csdn.net/u013829973/article/details/77942942。