蜗壳机器学习指南:KNN

KNN的原理

KNN的原理非常简单。首先我们有训练集的数据,这些数据都是有标签的。然后对于一个测试集的样本,我们想知道它的标签是什么。我们将该测试集样本的特征与训练集所有样本的特征进行比对,找一个距离最近的,将该训练集样本的标签作为这个测试集样本的标签。这就是所谓的最近邻算法,但实际中我们常常选取前K个最近的样本,然后选择这K个最相似数据中出现次数最多的标签作为测试集样本的标签。这就是所谓的K近邻算法。

KNN实战

首先我们定义一个KNN的类KnnClassifier,完成初始化方法,引入训练集数据和其标签。

前面提到过,我们需要衡量一个测试集样本到所有训练集样本的距离,我们现在来定义这个距离。这里我们使用欧式距离来衡量。

接下来我们来定义分类过程。为什么我们没有一个专门的训练过程呢?这是因为KNN只需要将训练集的信息记录下来,并不需要其它的真正意义上的训练操作;即KNN没有显示的训练过程,它是“懒惰学习”的代表,它在训练阶段只是把数据保存下来,训练时间开销为0,等收到测试样本后进行处理。

defclassify(self, sample, k=3):

"""分类

给定一个测试集样本进行分类

Args:

sample: 测试集样本,array类型

k: KNN中的k

Returns:

测试sample的标签

Raises:

pass

"""

#计算测试集样本到所有训练集的距离

dist = np.array([self.L2dist(sample, s)forsinself.samples])

#按照距离进行排序

ndx = dist.argsort()

#选取k个最大的

votes = defaultdict(int)

foriinrange(k):

label = self.labels[ndx[i]]

votes[label] +=1

returnmax(votes)

然后我们来构造数据集并利用上面完成的模型进行分类,且对于测试集得出准确率。

n=200

#构造两个正态分布训练数据集

class_1 =0.6* randn(n,2)

class_2 =1.2* randn(n,2) + np.array([5,1])

#训练数据集的标签

labels = np.hstack((np.ones(n),-np.ones(n)))

#构造两个正态分布测试数据集

test_1 =0.6* randn(n,2)

test_2 =1.2* randn(n,2) + np.array([5,1])

#测试数据集的标签

test_labels = np.hstack((np.ones(n),-np.ones(n)))

model = KnnClassifier(np.vstack((class_1, class_2)), labels)

#计算Accuracy

preds = []

foritem in np.vstack((test_1, test_2)):

pred = model.classify(item)

preds.append(pred)

preds = np.array(preds)

idx = (preds == test_labels)

accuracy = sum(idx) / len(idx)

print("Accuracy: ", accuracy)

最后考虑如何展示结果:对于两种类别我们分布用红色和绿色表示,且对于预测正确的样本,使用星花表示;对于预测错误的样本,使用圆表示。

plt.figure(figsize=(10,6))

plt.scatter(class_1_right[:,], class_1_right[:,1],marker ='*',

color ='r')

plt.scatter(class_1_wrong[:,], class_1_wrong[:,1],marker ='o',

color ='r')

plt.scatter(class_2_right[:,], class_2_right[:,1],marker ='*',

color ='g')

plt.scatter(class_2_wrong[:,], class_2_wrong[:,1],marker ='o',

color ='g')

plt.xticks([]),plt.yticks([])

plt.show()

最终的结果如图所示

编辑:蜗壳小肥 马宁馨

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20180829G0N6A400?refer=cp_1026
  • 腾讯「云+社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 yunjia_community@tencent.com 删除。

扫码关注云+社区

领取腾讯云代金券