KNN算法:从思想到实现(附代码)

引言

K最近邻算法(K Nearest Neighbors, KNN)是一种简单而有效的机器学习算法,用于分类和回归问题。其核心思想基于“近朱者赤,近墨者黑”,即通过测量不同特征值之间的距离来进行分类或预测数值。本文将详细介绍KNN的核心概念、使用方法及其在sklearn中的实现,并展示如何自己动手编写一个简单的KNN算法。

新样本
寻找K个最近邻
分类问题:多数表决
回归问题:均值计算

KNN 核心思想

如何做一个样本的推理?首先需要明确这个问题是分类问题还是回归问题

  • 分类问题
    x0到底属于哪一类呢?
    对于一个新的样本x0,KNN首先找到与其最近的K个邻居,然后统计这些邻居中出现次数最多的类别作为x0的类别。

  • 回归问题
    x0到底是多少?
    对于回归问题,KNN同样找到与新样本x0​最近的K个邻居,但这次它计算的是这K个邻居标签的平均值,以此作为x0的预测值。

算法特点

惰性学习:KNN几乎没有训练过程,主要工作是在推理阶段完成,因此被称为惰性学习算法,(在推理时直接硬计算,这不属于典型的人工智能!)。

sklearn进行KNN操作

  • 分类问题示例
# 分类任务(鸢尾花数据集)
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
# 加载数据集
X, y = load_iris(return_X_y=True)
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
# 初始化并训练模型
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X=X_train, y=y_train)
# 预测及评估
y_pred = knn.predict(X=X_test)
acc = (y_pred == y_test).mean()
print(acc)
  • 回归问题示例
# 回归任务(波士顿房价数据集)
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsRegressor# 加载数据
data = pd.read_csv(filepath_or_buffer="boston_house_prices.csv", skiprows=1)
X = data.drop(columns=["MEDV"]).to_numpy()
y = data["MEDV"].to_numpy()# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)# 初始化并训练模型
knn = KNeighborsRegressor(n_neighbors=5)
knn.fit(X=X_train, y=y_train)# 预测及评估
y_pred = knn.predict(X=X_test)
# 计算的是平均绝对误差(MAE)
# 首先,通过abs(y_pred - y_test)计算预测值与实际值之间的绝对差异,即对于每一个测试样本,计算其预测值减去真实值的绝对值。
# 然后,使用.mean()函数计算这些绝对差异的平均值,得到的就是MAE。
mae = abs(y_pred - y_test).mean()
# 这段代码用于计算均方误差(MSE)
# 首先,通过(y_pred - y_test) ** 2计算预测值与实际值之差的平方,这样做的目的是为了放大较大误差的影响,并且消除负号(因为误差可能是正也可能是负)。
# 之后,使用.mean()函数计算这些平方差的平均值,得到的就是MSE。
mse = ((y_pred - y_test) ** 2).mean()
print(mae, mse)
# 4.756078431372549,51.74387450980392

波士顿房价数据集
在这里插入图片描述
MAE衡量的是预测值与实际值之间差距的平均水平,不考虑差距的方向(正负),只关心差距的大小。MAE 提供了一个直观的平均误差度量,易于理解,因为它直接表示了预测值与真实值之间的平均绝对距离。

MSE同样衡量了预测值与实际值之间的差距,但由于采用了平方,它对较大的误差更加敏感,这意味着如果模型的某些预测非常不准确,MSE会显著增大。

这两个指标都是评估回归模型性能的重要工具,它们帮助我们了解模型预测值与真实值之间的接近程度。选择哪个指标取决于具体的业务需求以及你希望如何权衡不同大小的预测误差。例如,如果你认为较大的误差应该被更重地惩罚,那么MSE可能是一个更好的选择;反之,如果你更关注整体的平均表现,那么MAE可能更适合。

手写KNN算法实现

  • 分类问题实现
from collections import Counter
import numpy as np
class MyKNeighborsClassifier(object):"""自定义KNN分类算法"""def __init__(self, n_neighbors=5):"""初始化方法:- 接收 超参数"""self.n_neighbors = n_neighborsdef fit(self, X, y):"""训练过程"""self.X = Xself.y = ydef predict(self, X):"""推理过程这段代码实现了 KNN 分类的核心逻辑:1.对于输入的新样本集 X 中的每一个样本 x,2.计算它与训练集中所有样本之间的欧几里得距离,3.找出距离最近的 K 个邻居,4.统计这些邻居的目标变量(标签)中哪个类别出现次数最多,5.将出现次数最多的类别作为当前新样本 x 的预测类别,6.将所有新样本的预测结果汇总成一个数组并返回。"""# X:[batch_size, num_features]# 第一步:寻找样本的 K个邻居# 第二步:对 K 个邻居的标签进行投票results = []# 循环遍历输入的新样本集 X 中的每一个样本 x。for x in X:#计算距离# self.X:训练数据集中的所有样本。# (self.X - x):计算训练集中的每个样本与当前新样本 x 在各个特征上的差异。# ** 2:对差异进行平方操作,确保所有值都是正数。# .sum(axis=1):沿着每个样本的所有特征维度求和,得到每个训练样本到新样本 x 的欧几里得距离的平方和。# ** 0.5:取平方根,得到实际的欧几里得距离。distance = ((self.X - x) ** 2).sum(axis=1) ** 0.5# 找到最近的 K 个邻居# distance.argsort():返回按升序排列的距离索引列表,即距离最近的样本排在前面。# [:self.n_neighbors]:选取前 self.n_neighbors 个最小距离对应的索引,即找到最近的 K 个邻居的索引。idxes = distance.argsort()[:self.n_neighbors]# 获取这些邻居的标签# self.y:训练集中对应样本的标签(目标变量)。# 使用之前找到的最近邻居的索引 idxes 来获取这些邻居的标签 labels。labels = self.y[idxes]# 对 K 个邻居的标签进行投票# 使用 collections.Counter 对这 K 个邻居的标签进行计数,统计每个标签出现的次数。# Counter(labels).most_common(1):返回一个列表,包含最常见的标签及其出现次数,按频率降序排列。most_common(1) 只返回出现次数最多的那个标签及其计数。# [0][0]:从上述列表中提取出最常见的标签(第一个元素的第一个值),作为最终的预测标签 final_label。final_label = Counter(labels).most_common(1)[0][0]results.append(final_label)return np.array(results)
## 使用自定义KNN进行分类
my_knn = MyKNeighborsClassifier(n_neighbors=5)
my_knn.fit(X=X_train, y=y_train)
y_pred = my_knn.predict(X=X_test)
print((y_pred == y_test).mean())
  • 回归问题实现
class MyKNeighborsRegressor(object):"""自定义KNN回归算法"""def __init__(self, n_neighbors=5):"""初始化方法:- 接收 超参数"""self.n_neighbors = n_neighborsdef fit(self, X, y):"""训练过程"""self.X = Xself.y = ydef predict(self, X):"""推理过程这段代码实现了 KNN 回归的核心逻辑:1.对于输入的新样本集 X 中的每一个样本 x,2.计算它与训练集中所有样本之间的欧几里得距离,3.找出距离最近的 K 个邻居,4.计算这些邻居的目标变量(标签)的平均值作为预测结果,5.将所有新样本的预测结果汇总成一个数组并返回。"""# X:[batch_size, num_features]# 第一步:寻找样本的 K个邻居# 第二步:对K个邻居的标签取均值results = []# 循环遍历输入的新样本集 X 中的每一个样本 xfor x in X:# 计算距离# self.X:训练数据集中的所有样本。# (self.X - x):计算训练集中的每个样本与当前新样本 x 在各个特征上的差异。# ** 2:对差异进行平方操作,以确保所有值都是正数。# .sum(axis=1):沿着每个样本的所有特征维度求和,得到每个训练样本到新样本 x 的欧几里得距离的平方和。# ** 0.5:取平方根,得到实际的欧几里得距离。distance = ((self.X - x) ** 2).sum(axis=1) ** 0.5# 找到最近的 K 个邻居# distance.argsort():返回按升序排列的距离索引列表,即距离最近的样本排在前面。# [:self.n_neighbors]:选取前 self.n_neighbors 个最小距离对应的索引,即找到最近的 K 个邻居的索引。idxes = distance.argsort()[:self.n_neighbors]# 获取这些邻居的标签# self.y:训练集中对应样本的标签(目标变量)。# 使用之前找到的最近邻居的索引 idxes 来获取这些邻居的标签 labels。labels = self.y[idxes]# 计算最终预测值# 对于回归问题,计算这 K 个最近邻居的标签值的平均值作为当前新样本 x 的预测值 final_label。final_label = labels.mean()results.append(final_label)return np.array(results)# 使用自定义KNN进行回归
knn = MyKNeighborsRegressor(n_neighbors=5)
knn.fit(X=X_train, y=y_train)
y_pred = knn.predict(X=X_test)
mae = abs(y_pred - y_test).mean()
mse = ((y_pred - y_test) ** 2).mean()
print(mae, mse)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/69135.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

专业学习|一文了解并实操自适应大邻域搜索(讲解代码)

一、自适应大邻域搜索概念介绍 自适应大邻域搜索(Adaptive Large Neighborhood Search,ALNS)是一种用于解决组合优化问题的元启发式算法。以下是关于它的详细介绍: -自适应大领域搜索的核心思想是:破坏解、修复解、动…

TensorFlow深度学习实战(6)——回归分析详解

TensorFlow深度学习实战(6)——回归分析详解 0. 前言1. 回归分析简介2. 线性回归2.1 简单线性回归2.2 多重线性回归2.3 多元线性回归 3. 构建基于线性回归的神经网络3.1 使用 TensorFlow 进行简单线性回归3.2 使用 TensorFlow 进行多元线性回归和多重线性…

2024年12月 Scratch 图形化(二级)真题解析 中国电子学会全国青少年软件编程等级考试

202412 Scratch 图形化(二级)真题解析 中国电子学会全国青少年软件编程等级考试 一、单选题(共25题,共50分) 第 1 题 小猫初始位置和方向如下图所示,下面哪个选项能让小猫吃到老鼠?( ) A. B. …

Java 面试合集(2024版)

种自己的花,爱自己的宇宙 目录 第一章-Java基础篇 1、你是怎样理解OOP面向对象??? 难度系数:? 2、重载与重写区别??? 难度系数:? 3、接口与抽象类的区别??? 难度系数:? 4、深拷贝与浅拷贝的理解??? 难度系数&…

Math Reference Notes: 符号函数

1. 符号函数的定义 符号函数(Sign Function) sgn ( x ) \text{sgn}(x) sgn(x) 是一个将实数 ( x ) 映射为其 符号值(即正数、负数或零)的函数。 它的定义如下: sgn ( x ) { 1 如果 x > 0 0 如果 x 0 − 1 如…

一文了解边缘计算

什么是边缘计算? 我们可以通过一个最简单的例子来理解它,它就像一个司令员,身在离炮火最近的前线,汇集现场所有的实时信息,经过分析并做出决策,及时果断而不拖延。 1.什么是边缘计算? 边缘计算…

108,【8】 buuctf web [网鼎杯 2020 青龙组]AreUSerialz

进入靶场 <?php // 包含 flag.php 文件&#xff0c;通常这个文件可能包含敏感信息&#xff0c;如 flag include("flag.php");// 高亮显示当前文件的源代码&#xff0c;方便查看代码结构和逻辑 highlight_file(__FILE__);// 定义一个名为 FileHandler 的类&#x…

《redis哨兵机制》

【redis哨兵机制导读】上一节介绍了redis主从同步的机制&#xff0c;但大家有没有想过一种场景&#xff0c;比如&#xff1a;主库突然挂了&#xff0c;那么按照读写分离的设计思想&#xff0c;此时redis集群只有从库才能提供读服务&#xff0c;那么写服务该如何提供&#xff0c…

【赵渝强老师】Spark RDD的依赖关系和任务阶段

Spark RDD彼此之间会存在一定的依赖关系。依赖关系有两种不同的类型&#xff1a;窄依赖和宽依赖。 窄依赖&#xff1a;如果父RDD的每一个分区最多只被一个子RDD的分区使用&#xff0c;这样的依赖关系就是窄依赖&#xff1b;宽依赖&#xff1a;如果父RDD的每一个分区被多个子RD…

开源数据分析工具 RapidMiner

RapidMiner是一款功能强大且广泛应用的数据分析工具&#xff0c;其核心功能和特点使其成为数据科学家、商业分析师和预测建模人员的首选工具。以下是对RapidMiner的深度介绍&#xff1a; 1. 概述 RapidMiner是一款开源且全面的端到端数据科学平台&#xff0c;支持从数据准备、…

蓝桥杯备考:二维前缀和算法模板题(二维前缀和详解)

【模板】二维前缀和 这道题如果我们暴力求解的话&#xff0c;时间复杂度就是q次查询里套两层循环最差的时候要遍历整个矩阵也就是O&#xff08;q*n*m) 由题目就是10的11次方&#xff0c;超时 二维前缀和求和的公式&#xff08;创建需要用到&#xff09;f[i][j]就是从&#xf…

3-track_hacker/2018网鼎杯

3-track_hacker 打开附件 使用Wireshark打开。过滤器过滤http,看里面有没有flag.txt 发现有 得到&#xff1a;eJxLy0lMrw6NTzPMS4n3TVWsBQAz4wXi base64解密 import base64 import zlibc eJxLy0lMrw6NTzPMS4n3TVWsBQAz4wXi decoded base64.b64decode(c) result zlib.deco…

第二十章 存储函数

目录 一、概述 二、语法 三、示例 一、概述 前面章节中&#xff0c;我们详细讲解了MySQL中的存储过程&#xff0c;掌握了存储过程之后&#xff0c;学习存储函数则肥仓简单&#xff0c;存储函数其实是一种特殊的存储过程&#xff0c;也就是有返回值的存储过程。存储函数的参数…

Linux:文件系统(软硬链接)

目录 inode ext2文件系统 Block Group 超级块&#xff08;Super Block&#xff09; GDT&#xff08;Group Descriptor Table&#xff09; 块位图&#xff08;Block Bitmap&#xff09; inode位图&#xff08;Inode Bitmap&#xff09; i节点表&#xff08;inode Tabl…

java求职学习day27

数据库连接池 &DBUtils 1.数据库连接池 1.1 连接池介绍 1) 什么是连接池 实际开发中 “ 获得连接 ” 或 “ 释放资源 ” 是非常消耗系统资源的两个过程&#xff0c;为了解决此类性能问题&#xff0c;通常情况我们 采用连接池技术&#xff0c;来共享连接 Connection 。…

机器学习--2.多元线性回归

多元线性回归 1、基本概念 1.1、连续值 1.2、离散值 1.3、简单线性回归 1.4、最优解 1.5、多元线性回归 2、正规方程 2.1、最小二乘法 2.2、多元一次方程举例 2.3、矩阵转置公式与求导公式 2.4、推导正规方程0的解 2.5、凸函数判定 成年人最大的自律就是&#xff1a…

Docker 部署 ClickHouse 教程

Docker 部署 ClickHouse 教程 背景 ClickHouse 是一个开源的列式数据库管理系统&#xff08;DBMS&#xff09;&#xff0c;主要用于在线分析处理&#xff08;OLAP&#xff09;。它专为大数据的实时分析设计&#xff0c;支持高速的查询性能和高吞吐量。ClickHouse 以其高效的数…

建表注意事项(2):表约束,主键自增,序列[oracle]

没有明确写明数据库时,默认基于oracle 约束的分类 用于确保数据的完整性和一致性。约束可以分为 表级约束 和 列级约束&#xff0c;区别在于定义的位置和作用范围 复合主键约束: 主键约束中有2个或以上的字段 复合主键的列顺序会影响索引的使用&#xff0c;需谨慎设计 添加…

Google C++ Style / 谷歌C++开源风格

文章目录 前言1. 头文件1.1 自给自足的头文件1.2 #define 防护符1.3 导入你的依赖1.4 前向声明1.5 内联函数1.6 #include 的路径及顺序 2. 作用域2.1 命名空间2.2 内部链接2.3 非成员函数、静态成员函数和全局函数2.4 局部变量2.5 静态和全局变量2.6 thread_local 变量 3. 类3.…

【HTML入门】Sublime Text 4与 Phpstorm

文章目录 前言一、环境基础1.Sublime Text 42.Phpstorm(1)安装(2)启动Phpstorm(3)“启动”码 二、HTML1.HTML简介(1)什么是HTML(2)HTML版本及历史(3)HTML基本结构 2.HTML简单语法(1)HTML标签语法(2)HTML常用标签(3)表格(4)特殊字符 总结 前言 在当今的软件开发领域&#xff0c…