机器学习系列----KNN分类

目录

前言

一.KNN算法的基本原理

二.KNN分类的实现

三.总结


前言

在机器学习领域,K近邻算法(K-Nearest Neighbors, KNN)是一种非常直观且常用的分类算法。它是一种基于实例的学习方法,也被称为懒学习(Lazy Learning),因为它在训练阶段不进行任何模型的构建,所有的计算都推迟到测试阶段进行。KNN分类的核心思想是:给定一个测试样本,找到在训练集中与其距离最近的K个样本,然后根据这K个样本的标签进行预测。

本文将介绍KNN算法的基本原理、如何实现KNN分类,以及在实际使用中需要注意的几点。

一.KNN算法的基本原理

KNN算法的基本流程如下:

(1)选择距离度量:通常我们使用欧氏距离来衡量两个样本点之间的距离,但也可以选择其他距离度量,如曼哈顿距离、余弦相似度等。

(2)选择K值:选择K的大小会直接影响分类效果。K值太小容易受到噪声数据的影响,而K值过大可能导致分类结果过于平滑。

(3)找到K个邻居:对于测试样本,根据距离度量选择与之最接近的K个样本。

(4)投票决策:通过这K个邻居的类别标签进行投票,测试样本的预测标签通常由出现频率最高的类别决定。

二.KNN分类的实现

在Python中,我们可以通过 sklearn 库来快速实现KNN分类器。下面是一个使用KNN进行分类的基本示例:

导入必要的库

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score

 加载数据集
我们使用sklearn自带的鸢尾花(Iris)数据集,该数据集包含150个样本,4个特征,3个类别。

# 加载鸢尾花数据集
iris = load_iris()
X = iris.data  # 特征
y = iris.target  # 标签

数据集拆分
将数据集拆分为训练集和测试集,训练集占80%,测试集占20%。

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

初始化KNN分类器并训练
我们创建一个KNN分类器实例,选择K=3。

# 初始化KNN分类器,设置K=3
knn = KNeighborsClassifier(n_neighbors=3)# 在训练集上训练模型
knn.fit(X_train, y_train)

测试与评估
我们可以使用测试集来评估模型的准确性。

# 在测试集上做预测
y_pred = knn.predict(X_test)# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率: {accuracy * 100:.2f}%")

KNN分类的优缺点
优点
简单易懂:KNN算法简单直观,不需要复杂的训练过程。
无参数假设:KNN不需要像其他算法那样对数据做参数假设,适应性强。
适合多分类问题:KNN能够有效地处理多类问题。
缺点
计算开销大:在测试阶段需要计算每个测试点与所有训练数据的距离,计算量大,尤其在数据量较大时,效率较低。
对噪声敏感:由于KNN依赖于距离度量,数据中的噪声点可能会影响分类结果。
需要存储整个训练集:KNN算法是懒学习,需要将训练集存储在内存中,可能会对内存消耗产生较大影响。
K值的选择与调优
选择合适的K值是KNN分类器表现的关键。过小的K值(例如1)容易过拟合,受噪声影响较大,而过大的K值会导致欠拟合。常用的选择方法是通过交叉验证来选择K值。 

from sklearn.model_selection import cross_val_score# 使用交叉验证选择K值
k_values = range(1, 21)
cv_scores = [np.mean(cross_val_score(KNeighborsClassifier(n_neighbors=k), X, y, cv=5)) for k in k_values]# 输出不同K值的交叉验证得分
for k, score in zip(k_values, cv_scores):print(f"K={k}, Cross-validation accuracy={score:.2f}")

 

 

import numpy as np
import math
from collections import Counter
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt# 计算欧几里得距离
def euclidean_distance(x1, x2):"""计算欧几里得距离x1, x2: 两个输入样本,numpy数组或列表"""return math.sqrt(np.sum((x1 - x2) ** 2))# 计算曼哈顿距离
def manhattan_distance(x1, x2):"""计算曼哈顿距离(L1距离)x1, x2: 两个输入样本,numpy数组或列表"""return np.sum(np.abs(x1 - x2))# 计算闵可夫斯基距离
def minkowski_distance(x1, x2, p=3):"""计算闵可夫斯基距离,p为距离的阶数x1, x2: 两个输入样本,numpy数组或列表p: 阶数,通常为 1 (曼哈顿距离) 或 2 (欧几里得距离)"""return np.power(np.sum(np.abs(x1 - x2) ** p), 1/p)# KNN 分类器类
class KNN:def __init__(self, k=3, distance_metric='euclidean'):"""初始化 KNN 分类器k: 最近邻的个数distance_metric: 距离度量方式,'euclidean' 为欧几里得距离,'manhattan' 为曼哈顿距离,'minkowski' 为闵可夫斯基距离"""self.k = kself.distance_metric = distance_metricdef fit(self, X_train, y_train):"""训练模型,保存训练数据X_train: 训练特征数据y_train: 训练标签数据"""self.X_train = X_trainself.y_train = y_traindef predict(self, X_test):"""对测试数据进行预测X_test: 测试特征数据返回预测标签"""predictions = [self._predict(x) for x in X_test]return np.array(predictions)def _predict(self, x):"""对单个样本进行预测x: 输入样本返回预测标签"""# 根据指定的距离度量方法计算距离if self.distance_metric == 'euclidean':distances = [euclidean_distance(x, x_train) for x_train in self.X_train]elif self.distance_metric == 'manhattan':distances = [manhattan_distance(x, x_train) for x_train in self.X_train]elif self.distance_metric == 'minkowski':distances = [minkowski_distance(x, x_train) for x_train in self.X_train]else:raise ValueError(f"Unsupported distance metric: {self.distance_metric}")# 找到最近的 k 个邻居k_indices = np.argsort(distances)[:self.k]k_nearest_labels = [self.y_train[i] for i in k_indices]# 返回最常见的标签most_common = Counter(k_nearest_labels).most_common(1)return most_common[0][0]# 加载 Iris 数据集
iris = load_iris()
X = iris.data  # 特征数据
y = iris.target  # 标签数据# 切分数据集为训练集和测试集,70% 训练集,30% 测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 标准化数据,以确保不同特征的数值范围一致
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)# 初始化 KNN 分类器
knn = KNN(k=5, distance_metric='minkowski')  # 使用闵可夫斯基距离
knn.fit(X_train, y_train)# 预测
predictions = knn.predict(X_test)# 计算准确率
accuracy = accuracy_score(y_test, predictions)
print(f"预测准确率: {accuracy * 100:.2f}%")# 输出混淆矩阵和分类报告
print("\n混淆矩阵:")
print(confusion_matrix(y_test, predictions))print("\n分类报告:")
print(classification_report(y_test, predictions))# 绘制预测结果与真实结果对比的图表
def plot_confusion_matrix(cm, classes, title='Confusion Matrix', cmap=plt.cm.Blues):"""绘制混淆矩阵cm: 混淆矩阵classes: 类别名"""plt.imshow(cm, interpolation='nearest', cmap=cmap)plt.title(title)plt.colorbar()tick_marks = np.arange(len(classes))plt.xticks(tick_marks, classes, rotation=45)plt.yticks(tick_marks, classes)# 绘制网格thresh = cm.max() / 2.for i, j in np.ndindex(cm.shape):plt.text(j, i, format(cm[i, j], 'd'),horizontalalignment="center",color="white" if cm[i, j] > thresh else "black")plt.tight_layout()plt.ylabel('True label')plt.xlabel('Predicted label')# 计算混淆矩阵
cm = confusion_matrix(y_test, predictions)# 绘制混淆矩阵
plt.figure(figsize=(8, 6))
plot_confusion_matrix(cm, classes=iris.target_names)
plt.show()# 展示一些预测结果
for i in range(5):print(f"实际标签: {iris.target_names[y_test[i]]}, 预测标签: {iris.target_names[predictions[i]]}")

 

代码功能解释:
计算不同距离:

euclidean_distance:计算欧几里得距离。
manhattan_distance:计算曼哈顿距离。
minkowski_distance:计算闵可夫斯基距离,p 代表阶数,通常取 1(曼哈顿距离)或者 2(欧几里得距离)。
KNN 分类器:

在 KNN 类中,你可以选择不同的距离度量方式 ('euclidean', 'manhattan', 'minkowski'),通过 k 来设定邻居个数。
fit 方法保存训练数据,predict 方法对每个测试数据点进行预测。
_predict 方法对单个测试样本进行预测,通过计算与训练集中所有样本的距离来选择最近的 k 个邻居。
数据预处理:

使用 StandardScaler 来标准化数据,使得每个特征具有零均值和单位方差。
模型评估:

使用 accuracy_score 计算预测准确率。
使用 confusion_matrix 和 classification_report 来展示混淆矩阵和分类性能报告(包括精确度、召回率、F1 分数等)。
通过 matplotlib 绘制混淆矩阵,帮助可视化模型的分类效果。
数据集:

使用 sklearn.datasets 中的 Iris 数据集。该数据集包含 150 个样本,分别属于 3 个不同的鸢尾花种类,每个样本有 4 个特征。
输出:
准确率:模型对测试集的预测准确性。
混淆矩阵:展示真实标签与预测标签的对比。
分类报告:包含精确度、召回率、F1 分数等详细指标。
混淆矩阵图表:图形化展示分类性能。
这个实现包含了更多的功能,并且通过使用不同的距离度量方法,你可以探索 KNN 在不同设置下的表现。

三.总结

K 最近邻(KNN)算法是一种简单直观的监督学习算法,广泛应用于分类和回归问题。其核心思想是,通过计算待预测样本与训练集中的每个样本之间的距离,选择距离最近的 k 个样本(即“邻居”),然后根据这些邻居的标签或数值来进行预测。在分类问题中,KNN 通过多数投票原则决定最终分类结果;在回归问题中,则通常是取邻居标签的平均值。KNN 算法的优势在于不需要显式的训练过程,其预测过程依赖于对整个训练数据集的存储和计算,因此适合动态更新数据的场景。然而,KNN 算法的计算复杂度较高,尤其在数据集较大时,预测过程可能变得非常缓慢。为了提高效率,通常需要对数据进行预处理,如归一化或标准化,以消除不同特征尺度差异的影响。此外,K 值的选择以及距离度量方法(如欧几里得距离、曼哈顿距离等)会显著影响模型的表现,K 值过小可能导致过拟合,过大则可能导致欠拟合。KNN 的一个主要缺点是它对高维数据(即特征空间维度较大)不太敏感,因为高维空间的距离度量往往会失去区分度,导致“维度灾难”。总的来说,KNN 是一个易于理解和实现的算法,适用于样本量不大且特征维度较低的问题,但在大数据集和高维数据上可能不够高效。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/60493.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深度学习——优化算法、激活函数、归一化、正则化

文章目录 🌺深度学习面试八股汇总🌺优化算法方法梯度下降 (Gradient Descent, GD)动量法 (Momentum)AdaGrad (Adaptive Gradient Algorithm)RMSProp (Root Mean Square Propagation)Adam (Adaptive Moment Estimation)AdamW 优化算法总结 经验和实践建议…

YOLOv11实战宠物狗分类

本文采用YOLOv11作为核心算法框架,结合PyQt5构建用户界面,使用Python3进行开发。YOLOv11以其高效的特征提取能力,在多个图像分类任务中展现出卓越性能。本研究针对5种宠物狗数据集进行训练和优化,该数据集包含丰富的宠物狗图像样本…

星期-时间范围选择器 滑动选择时间 最小粒度 vue3

星期-时间范围选择器 功能介绍属性说明事件说明实现代码使用范例 根据业务需要,实现了一个可选择时间范围的周视图。用户可以通过鼠标拖动来选择时间段,并且可以通过快速选择组件来快速选择特定的时间范围。 如图: 功能介绍 时间范围选择&…

云岚到家 秒杀抢购

目录 秒杀抢购业务特点 常用技术方案 抢券 抢券界面 进行抢券 我的优惠券列表 活动查询 系统设计 活动查询分析 活动查询界面显示了哪些数据? 面向高并发如何提高活动查询性能? 如何保证缓存一致性? 数据流 Redis数据结构设计 如…

JavaWeb常见注解

1.Controller 在 JavaWeb 开发中,Controller是 Spring 框架中的一个注解,主要用于定义控制器类(Controller),是 Spring MVC 模式的核心组件之一。它表示该类是一个 Spring MVC 控制器,用来处理 HTTP 请求并…

光伏储能微电网协调控制器

安科瑞 Acrel-Tu1990 1. 产品介绍 ACCU-100微电网协调控制器是一款专为微电网、分布式发电和储能系统设计的智能协调控制设备。该装置能够兼容包括光伏系统、风力发电、储能系统以及充电桩等多种设备的接入。它通过全天候的数据采集与分析,实时监控光伏、风能、储…

【C++课程学习】:继承:默认成员函数

🎁个人主页:我们的五年 🔍系列专栏:C课程学习 🎉欢迎大家点赞👍评论📝收藏⭐文章 目录 构造函数 🍩默认构造函数(这里指的是编译器生成的构造函数)&#…

泷羽sec学习打卡-Linux基础2

声明 学习视频来自B站UP主 泷羽sec,如涉及侵权马上删除文章 笔记的只是方便各位师傅学习知识,以下网站只涉及学习内容,其他的都与本人无关,切莫逾越法律红线,否则后果自负 关于Linux的那些事儿-Base2 一、Linux-Base2linux有哪些目录呢?不同目录下有哪些具体的文件呢…

TCP拥塞控制

TCP拥塞控制(Congestion Control) 什么是拥塞控制? 拥塞控制(Congestion Control)主要针对整个网络中的数据传输速率进行调节,防止过多的数据注入网络中,这样可以使网络中的路由器或链路不致于过载,以避免…

Unity教程(十八)战斗系统 攻击逻辑

Unity开发2D类银河恶魔城游戏学习笔记 Unity教程(零)Unity和VS的使用相关内容 Unity教程(一)开始学习状态机 Unity教程(二)角色移动的实现 Unity教程(三)角色跳跃的实现 Unity教程&…

自动驾驶合集(更新中)

文章目录 车辆模型控制路径规划 车辆模型 车辆模型基础合集 控制 控制合集 路径规划 规划合集

网站架构知识之Ansible进阶(day022)

1.handler触发器 应用场景:一般用于分发配置文件时候,如果配置文件有变化,则重启服务,如果没有变化,则不重启服务 案列01:分发nfs配置文件,若文件发生改变则重启服务 2.when判断 用于给ans运…

整理5个优秀的微信小程序开源项目

​ 一、Bee GitHub: https://github.com/woniudiancang/bee Bee是一个餐饮点餐商城微信小程序,是针对餐饮行业推出的一套完整的餐饮解决方案,实现了用户在线点餐下单、外卖、叫号排队、支付、配送等功能,完美的使餐饮行业更高效便捷&#x…

微服务链路追踪skywalking安装

‌SkyWalking是一个开源的分布式追踪系统,主要用于监控和分析微服务架构下的应用性能。‌ 它提供了分布式追踪、服务网格遥测分析、度量聚合和可视化一体化解决方案,特别适用于微服务、云原生架构和基于容器的环境(如Docker、K8s、Mesos&…

5G的发展演进

5G发展的驱动力 什么是5G [远程会议,2020年7月10日] 在来自世界各地的政府主管部门、电信制造及运营企业、研究机构约200多名会议代表和专家们的共同见证下,ITU-R WP 5D#35e远程会议宣布3GPP 5G技术(含NB-IoT)满足IMT-2020 5G技…

matlab建模入门指导

本文以水池中鸡蛋温度随时间的变化为切入点,对其进行数学建模并进行MATLAB求解,以更为通俗地进行数学建模问题入门指导。 一、问题简述 一个煮熟的鸡蛋有98摄氏度,将它放在18摄氏度的水池中,五分钟后鸡蛋的温度为38摄氏度&#x…

开源 2 + 1 链动模式、AI 智能名片、S2B2C 商城小程序在用户留存与品牌发展中的应用研究

摘要:本文以企业和个人品牌发展中至关重要的用户留存问题为切入点,结合管理大师彼得德鲁克对于企业兴旺发达的观点,阐述了用户留存对品牌营收的关键意义。在此基础上,深入分析开源 2 1 链动模式、AI 智能名片、S2B2C 商城小程序在…

搭建Python2和Python3虚拟环境

搭建Python3虚拟环境 1. 更新pip2. 搭建Python3虚拟环境第一步:安装python虚拟化工具第二步: 创建虚拟环境 3. 搭建Python2虚拟环境第一步:安装虚拟环境模块第二步:创建虚拟环境 4. workon命令管理虚拟机第一步:安装扩…

对接阿里云实人认证

对接阿里云实人认证-身份二要素核验接口整理 目录 应用场景 接口文档 接口信息 请求参数 响应参数 调试 阿里云openApi平台调试 查看调用结果 查看SDK示例 下载SDK 遇到问题 本地调试 总结 应用场景 项目有一个提现的场景,需要用户真实的身份信息。 …

基于卷积神经网络的车辆损坏部位检测系统带gui

项目源码获取方式见文章末尾! 600多个深度学习项目资料,快来加入社群一起学习吧。 《------往期经典推荐------》 项目名称 1.【基于CNN-RNN的影像报告生成】 2.【卫星图像道路检测DeepLabV3Plus模型】 3.【GAN模型实现二次元头像生成】 4.【CNN模型实现…