机器学习系列----KNN分类

目录

前言

一.KNN算法的基本原理

二.KNN分类的实现

三.总结


前言

在机器学习领域,K近邻算法(K-Nearest Neighbors, KNN)是一种非常直观且常用的分类算法。它是一种基于实例的学习方法,也被称为懒学习(Lazy Learning),因为它在训练阶段不进行任何模型的构建,所有的计算都推迟到测试阶段进行。KNN分类的核心思想是:给定一个测试样本,找到在训练集中与其距离最近的K个样本,然后根据这K个样本的标签进行预测。

本文将介绍KNN算法的基本原理、如何实现KNN分类,以及在实际使用中需要注意的几点。

一.KNN算法的基本原理

KNN算法的基本流程如下:

(1)选择距离度量:通常我们使用欧氏距离来衡量两个样本点之间的距离,但也可以选择其他距离度量,如曼哈顿距离、余弦相似度等。

(2)选择K值:选择K的大小会直接影响分类效果。K值太小容易受到噪声数据的影响,而K值过大可能导致分类结果过于平滑。

(3)找到K个邻居:对于测试样本,根据距离度量选择与之最接近的K个样本。

(4)投票决策:通过这K个邻居的类别标签进行投票,测试样本的预测标签通常由出现频率最高的类别决定。

二.KNN分类的实现

在Python中,我们可以通过 sklearn 库来快速实现KNN分类器。下面是一个使用KNN进行分类的基本示例:

导入必要的库

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score

 加载数据集
我们使用sklearn自带的鸢尾花(Iris)数据集,该数据集包含150个样本,4个特征,3个类别。

# 加载鸢尾花数据集
iris = load_iris()
X = iris.data  # 特征
y = iris.target  # 标签

数据集拆分
将数据集拆分为训练集和测试集,训练集占80%,测试集占20%。

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

初始化KNN分类器并训练
我们创建一个KNN分类器实例,选择K=3。

# 初始化KNN分类器,设置K=3
knn = KNeighborsClassifier(n_neighbors=3)# 在训练集上训练模型
knn.fit(X_train, y_train)

测试与评估
我们可以使用测试集来评估模型的准确性。

# 在测试集上做预测
y_pred = knn.predict(X_test)# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率: {accuracy * 100:.2f}%")

KNN分类的优缺点
优点
简单易懂:KNN算法简单直观,不需要复杂的训练过程。
无参数假设:KNN不需要像其他算法那样对数据做参数假设,适应性强。
适合多分类问题:KNN能够有效地处理多类问题。
缺点
计算开销大:在测试阶段需要计算每个测试点与所有训练数据的距离,计算量大,尤其在数据量较大时,效率较低。
对噪声敏感:由于KNN依赖于距离度量,数据中的噪声点可能会影响分类结果。
需要存储整个训练集:KNN算法是懒学习,需要将训练集存储在内存中,可能会对内存消耗产生较大影响。
K值的选择与调优
选择合适的K值是KNN分类器表现的关键。过小的K值(例如1)容易过拟合,受噪声影响较大,而过大的K值会导致欠拟合。常用的选择方法是通过交叉验证来选择K值。 

from sklearn.model_selection import cross_val_score# 使用交叉验证选择K值
k_values = range(1, 21)
cv_scores = [np.mean(cross_val_score(KNeighborsClassifier(n_neighbors=k), X, y, cv=5)) for k in k_values]# 输出不同K值的交叉验证得分
for k, score in zip(k_values, cv_scores):print(f"K={k}, Cross-validation accuracy={score:.2f}")

 

 

import numpy as np
import math
from collections import Counter
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt# 计算欧几里得距离
def euclidean_distance(x1, x2):"""计算欧几里得距离x1, x2: 两个输入样本,numpy数组或列表"""return math.sqrt(np.sum((x1 - x2) ** 2))# 计算曼哈顿距离
def manhattan_distance(x1, x2):"""计算曼哈顿距离(L1距离)x1, x2: 两个输入样本,numpy数组或列表"""return np.sum(np.abs(x1 - x2))# 计算闵可夫斯基距离
def minkowski_distance(x1, x2, p=3):"""计算闵可夫斯基距离,p为距离的阶数x1, x2: 两个输入样本,numpy数组或列表p: 阶数,通常为 1 (曼哈顿距离) 或 2 (欧几里得距离)"""return np.power(np.sum(np.abs(x1 - x2) ** p), 1/p)# KNN 分类器类
class KNN:def __init__(self, k=3, distance_metric='euclidean'):"""初始化 KNN 分类器k: 最近邻的个数distance_metric: 距离度量方式,'euclidean' 为欧几里得距离,'manhattan' 为曼哈顿距离,'minkowski' 为闵可夫斯基距离"""self.k = kself.distance_metric = distance_metricdef fit(self, X_train, y_train):"""训练模型,保存训练数据X_train: 训练特征数据y_train: 训练标签数据"""self.X_train = X_trainself.y_train = y_traindef predict(self, X_test):"""对测试数据进行预测X_test: 测试特征数据返回预测标签"""predictions = [self._predict(x) for x in X_test]return np.array(predictions)def _predict(self, x):"""对单个样本进行预测x: 输入样本返回预测标签"""# 根据指定的距离度量方法计算距离if self.distance_metric == 'euclidean':distances = [euclidean_distance(x, x_train) for x_train in self.X_train]elif self.distance_metric == 'manhattan':distances = [manhattan_distance(x, x_train) for x_train in self.X_train]elif self.distance_metric == 'minkowski':distances = [minkowski_distance(x, x_train) for x_train in self.X_train]else:raise ValueError(f"Unsupported distance metric: {self.distance_metric}")# 找到最近的 k 个邻居k_indices = np.argsort(distances)[:self.k]k_nearest_labels = [self.y_train[i] for i in k_indices]# 返回最常见的标签most_common = Counter(k_nearest_labels).most_common(1)return most_common[0][0]# 加载 Iris 数据集
iris = load_iris()
X = iris.data  # 特征数据
y = iris.target  # 标签数据# 切分数据集为训练集和测试集,70% 训练集,30% 测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 标准化数据,以确保不同特征的数值范围一致
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)# 初始化 KNN 分类器
knn = KNN(k=5, distance_metric='minkowski')  # 使用闵可夫斯基距离
knn.fit(X_train, y_train)# 预测
predictions = knn.predict(X_test)# 计算准确率
accuracy = accuracy_score(y_test, predictions)
print(f"预测准确率: {accuracy * 100:.2f}%")# 输出混淆矩阵和分类报告
print("\n混淆矩阵:")
print(confusion_matrix(y_test, predictions))print("\n分类报告:")
print(classification_report(y_test, predictions))# 绘制预测结果与真实结果对比的图表
def plot_confusion_matrix(cm, classes, title='Confusion Matrix', cmap=plt.cm.Blues):"""绘制混淆矩阵cm: 混淆矩阵classes: 类别名"""plt.imshow(cm, interpolation='nearest', cmap=cmap)plt.title(title)plt.colorbar()tick_marks = np.arange(len(classes))plt.xticks(tick_marks, classes, rotation=45)plt.yticks(tick_marks, classes)# 绘制网格thresh = cm.max() / 2.for i, j in np.ndindex(cm.shape):plt.text(j, i, format(cm[i, j], 'd'),horizontalalignment="center",color="white" if cm[i, j] > thresh else "black")plt.tight_layout()plt.ylabel('True label')plt.xlabel('Predicted label')# 计算混淆矩阵
cm = confusion_matrix(y_test, predictions)# 绘制混淆矩阵
plt.figure(figsize=(8, 6))
plot_confusion_matrix(cm, classes=iris.target_names)
plt.show()# 展示一些预测结果
for i in range(5):print(f"实际标签: {iris.target_names[y_test[i]]}, 预测标签: {iris.target_names[predictions[i]]}")

 

代码功能解释:
计算不同距离:

euclidean_distance:计算欧几里得距离。
manhattan_distance:计算曼哈顿距离。
minkowski_distance:计算闵可夫斯基距离,p 代表阶数,通常取 1(曼哈顿距离)或者 2(欧几里得距离)。
KNN 分类器:

在 KNN 类中,你可以选择不同的距离度量方式 ('euclidean', 'manhattan', 'minkowski'),通过 k 来设定邻居个数。
fit 方法保存训练数据,predict 方法对每个测试数据点进行预测。
_predict 方法对单个测试样本进行预测,通过计算与训练集中所有样本的距离来选择最近的 k 个邻居。
数据预处理:

使用 StandardScaler 来标准化数据,使得每个特征具有零均值和单位方差。
模型评估:

使用 accuracy_score 计算预测准确率。
使用 confusion_matrix 和 classification_report 来展示混淆矩阵和分类性能报告(包括精确度、召回率、F1 分数等)。
通过 matplotlib 绘制混淆矩阵,帮助可视化模型的分类效果。
数据集:

使用 sklearn.datasets 中的 Iris 数据集。该数据集包含 150 个样本,分别属于 3 个不同的鸢尾花种类,每个样本有 4 个特征。
输出:
准确率:模型对测试集的预测准确性。
混淆矩阵:展示真实标签与预测标签的对比。
分类报告:包含精确度、召回率、F1 分数等详细指标。
混淆矩阵图表:图形化展示分类性能。
这个实现包含了更多的功能,并且通过使用不同的距离度量方法,你可以探索 KNN 在不同设置下的表现。

三.总结

K 最近邻(KNN)算法是一种简单直观的监督学习算法,广泛应用于分类和回归问题。其核心思想是,通过计算待预测样本与训练集中的每个样本之间的距离,选择距离最近的 k 个样本(即“邻居”),然后根据这些邻居的标签或数值来进行预测。在分类问题中,KNN 通过多数投票原则决定最终分类结果;在回归问题中,则通常是取邻居标签的平均值。KNN 算法的优势在于不需要显式的训练过程,其预测过程依赖于对整个训练数据集的存储和计算,因此适合动态更新数据的场景。然而,KNN 算法的计算复杂度较高,尤其在数据集较大时,预测过程可能变得非常缓慢。为了提高效率,通常需要对数据进行预处理,如归一化或标准化,以消除不同特征尺度差异的影响。此外,K 值的选择以及距离度量方法(如欧几里得距离、曼哈顿距离等)会显著影响模型的表现,K 值过小可能导致过拟合,过大则可能导致欠拟合。KNN 的一个主要缺点是它对高维数据(即特征空间维度较大)不太敏感,因为高维空间的距离度量往往会失去区分度,导致“维度灾难”。总的来说,KNN 是一个易于理解和实现的算法,适用于样本量不大且特征维度较低的问题,但在大数据集和高维数据上可能不够高效。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/60493.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深度学习——优化算法、激活函数、归一化、正则化

文章目录 🌺深度学习面试八股汇总🌺优化算法方法梯度下降 (Gradient Descent, GD)动量法 (Momentum)AdaGrad (Adaptive Gradient Algorithm)RMSProp (Root Mean Square Propagation)Adam (Adaptive Moment Estimation)AdamW 优化算法总结 经验和实践建议…

vue登陆验证

导航守卫:直白的说,导航守卫就是路由跳转过程中的一些钩子函数,这些函数能让你在跳转过程中操作一些其他 的事的时机,这就是导航守卫。 比如最常见的登录权限验证,当用户满足条件时,才让其进入导航&…

YOLOv11实战宠物狗分类

本文采用YOLOv11作为核心算法框架,结合PyQt5构建用户界面,使用Python3进行开发。YOLOv11以其高效的特征提取能力,在多个图像分类任务中展现出卓越性能。本研究针对5种宠物狗数据集进行训练和优化,该数据集包含丰富的宠物狗图像样本…

星期-时间范围选择器 滑动选择时间 最小粒度 vue3

星期-时间范围选择器 功能介绍属性说明事件说明实现代码使用范例 根据业务需要,实现了一个可选择时间范围的周视图。用户可以通过鼠标拖动来选择时间段,并且可以通过快速选择组件来快速选择特定的时间范围。 如图: 功能介绍 时间范围选择&…

上海ABC行测试面试题回忆版本

11.14号去ABC面试,流程上先做个半个小时的笔试,然后是排队面试。这次做笔试的人很多,有JAVA,大数据,前端,测试,我是最后一批测试。现场没有敢拍照。面试的时候,一共8个面试官&#x…

云岚到家 秒杀抢购

目录 秒杀抢购业务特点 常用技术方案 抢券 抢券界面 进行抢券 我的优惠券列表 活动查询 系统设计 活动查询分析 活动查询界面显示了哪些数据? 面向高并发如何提高活动查询性能? 如何保证缓存一致性? 数据流 Redis数据结构设计 如…

【大数据测试HBase数据库 — 详细教程(含实例与监控调优)】

大数据测试HBase数据库 1. 环境准备与安装1.1 安装 HBase 环境1.1.1 下载与安装 HBase1.1.2 配置 HBase 2. 功能测试2.1 创建表和插入数据2.2 查询数据2.3 更新数据2.4 删除数据2.5 查看表格结构 3. 性能测试3.1 使用 HBase 自带的性能测试工具3.2 使用 YCSB 进行性能测试 4. 容…

JavaWeb常见注解

1.Controller 在 JavaWeb 开发中,Controller是 Spring 框架中的一个注解,主要用于定义控制器类(Controller),是 Spring MVC 模式的核心组件之一。它表示该类是一个 Spring MVC 控制器,用来处理 HTTP 请求并…

vue3+elementplus+虚拟树el-tree-v2+多条件筛选过滤filter-method

筛选条件 <el-inputv-model"searchForm.searchTreeValue"input"searchTreeData"style"flex: 1; margin-right: 0.0694rem"placeholder"请输入要搜索的设备"clearable/><imgclass"refresh-img"src"com_refres…

光伏储能微电网协调控制器

安科瑞 Acrel-Tu1990 1. 产品介绍 ACCU-100微电网协调控制器是一款专为微电网、分布式发电和储能系统设计的智能协调控制设备。该装置能够兼容包括光伏系统、风力发电、储能系统以及充电桩等多种设备的接入。它通过全天候的数据采集与分析&#xff0c;实时监控光伏、风能、储…

【C++课程学习】:继承:默认成员函数

&#x1f381;个人主页&#xff1a;我们的五年 &#x1f50d;系列专栏&#xff1a;C课程学习 &#x1f389;欢迎大家点赞&#x1f44d;评论&#x1f4dd;收藏⭐文章 目录 构造函数 &#x1f369;默认构造函数&#xff08;这里指的是编译器生成的构造函数&#xff09;&#…

React核心概念与特点

React是由Facebook开发并维护的一个用于构建用户界面的开源JavaScript库。它以其独特的组件化架构、高效的性能优化以及灵活的状态管理方式&#xff0c;在前端开发领域占据了重要地位。本文将对React的核心概念、特点以及关键知识点进行全面解析&#xff0c;以帮助读者更好地理…

泷羽sec学习打卡-Linux基础2

声明 学习视频来自B站UP主 泷羽sec,如涉及侵权马上删除文章 笔记的只是方便各位师傅学习知识,以下网站只涉及学习内容,其他的都与本人无关,切莫逾越法律红线,否则后果自负 关于Linux的那些事儿-Base2 一、Linux-Base2linux有哪些目录呢&#xff1f;不同目录下有哪些具体的文件呢…

TCP拥塞控制

TCP拥塞控制&#xff08;Congestion Control&#xff09; 什么是拥塞控制&#xff1f; 拥塞控制(Congestion Control)主要针对整个网络中的数据传输速率进行调节&#xff0c;防止过多的数据注入网络中&#xff0c;这样可以使网络中的路由器或链路不致于过载&#xff0c;以避免…

自闭症机构解析:去机构是否是最好的选择?

在探讨自闭症儿童的教育与康复问题时&#xff0c;一个常被提及的话题是&#xff1a;将孩子送入专业的自闭症干预机构&#xff0c;是否真的是最好的选择&#xff1f;这个问题&#xff0c;对于每一个自闭症家庭而言&#xff0c;都显得尤为沉重且复杂。星贝育园康复中心&#xff0…

Unity教程(十八)战斗系统 攻击逻辑

Unity开发2D类银河恶魔城游戏学习笔记 Unity教程&#xff08;零&#xff09;Unity和VS的使用相关内容 Unity教程&#xff08;一&#xff09;开始学习状态机 Unity教程&#xff08;二&#xff09;角色移动的实现 Unity教程&#xff08;三&#xff09;角色跳跃的实现 Unity教程&…

自动驾驶合集(更新中)

文章目录 车辆模型控制路径规划 车辆模型 车辆模型基础合集 控制 控制合集 路径规划 规划合集

网站架构知识之Ansible进阶(day022)

1.handler触发器 应用场景&#xff1a;一般用于分发配置文件时候&#xff0c;如果配置文件有变化&#xff0c;则重启服务&#xff0c;如果没有变化&#xff0c;则不重启服务 案列01&#xff1a;分发nfs配置文件&#xff0c;若文件发生改变则重启服务 2.when判断 用于给ans运…

整理5个优秀的微信小程序开源项目

​ 一、Bee GitHub: https://github.com/woniudiancang/bee Bee是一个餐饮点餐商城微信小程序&#xff0c;是针对餐饮行业推出的一套完整的餐饮解决方案&#xff0c;实现了用户在线点餐下单、外卖、叫号排队、支付、配送等功能&#xff0c;完美的使餐饮行业更高效便捷&#x…

LeetCode【0029】两数相除

本文目录 1 中文题目2 求解方法&#xff1a;位运算2.1 方法思路2.2 Python代码2.3 复杂度分析 3 题目总结 1 中文题目 给定两个整数&#xff0c;被除数 dividend 和除数 divisor。将两数相除&#xff0c;要求 不使用 乘法、除法和取余运算。 整数除法应该向零截断&#xff0c…