前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >不要太强!全面总结 KNN !!

不要太强!全面总结 KNN !!

作者头像
Python编程爱好者
发布2023-11-29 13:54:01
3580
发布2023-11-29 13:54:01
举报

哈喽,我是Johngo~

今儿准备了 KNN 的文章给到大家,因为后台很多人问到了关于KNN相关的内容细节!

恰好需要的同学,认真读一读;未来需要的同学收个藏,给自己准备一下。期待反馈!

老规矩大家伙如果觉得近期文章还不错!欢迎大家点个赞、转个发,让更多的朋友看到。

那咱们就开始吧~

先来入个门

KNN(K-Nearest Neighbors)算法最擅长处理的场景是小到中等规模的数据集,特别是当数据集中的样本具有清晰的区分特征且样本间的相似度(或距离)能够有效地反映其类别关系时。这种算法适用于那些特征空间中样本分布较为紧密且具有一定规律的情况,如在一些分类和回归任务中,尤其是在医学、生物信息学和图像识别等领域,其中样本的特征向量直接影响到分类的准确性。

KNN 的主要优势在于它的简单性、直观性和对于数据分布没有严格假设的灵活性,但它在处理大规模数据集、高维数据或者样本特征间相似度不明显的情况时效率较低。

KNN 的一点原理

KNN 算法的工作原理相当直接:就是对于给定的测试数据点,算法搜索训练集中距离这个测试点最近的 K 个点,基于这些邻近点的已知类别来预测测试点的类别。对于分类任务,采用“投票制”(即选择这 K 个点中最常见的类别);对于回归任务,则通常采用平均值。

距离的计算可以采用不同的度量,如欧几里得距离(最常用)、曼哈顿距离、切比雪夫距离等。

欧几里得距离的公式如下:

d(p, q) = \sqrt{(q_1 - p_1)^2 + (q_2 - p_2)^2 + \cdots + (q_n - p_n)^2}

其中

p

q

是两个样本点,

n

是特征的数量。

模型训练

KNN 算法的“训练”过程实际上非常简单,因为它不需要构建一个显式的模型。所有的工作都发生在预测阶段。然而,确定最佳的 K 值以及选择合适的距离度量是 KNN 的关键。

一个简单示例

以下是使用 Python 和 Scikit-Learn 实现 KNN 分类器的一个示例。

数据集:著名的鸢尾花(Iris)

代码语言:javascript
复制
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score

# 加载数据集
iris = datasets.load_iris()
X = iris.data
y = iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# 创建 KNN 分类器实例
k = 3
knn = KNeighborsClassifier(n_neighbors=k)

# 训练模型
knn.fit(X_train, y_train)

# 进行预测
y_pred = knn.predict(X_test)

# 评估模型
print("Accuracy:", accuracy_score(y_test, y_pred))

# 可视化
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_pred)  # 仅取前两个特征进行可视化
plt.title("KNN Classification")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.show()

KNN 的优缺点

优点

  1. 简单易懂:KNN 算法非常直观且易于理解。
  2. 无需训练:由于 KNN 是一种基于实例的学习,不需要显式的训练过程。
  3. 适应性强:适用于多类别问题,对数据的分布没有假设,既可以用于分类也可以用于回归。

缺点

  1. 计算成本高:KNN 在每次分类时都需要计算测试数据与每个训练样本的距离。
  2. 存储空间大:需要存储整个训练数据集。
  3. 对不平衡数据敏感:如果某些类别的样本远多于其他类别,KNN 会偏向于多数类。
  4. 维度灾难:在高维数据中性能下降。

尽管 KNN 在理论上简单直观,但在实际应用中,选择合适的 K 值、距离度量以及处理高维数据和大数据集等问题,仍需要细致的考虑。此外呢,随着数据规模的增长,KNN 的计算和存储成本可能会成为限制因素。因此,在实际应用中,经常需要与其他技术(如维度降低、数据预处理等)结合使用。

案例 1:手写数字识别

手写数字识别通常用于演示分类算法的效果。在这个案例中,我们使用 KNN 算法来识别手写数字。KNN 是一种简单且有效的分类方法,它通过查找最近的 K 个训练样本来预测新样本的类别。

数据集

我们将使用 scikit-learn 库中的手写数字数据集 (load_digits) 进行演示。这个数据集包含了 1797 个 8x8 像素的手写数字图像。

计算步骤

  1. 加载数据集并划分为训练集和测试集。
  2. 对于测试集中的每个样本,计算其与训练集中所有样本的欧几里得距离。
  3. 找出距离最近的 K 个训练样本。
  4. 根据这 K 个样本的类别,通过投票机制决定测试样本的类别。

代码演示

代码语言:javascript
复制
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
import matplotlib.pyplot as plt

# 加载数据集
digits = datasets.load_digits()

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size=0.2, random_state=0)

# 创建 KNN 分类器
knn = KNeighborsClassifier(n_neighbors=3)

# 训练模型
knn.fit(X_train, y_train)

# 进行预测
y_pred = knn.predict(X_test)

# 显示几个测试图像及其预测结果
_, axes = plt.subplots(nrows=1, ncols=4, figsize=(10, 3))
for ax, image, prediction in zip(axes, X_test, y_pred):
    ax.set_axis_off()
    image = image.reshape(8, 8)
    ax.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
    ax.set_title(f'Predicted: {prediction}')
plt.show()

代码细节解释

  1. 数据加载与划分
    • 使用 datasets.load_digits() 加载数据集。
    • 使用 train_test_split() 将数据集划分为训练集和测试集。
  2. 创建 KNN 分类器
    • KNeighborsClassifier(n_neighbors=3) 创建一个 KNN 分类器实例,其中 n_neighbors=3 表示选择最近的 3 个邻居。
  3. 模型训练
    • 使用 fit() 方法训练模型。这里没有显式的训练过程,KNN 只是存储训练数据。
  4. 模型预测与可视化
    • 使用 predict() 方法对测试集进行预测。
    • 使用 Matplotlib 展示测试图像及其预测的类别。

总结一下

KNN 在手写数字识别上表现良好,原因是它能有效地找出相似的样本。然而,KNN 的计算成本随着数据集大小的增加而显著增加。此外,它对数据缩放敏感,因此在实际应用中可能需要先进行数据标准化或归一化处理。

案例2:文本分类

文本分类是自然语言处理中的一个基本任务,它涉及将文本文档分配给一个或多个类别。

这个案例中,依然使用 KNN 算法来执行文本分类任务。我们将以新闻组文档分类为例,目标是根据文档内容将其分类到不同的新闻组。

在文本分类中,KNN 算法通过查找与测试文档最接近的 K 个训练文档,然后根据这些邻近文档的已知类别来预测测试文档的类别。

一点公式

在文本分类中,常用的距离度量是余弦相似度,其公式如下:

\text{Cosine Similarity}(A, B) = \frac{A \cdot B}{\|A\| \|B\|} = \frac{\sum_{i=1}^{n} A_i \times B_i}{\sqrt{\sum_{i=1}^{n} A_i^2} \times \sqrt{\sum_{i=1}^{n} B_i^2}}

其中,

A

B

是两个文档的向量表示,

n

是向量的维度。

数据集

我们将使用 scikit-learn 库中的 '20 Newsgroups' 数据集进行演示。这个数据集是一个包含约 20,000 个新闻组文档的集合,分布在 20 个不同的新闻组。

计算步骤

  1. 加载数据集并划分为训练集和测试集。
  2. 将文本数据转换为向量表示(例如使用 TF-IDF)。
  3. 对于测试集中的每个文档,计算其与训练集中所有文档的相似度。
  4. 找出相似度最高的 K 个训练文档。
  5. 根据这 K 个文档的类别,通过投票机制决定测试文档的类别。

代码演示

代码语言:javascript
复制
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.neighbors import KNeighborsClassifier
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns

# 加载数据集
newsgroups = fetch_20newsgroups(subset='all')
X, y = newsgroups.data, newsgroups.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

# 创建一个 TfidfVectorizer 和 KNN 的 pipeline
knn_pipeline = make_pipeline(TfidfVectorizer(), KNeighborsClassifier(n_neighbors=3))

# 训练模型
knn_pipeline.fit(X_train, y_train)

# 进行预测
y_pred = knn_pipeline.predict(X_test)

# 评估模型
print("Classification report:\n", classification_report(y_test, y_pred))

# 计算混淆矩阵
cm = confusion_matrix(y_test, y_pred)

# 绘制混淆矩阵
plt.figure(figsize=(10, 10))
sns.heatmap(cm, annot=True, fmt="d", cmap='Blues')
plt.title('Confusion Matrix')
plt.ylabel('Actual Labels')
plt.xlabel('Predicted Labels')
plt.show()

代码细节

数据加载与划分

  • 使用 fetch_20newsgroups() 加载 '20 Newsgroups' 数据集。
  • 使用 train_test_split() 将数据集划分为训练集和测试集。

文本向量化

  • TfidfVectorizer() 被用来将文本数据转换为 TF-IDF 特征。

创建 KNN 分类器

  • KNeighborsClassifier(n_neighbors=3) 创建一个 KNN 分类器实例,其中 n_neighbors=3 表示选择最相似的 3 个邻居。

模型训练和预测

  • 使用 fit() 方法训练模型,并用 predict() 进行预测。

模型评估

代码语言:javascript
复制
Classification report:
               precision    recall  f1-score   support

           0       0.68      0.91      0.78       163
           1       0.60      0.82      0.69       190
           2       0.63      0.73      0.68       200
           3       0.65      0.71      0.68       196
           4       0.73      0.71      0.72       201
           5       0.83      0.72      0.77       198
           6       0.78      0.58      0.66       206
           7       0.85      0.86      0.86       177
           8       0.89      0.80      0.84       189
           9       0.83      0.88      0.85       171
          10       0.90      0.92      0.91       233
          11       0.81      0.92      0.86       190
          12       0.85      0.70      0.77       207
          13       0.92      0.81      0.86       203
          14       0.91      0.90      0.90       191
          15       0.87      0.77      0.82       198
          16       0.90      0.85      0.87       155
          17       0.90      0.94      0.92       196
          18       0.86      0.80      0.83       170
          19       0.82      0.71      0.76       136

    accuracy                           0.80      3770
   macro avg       0.81      0.80      0.80      3770
weighted avg       0.81      0.80      0.80      3770
  • 使用 classification_report() 来评估模型的性能。

绘制图像:

  • 在这段代码中,confusion_matrix 用于计算混淆矩阵,而 seaborn.heatmap 用于绘制矩阵图。每个格子的数值表示实际类别和预测类别的匹配数量,颜色的深浅表示数量的大小。

总结一下

使用 KNN 进行文本分类可以获得不错的效果,特别是当文档集较小且文档之间的区分度较高时。然而,由于 KNN 算法在预测阶段需要与训练集中的每个样本进行比较,因此在大型数据集上可能会有较高的计算成本。此外,合适的 K 值选择和距离度量方法对于优化性能至关重要。

最后绘制混淆矩阵,显示实际类别与预测类别之间的关系。混淆矩阵可以帮助我们理解分类模型在各个类别上的表现,包括哪些类别容易被混淆。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2023-11-28,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Johngo学长 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 先来入个门
    • KNN 的一点原理
      • 模型训练
        • 一个简单示例
          • KNN 的优缺点
          • 案例 1:手写数字识别
            • 数据集
              • 计算步骤
                • 代码演示
                  • 代码细节解释
                    • 总结一下
                    • 案例2:文本分类
                      • 一点公式
                        • 数据集
                          • 计算步骤
                            • 代码演示
                              • 代码细节
                                • 总结一下
                                相关产品与服务
                                腾讯云服务器利旧
                                云服务器(Cloud Virtual Machine,CVM)提供安全可靠的弹性计算服务。 您可以实时扩展或缩减计算资源,适应变化的业务需求,并只需按实际使用的资源计费。使用 CVM 可以极大降低您的软硬件采购成本,简化 IT 运维工作。
                                领券
                                问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档