【人工智能-初级】第3章 k-最近邻算法(KNN):分类和Python实现

文章目录

    • 一、KNN算法简介
    • 二、KNN算法的工作原理
      • 2.1 欧氏距离
    • 三、K值的选择
    • 四、KNN算法的优缺点
      • 4.1 优点
      • 4.2 缺点
    • 五、Python实现KNN分类
      • 5.1 导入必要的库
      • 5.2 加载数据集并进行预处理
      • 5.3 创建KNN分类器并进行训练
      • 5.4 模型预测与评估
      • 5.5 可视化K值对模型性能的影响
    • 六、总结
      • 6.1 学习要点
      • 6.2 练习题

一、KNN算法简介

K-最近邻算法(K-Nearest Neighbors,简称KNN)是一种简单而有效的监督学习算法,主要用于分类和回归问题。在分类问题中,KNN算法通过计算测试样本与训练样本之间的距离,找到距离测试样本最近的 k 个训练样本,然后通过这 k 个样本的类别进行投票决定测试样本的类别。在回归问题中,KNN则是通过这些最近邻的平均值来预测输出。

KNN是一种基于实例的学习算法,它没有显式的模型训练过程,而是直接利用所有训练数据进行预测。正因为其简单和直观的特点,KNN广泛用于各种应用中,包括图像分类、文本分类和推荐系统等。

二、KNN算法的工作原理

KNN的工作原理主要包含以下几个步骤:

  1. 计算距离:计算测试样本与训练样本之间的距离,通常使用欧氏距离(Euclidean Distance),也可以使用曼哈顿距离(Manhattan Distance)或余弦相似度(Cosine Similarity)等。

  2. 选择最近的K个邻居:根据距离大小,选择与测试样本距离最近的 k 个训练样本。

  3. 投票决定类别:对于分类问题,KNN通过这 k 个邻居的类别进行投票,将类别最多的作为预测结果。对于回归问题,则通过最近 k 个点的平均值来得到预测值。

2.1 欧氏距离

欧氏距离是最常用的距离度量方法之一,用于度量两个样本点之间的直线距离。对于两个点 AB,其坐标分别为 (x1, y1)(x2, y2),欧氏距离的计算公式为:

d ( A , B ) = ( x 2 − x 1 ) 2 + ( y 2 − y 1 ) 2 d(A, B) = \sqrt{(x2 - x1)^2 + (y2 - y1)^2} d(A,B)=(x2x1)2+(y2y1)2

在多维空间中,同样可以使用欧氏距离,公式如下:

d ( A , B ) = ∑ i = 1 n ( x i A − x i B ) 2 d(A, B) = \sqrt{\sum_{i=1}^n (x_{i}^{A} - x_{i}^{B})^2} d(A,B)=i=1n(xiAxiB)2

其中,n 是样本特征的维数。

三、K值的选择

K值的选择对于KNN算法的效果非常重要。如果 K 值太小,模型容易受到噪声数据的影响,导致过拟合(overfitting);如果 K 值太大,模型则会变得过于平滑,导致欠拟合(underfitting)。因此,我们需要通过交叉验证等方法来选择最合适的 K 值。

通常,K值取奇数,特别是在二分类问题中,以避免投票结果出现平局的情况。

四、KNN算法的优缺点

4.1 优点

  1. 简单易懂:KNN算法的原理非常简单,容易理解和实现。
  2. 无训练过程:KNN不需要显式的模型训练,可以直接用于预测,适用于小规模数据集。
  3. 适用性广:KNN可以处理多分类问题和回归问题,并且适用于多种距离度量方法。

4.2 缺点

  1. 计算复杂度高:对于每一个测试样本,KNN都需要计算与所有训练样本的距离,当数据集很大时,计算开销非常大。
  2. 内存消耗大:KNN需要存储所有的训练数据,因此对内存的要求较高。
  3. 对特征尺度敏感:KNN对特征的尺度比较敏感,如果特征之间的尺度相差较大,可能会导致距离度量不准确,因此在使用KNN之前通常需要对数据进行归一化处理。

五、Python实现KNN分类

下面我们将通过Python实现一个简单的KNN分类模型,使用 scikit-learn 库来帮助我们完成这一任务。

5.1 导入必要的库

首先,我们需要导入一些必要的库:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, confusion_matrix
  • numpy:用于数值计算。
  • matplotlib:用于数据可视化。
  • sklearn.datasets:用于加载 Iris 数据集,这是一个经典的多分类数据集。
  • train_test_split:用于将数据集拆分为训练集和测试集。
  • StandardScaler:用于数据标准化。
  • KNeighborsClassifier:KNN分类器。
  • accuracy_score, confusion_matrix:用于评估模型的准确率和混淆矩阵。

5.2 加载数据集并进行预处理

我们使用 Iris 数据集,这是一个常用的多分类数据集,包含三类花(山鸢尾、变色鸢尾、维吉尼亚鸢尾),每类有50个样本。

# 加载Iris数据集
data = load_iris()
X = data.data
y = data.target# 将数据集拆分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 对特征进行标准化处理
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
  • load_iris():加载Iris数据集,X 是特征矩阵,y 是标签。
  • train_test_split:将数据集拆分为训练集和测试集,20%的数据用于测试。
  • StandardScaler:对数据进行标准化,使每个特征具有零均值和单位方差,减少特征间的尺度差异。

5.3 创建KNN分类器并进行训练

我们创建一个KNN分类器,设定 k=3,并用训练集进行模型训练。

# 创建KNN分类器
knn = KNeighborsClassifier(n_neighbors=3)# 训练模型
knn.fit(X_train, y_train)
  • KNeighborsClassifier(n_neighbors=3):创建KNN分类器,并设置邻居数为3。
  • knn.fit(X_train, y_train):用训练数据拟合KNN模型。

5.4 模型预测与评估

使用测试集进行预测,并评估模型的性能。

# 对测试集进行预测
y_pred = knn.predict(X_test)# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"模型的准确率: {accuracy * 100:.2f}%")# 混淆矩阵
conf_matrix = confusion_matrix(y_test, y_pred)
print("混淆矩阵:\n", conf_matrix)
  • knn.predict(X_test):对测试集进行预测。
  • accuracy_score:计算预测的准确率。
  • confusion_matrix:计算混淆矩阵,用于评估分类器在每个类别上的表现。

5.5 可视化K值对模型性能的影响

为了选择合适的K值,我们可以绘制不同K值下模型准确率的变化图。

# 尝试不同的K值,计算模型的准确率
k_values = range(1, 26)
accuracies = []for k in k_values:knn = KNeighborsClassifier(n_neighbors=k)knn.fit(X_train, y_train)y_pred = knn.predict(X_test)accuracies.append(accuracy_score(y_test, y_pred))# 绘制准确率变化图
plt.plot(k_values, accuracies, marker='o')
plt.xlabel('K值')
plt.ylabel('准确率')
plt.title('不同K值下的模型准确率')
plt.grid(True)
plt.show()

通过运行上述代码,我们可以看到不同 K 值对模型准确率的影响,从而选择最优的 K 值。

六、总结

KNN是一种简单直观的监督学习算法,适用于分类和回归问题。它通过计算测试样本与训练样本之间的距离,找到最近的K个邻居进行投票决定类别。在实现KNN时,我们需要注意特征的尺度和K值的选择。KNN的优点是简单、易于理解,但其计算复杂度较高,尤其在大规模数据集上。因此,KNN更适用于小规模数据集。

6.1 学习要点

  1. KNN原理:通过距离度量,找到测试样本的最近邻并投票决定其类别。
  2. 距离度量方法:欧氏距离是最常用的距离度量方法。
  3. K值选择:K值太小容易过拟合,K值太大容易欠拟合,可以通过交叉验证选出最优的K值。
  4. Python实现:可以使用 scikit-learn 库中的 KNeighborsClassifier 轻松实现KNN分类。

6.2 练习题

  1. 使用KNN算法对 Iris 数据集进行回归,尝试使用不同的K值,观察模型表现的变化。
  2. 尝试使用曼哈顿距离或余弦相似度作为KNN中的距离度量方法,比较其与欧氏距离的性能差异。
  3. 使用 sklearn.datasets 模块中的 load_wine 数据集,构建一个KNN分类模型,预测葡萄酒的类别。

如果您觉得本文有帮助,欢迎继续学习本专栏的其他内容,下一篇文章将为您介绍逻辑回归及其Python实现。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/57542.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

服务器磁盘爆满?别慌,教你轻松清理!

服务器磁盘爆满?别慌,教你轻松清理! 简介 服务器磁盘空间告急,网站访问缓慢,甚至无法正常运行?别担心,这篇文章将为你提供一份详细的清理指南,帮助你快速释放服务器磁盘空间&#x…

【算法】Bellman-Ford单源最短路径算法(附动图)

目录 一、性质 二、思路 三、有边路限制的最短路 一、性质 适用于含有负权边的图(Dijkstra不适用) 更简单,但效率慢 如果对应路径存在负权回路则没有最短路径(可用于判断图中是否存在负权回路) 相比于spfa&#…

[分享] SQL在线编辑工具(好用)

在线SQL编写工具(无广告) - 在线SQL编写工具 - Web SQL - SQL在线编辑格式化 - WGCLOUD

物联网实训项目:绿色家居套件

1、基本介绍 绿色家居通过物联网技术将家中的各种设备连接到一起,提供家电控制、照明控制、电话远程控制、室内外遥控、防盗报警、环境监测、暖通控制、红外转发以及可编程定时控制等多种功能和手段。绿色家居提供全方位的信息交互功能,甚至为各种能源费…

使用DeepSpeed进行单机多卡训练

这是你提供的DeepSpeed单机多卡训练步骤的Markdown格式: 使用 DeepSpeed 进行单机多卡训练的主要步骤 1. 安装 DeepSpeed 确保你已经安装了 DeepSpeed 及其依赖: pip install deepspeed设置模型并集成 DeepSpeed 在模型的定义和训练循环中集成 Deep…

solana phantom NFT图片显示不出来?

solana phantom NFT图片显示不出来? 问题 同样是jpeg格式图片,一个phatom可以显示,一个不可以显示为什么,nft图片格式大小有要求吗? 问题分析 Phantom 官网有一些关于 NFT 集成的文档,其中可能会有关于图片大小限制…

049_python基于Python的热门微博数据可视化分析

目录 系统展示 开发背景 代码实现 项目案例 获取源码 博主介绍:CodeMentor毕业设计领航者、全网关注者30W群落,InfoQ特邀专栏作家、技术博客领航者、InfoQ新星培育计划导师、Web开发领域杰出贡献者,博客领航之星、开发者头条/腾讯云/AW…

@tarojs/components 和 taro-ui 中的组件之间的区别

1. 来源与用途: tarojs/components:Taro 官方提供的基础组件库,包含了微信小程序、H5 等不同平台的通用组件(如 View, Input, Button, Form 等)。这些组件是跨平台的,并提供了与微信小程序等平台原生组件类…

15分钟学Go 第7天:控制结构 - 条件语句

第7天:控制结构 - 条件语句 在Go语言中,控制结构是程序逻辑的重要组成部分。通过条件语句,我们可以根据不同的条件采取不同的行动。今天我们将详细探讨Go语言中的两种主要条件结构:if语句和switch语句。理解这些控制结构对于编写…

CTA-GAN:基于生成对抗网络对颈动脉和主动脉的非增强CT影像进行血管增强

写在前面 目前只分析了文章的大体内容和我个人认为的比较重要的细节,代码实现还没仔细看,后续有时间会补充代码细节部分。 文章地址:Generative Adversarial Network-based Noncontrast CT Angiography for Aorta and Carotid Arteries 代…

JAVA基础面试题准备

一些常见的JAVA基础题,面试中遇到过的会加*显示。 JAVA基础 1.Java中重载和重写的区别?* 2.int 和Integer类型这两个区别吗? 为什么需要有Integer类型: int和Integer类型的区别: 3.遍历list有那些方式吗?…

python如何提取MYSQL数据,并在完成数据处理后保存?

在现代数据驱动的世界中,数据分析已成为企业决策的重要组成部分。 Python作为一种强大的编程语言,因其丰富的库和简单的语法,广泛应用于数据分析、数据清洗和数据可视化等领域。 本文将详细介绍如何使用Python提取MySQL数据库中的数据,并进行数据分析、数据清洗、汇总等操…

【Linux】进程信号(下)

目录 一、信号的阻塞 1.1 信号在内核中的保存方式 1.2 sigset_t信号集 (1)信号集操作 (2)sigprocmask函数 (3)sigpending函数 二、信号的处理 2.1 用户态和内核态 2.2 重谈进程地址空间 三、信号…

盘点2024年4款高清稳定的Windows10录屏工具。

Windows10电脑录屏在生活当中还是挺重要的,无论是教育领域的制作教程,还是游戏玩家记录精彩瞬间,亦或是商务人士进行演示,录屏都能发挥巨大作用。如果设备自带的一些工具无法完成录屏需求的话,这里帮大家找了几款好用到…

AI大模型应用(3)开源框架Vanna: 利用RAG方法做Text2SQL任务

AI大模型应用(3)开源框架Vanna: 利用RAG方法做Text2SQL任务 RAG(Retrieval-Augmented Generation,如下图所示)检索增强生成,即大模型LLM在回答问题时,会先从大量的文档中检索出相关信息,然后基于这些检索出…

W25Q64的学习

24位地址意味着系统有24根地址线,每根地址线可以取两种状态(0或1),所以系统可以形成 2242^{24}224 个不同的地址组合。每个地址对应一个存储单元,通常是1字节。 在大多数现代计算机体系结构中,地址指向的…

万家数科:零售业务信息化融合的探索|OceanBase案例

本文作者:马琳,万家数科数据库专家。 万家数科商业数据有限公司,作为华润万家旗下的信息技术企业,专注于零售行业,在为华润万家提供服务的同时,也积极面向市场,为零售商及其生态系统提供全面的核…

挖矿病毒来势汹汹

病毒来了, 我的个人站点使用了 wordpress, 它的不知哪个漏洞让黑客攻入了我的站点 使用 top 命令看到了有不明进程始终占据了 100% 的 CPU snapshot 1 snapshot 2 通过以下 "三板斧"可以查杀这个进程 先用 top (shiftp) 查找占据 CPU 最多的进程根据其进程号 pid 查看…

【数据结构】宜宾大学-计院-实验四

栈和队列之(栈的基本操作) 实验目的:实验内容:实验结果:实验报告:(及时撰写实验报告):实验测试结果:代码实现1.0:(C/C)【含注释】代码…

QGIS之三十二DEM地形导出三维模型gltf

效果 1、准备数据 (1)dem.tif (2)dom.tif 2、qgis加载dem和dom数据 3、安装插件 插件步骤可以参考这篇文章 QGIS之二十四安装插件 安装了Qgis2threejs插件,结果