机器学习——K最近邻算法(KNN)

机器学习——K最近邻算法(KNN)

文章目录

  • 前言
  • 一、原理
  • 二、距离度量方法
    • 2.1. 欧氏距离
    • 2.2. 曼哈顿距离
    • 2.3. 闵可夫斯基距离
    • 2.4. 余弦相似度
    • 2.5. 切比雪夫距离
    • 2.6. 马哈拉诺比斯距离
    • 2.7. 汉明距离
  • 三、在MD编辑器中输入数学公式(额外)
  • 四、代码实现
    • 2.1. 用KNN算法进行分类
    • 2.2. 用KNN算法进行回归
  • 五、模型的保存和加载
  • 总结


前言

在传统机器学习中,KNN算法是一种基于实例的学习算法,能解决分类和回归问题,而本文将介绍一下KNN即K最近邻算法。


在这里插入图片描述

一、原理

K最近邻(KNN)算法是一种基于实例的学习算法,用于分类和回归问题。它的原理是根据样本之间的距离来进行预测
核心思想是通过找到与待分类样本最相似的K个训练样本,来确定待分类样本的类别或者预测其数值。

假设存在一个样本数据集(训练集),并且样本集中每个数据都存在标签(即知道样本集中数据的分类情况)
KNN算法的步骤如下:

  1. 计算距离:对于给定的未知样本(没有标签值的测试集),计算它与训练集中每个样本的距离。常用的距离度量方法有欧氏距离、曼哈顿距离等。

  2. 选择K值:选择一个合适的K值,即要考虑的最近邻的数量。

  3. 选择最近邻:从训练集中选择K个距离最近的样本。

  4. 进行投票或计算平均值:对于分类问题,根据最近邻的标签进行投票,选取票数最多的标签作为预测结果。对于回归问题,根据最近邻的值计算平均值作为预测结果。

按我的理解其实就是将待分类的样本与训练集中的每个样本去计算距离,然后从训练集中选择K个与待分类样本最靠近的几个样本,然后再根据选取得最靠近的几个样本得标签值进行投票来分类。
对于回归问题,则统计K个最近邻样本的数值,然后通过平均或加权平均的方式计算出待分类样本的数值。

如图所示(可看出K值的选择对结果有很大的影响):
在这里插入图片描述

当K=3时,根据距离计算,待分类的样本点被划为黄色那一类;(因为2>1)
当K=5时, 根据距离计算,待分类的样本点被划为红色那一类;(因为3>2)

二、距离度量方法

参考文献
https://zhuanlan.zhihu.com/p/354289511

以下是一些常见的距离度量方法:

2.1. 欧氏距离

欧氏距离(Euclidean Distance):欧氏距离是最常见的距离度量方法,它是两个向量之间的直线距离。对于两个n维向量x和y,欧氏距离的计算公式为:

d ( x , y ) = ∑ i = 1 n ( x i − y i ) 2 d(x,y) = \sqrt{\sum_{i=1}^{n}(x_{i}-y_{i})^{2}} d(x,y)=i=1n(xiyi)2

其中,xi和yi分别表示向量x和y的第i个元素。
例如当n = 2 时,这就是中学学的二维平面中两点之间距离公式的计算了。

2.2. 曼哈顿距离

曼哈顿距离(Manhattan Distance):曼哈顿距离是两个向量之间的城市街区距离,也称为L1距离。对于两个n维向量x和y,曼哈顿距离的计算公式为:
d ( x , y ) = ∑ i = 1 n ∣ x i − y i ∣ d(x,y) = \sum_{i=1}^{n} |x_{i} -y_{i}| d(x,y)=i=1nxiyi

2.3. 闵可夫斯基距离

闵可夫斯基距离(Minkowski Distance):闵可夫斯基距离是欧氏距离和曼哈顿距离的一般化形式,它可以根据参数p的不同取值变化为不同的距离度量方法。对于两个n维向量x和y,闵可夫斯基距离的计算公式为:
d ( x , y ) = ∑ i = 1 n ∣ x i − y i ∣ p p d(x,y) = \sqrt[p]{\sum_{i=1}^{n}|x_{i}-y_{i}|^{p}} d(x,y)=pi=1nxiyip

其中,xi和yi分别表示向量x和y的第i个元素,p为参数,当p=2时,闵可夫斯基距离等价于欧氏距离;当p=1时,闵可夫斯基距离等价于曼哈顿距离。

2.4. 余弦相似度

余弦相似度(Cosine Similarity):余弦相似度是衡量两个向量方向相似程度的度量方法,它计算两个向量之间的夹角余弦值。对于两个n维向量x和y,余弦相似度的计算公式为:

c o s ( θ ) = ∑ i = 1 n ( x i ∗ y i ) ∑ i = 1 n ( x i ) 2 ∗ ∑ i = 1 n ( y i ) 2 cos(\theta ) = \frac{\sum_{i=1}^{n}(x_{i} * y_{i})}{\sqrt{\sum_{i=1}^{n}(x_{i})^{2}}*\sqrt{\sum_{i=1}^{n}(y_{i})^{2}}} cos(θ)=i=1n(xi)2 i=1n(yi)2 i=1n(xiyi)

2.5. 切比雪夫距离

切比雪夫距离(Chebyshev Distance):切比雪夫距离是两个向量之间的最大绝对差距。对于两个n维向量x和y,切比雪夫距离的计算公式为:
d ( x , y ) = m a x i ( ∣ p i − q i ∣ ) d(x,y) = \underset{i}{max}(|p_{i} -q_{i}|) d(x,y)=imax(piqi)

2.6. 马哈拉诺比斯距离

马哈拉诺比斯距离(Mahalanobis Distance):马哈拉诺比斯距离是一种考虑特征之间相关性的距离度量方法。它首先通过计算协方差矩阵来衡量特征之间的相关性,然后计算两个向量在经过协方差矩阵变换后的空间中的欧氏距离。对于两个n维向量x和y,马哈拉诺比斯距离的计算公式为:

d = ( x ⃗ − y ⃗ ) T S − 1 ( x ⃗ − y ⃗ ) d = \sqrt{(\vec{x}-\vec{y})^{T}S^{-1}(\vec{x}-\vec{y})} d=(x y )TS1(x y )

其中,x和y分别表示向量x和y,S为x和y的协方差矩阵。

2.7. 汉明距离

汉明距离(Hamming Distance):汉明距离是用于比较两个等长字符串之间的差异的度量方法。对于两个等长字符串x和y,汉明距离的计算公式为:
d = 1 N ∑ i = 1 n 1 x i ≠ y i d = \frac{1}{N}\sum_{i=1}^{n}1_{x_{i}\neq y_{i}} d=N1i=1n1xi=yi

三、在MD编辑器中输入数学公式(额外)

在使用markdown文本编辑器时,对于数学公式的书写一般是使用到LaTeX这个排版系统,基于latex语法构建数学公式。

这对我这种刚开始接触的初学者是不友好的(在这之前还要学习LateX语法…)。
$$

$$
在这之间填入数学公式对应的LaTeX语法,就能获得对应的数学公式

对应的LaTeX语法可以从另一个编辑器——富文本编辑器 中获得:

在这里插入图片描述
将LaTeX公式复制过来,d(x,y) = \sqrt{\sum_{i=1}{n}(x_{i}-y_{i}){2}}
$$

$$
放于这两个之间,可以得到对应公式:

d ( x , y ) = ∑ i = 1 n ( x i − y i ) 2 d(x,y) = \sqrt{\sum_{i=1}^{n}(x_{i}-y_{i})^{2}} d(x,y)=i=1n(xiyi)2

嗯…,其实我也不太清楚为何我的Mardown编辑器中没有像富文本编辑器中那样的公式编辑器,(或许是要下载插件吗?),不用管这么多,能用就行。

四、代码实现

2.1. 用KNN算法进行分类

from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score# 加载数据集
iris = load_iris()
X = iris.data
y = iris.target# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=42)# 创建KNN分类器
knn = KNeighborsClassifier(n_neighbors=3)
#metric= "minkowski",距离度量默认是闵可夫斯基距离# 拟合模型
knn.fit(X_train, y_train)# 预测
y_pred = knn.predict(X_test)# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)
Accuracy: 0.9833333333333333

2.2. 用KNN算法进行回归

from sklearn.neighbors import KNeighborsRegressor
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error# 加载数据集
boston = load_boston()
X = boston.data
y = boston.target# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 创建KNN回归器
knn = KNeighborsRegressor(n_neighbors=3)
# 拟合模型
knn.fit(X_train, y_train)
# 预测
y_pred = knn.predict(X_test)# 计算均方误差
mse = mean_squared_error(y_test, y_pred)
print("MSE:", mse)
MSE: 21.65955337690632#计算R方值
print(knn.score(X_test,y_test))
0.7046442656646525#绘图展示
import matplotlib.pyplot as plt
plt.style.use("ggplot")
plt.scatter(y_test,y_pred)
plt.plot([min(y_test),max(y_test)],[min(y_pred),max(y_pred)],"k--",color = "green", lw = 2,)
plt.xlabel("y_test")
plt.ylabel("y_pred")
plt.show()

在这里插入图片描述

均方误差:

M S E = ∑ i = 1 n ( y t − y p ) 2 n MSE = \frac{\sum_{i=1}^{n}(y_t - y_p)^{2}}{n} MSE=ni=1n(ytyp)2

再用线性回归试一下:

from sklearn.linear_model import LinearRegression
model = LinearRegression()
model.fit(X_train, y_train)
coefficients = model.coef_
intercept = model.intercept_# 构建回归公式
equation = f"y = {intercept} + {coefficients[0]}*x1 + {coefficients[1]}*x2 + ..."# 计算R^2值
r2_score = model.score(X_test, y_test)
print("R^2值:", r2_score)
R^2值: 0.6687594935356289

这些模型都是十分简单的模型,还未经过参数的调优和算法的优化。

五、模型的保存和加载

#模型的保存和加载
import pickle
with open("model.pkl","wb") as f:pickle.dump(knn,f)
with open("model.pkl","rb") as f:knn_loaded = pickle.load(f)print(knn_loaded.score(X_test,y_test))
0.7046442656646525

总结

本文从KNN算法的原理:(根据样本之间的距离来预测)出发,介绍了一些常见的距离度量方法,另外也介绍了一下在Markdown编辑器中输入数学公式,最后就是KNN算法在python中的分类和回归代码的实现。最后的最后就是模型的保存和加载。

道可道,非常道;名可名,非常名。

–2023-9-10 筑基篇

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/74697.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

门面设计模式

github:GitHub - QiuliangLee/pattern: 设计模式 1 什么是门面设计模式 门面设计模式是一种软件设计模式,也被称为外观(Facade)模式。它提供了一个简单的接口,让客户端能够访问复杂系统中的一组接口。通过门面模式&a…

企业架构LNMP学习笔记15

客户端缓存: B/S架构里,Browser是浏览器,就是客户端。 客户端缓存告知浏览器获取服务段的信息是在某个区间时间段是有效的。 每次请求从服务器拿一遍数据,数据没有变化,影响带宽,影响时间。刷新又要去加载…

Java中快速排序的优化技巧:随机取样、三数取中和插入排序

目录 快速排序基础 优化1:随机取样 优化2:三数取中 优化3:插入排序 总结: 快速排序(Quick Sort)是一种高效的排序算法,它的平均时间复杂度为O(n log n)。然而,在某些情况下&…

rust中的reborrow和NLL

reborrow 我们看下面这段代码 fn main() {let mut num 123;let ref1 &mut num; // 可变引用add(ref1); // 传递给 add 函数println!("{}", ref1); // 再次使用ref1 }fn add(num: &mut i32) {println!("{}", *num); }我们…

机器学习——生成分类数据的坐标系边界需要用到的技术方法

0、前言: 如果遇到一种应用场景需要将x轴数据和y轴数据所有点映射到坐标系中,需要得到坐标系中x和y映射的坐标点,就要用到meshgrid把x和y映射到坐标系中,然后把得到的结果用ravel把结果转成一维的。用np.c_()把x数据和y数据堆叠在…

Python实现猎人猎物优化算法(HPO)优化BP神经网络回归模型(BP神经网络回归算法)项目实战

说明:这是一个机器学习实战项目(附带数据代码文档视频讲解),如需数据代码文档视频讲解可以直接到文章最后获取。 1.项目背景 猎人猎物优化搜索算法(Hunter–prey optimizer, HPO)是由Naruei& Keynia于2022年提出的一种最新的…

spring boot-Resolved element must not contain multiple elements 警告

首先强调一下,此问题不影响程序运行。 报错信息: package org.springframework.util; ...public abstract class Assert ...public static void state(boolean expression, String message) {if (!expression) {throw new IllegalStateException(messa…

人工智能TensorFlow PyTorch物体分类和目标检测合集【持续更新】

1. 基于TensorFlow2.3.0的花卉识别 基于TensorFlow2.3.0的花卉识别Android APP设计_基于安卓的花卉识别_lilihewo的博客-CSDN博客 2. 基于TensorFlow2.3.0的垃圾分类 基于TensorFlow2.3.0的垃圾分类Android APP设计_def model_load(img_shape(224, 224, 3)_lilihewo的博客-CS…

flink 端到端一致性

背景 我们经常会混淆flink提供的状态一致性保证和数据端到端一致性保证的关系,总以为他们表达的是同一个意思,事实上,他们不是一个含义,flink只能保证其维护的内部状态的一致性,而数据端到端的一致性需要数据源&#…

数学建模:多目标优化算法

🔆 文章首发于我的个人博客:欢迎大佬们来逛逛 数学建模:多目标优化算法 多目标优化 分别求权重方法 算法流程: 两个目标权重求和,化为单目标函数,然后求解最优值 min ⁡ x ∑ i 1 m w i F i ( x ) s.…

I - Protecting the Flowers

Farmer John went to cut some wood and left N (2 ≤ N ≤ 100,000) cows eating the grass, as usual. When he returned, he found to his horror that the cluster of cows was in his garden eating his beautiful flowers. Wanting to minimize the subsequent damage, F…

南大通用数据库-Gbase-8a-学习-38-常规日志(general log)

目录 一、环境信息 二、general log的用途 三、general log相关参数介绍 四、LInux环境模拟实验 1、查看参数配置 2、开启general log 3、输入测试SQL 4、查看文件级别general log 5、改为表级别general log 6、再次输入测试SQL 7、查看gbase.general_log 一、环境信…

微信小程序开发教学系列(4)- 抖音小程序组件开发

章节四:抖音小程序组件开发 在本章中,我们将深入探讨抖音小程序的组件开发。组件是抖音小程序中的基本构建块,它们负责展示数据和与用户交互。了解组件的开发方法和使用技巧是进行抖音小程序开发的重要一步。 4.1 抖音小程序的基本组件 抖…

iOS接入IJKPlayer遇到的问题汇总

这里有一个我自己编译的IJKMediaFramework,能解决目前Github上反馈很多常见的IJKPlayer使用问题(包含播放异常,UI主线程Crash等),替换自己项目中的IJKMediaFramework即可链接: https://pan.baidu.com/s/1UO-YfN_1YIDOX81bgW8bag?pwdvq4u 提取…

题目:2695.包装数组

​​题目来源: leetcode题目,网址:2695. 包装数组 - 力扣(LeetCode) 解题思路: 按要求模拟即可。 解题代码: /*** param {number[]} nums*/ var ArrayWrapper function(nums) {this.valuenu…

HTML事件列表

鼠标事件 属性描述DOMonclick当用户点击某个对象时调用的事件句柄。2oncontextmenu在用户点击鼠标右键打开上下文菜单时触发ondblclick当用户双击某个对象时调用的事件句柄。2onmousedown鼠标按钮被按下。2onmouseenter当鼠标指针移动到元素上时触发。2onmouseleave当鼠标指针…

安装samba服务器

1.实验目的 (1)了解SMB和NETBIOS的基本原理 (2)掌握Windows和Linux之间,Linux系统之间文件共享的基本方法。 2.实验内容 (1)安装samba服务器。 (2)配置samba服务器的…

unity 控制Dropdown的Arrow箭头变化

Dropdown打开下拉菜单会以“Template”为模板创建一个Dropdown List,在“Template”上添加一个脚本在Start()中执行下拉框打开时的操作,在OnDestroy()中执行下拉框收起时的操作即可。 效果代码如下用于控制Arrow旋转可以根据自己的想法进行修改&#xff…

【RuoYi移动端】uni-app中实现生成二维码功能(代码示例)

完整示例&#xff1a; <template><view><view class"titleBar">执法检查“通行码”信息</view><view class"twoCode"><canvas canvas-id"qrcode"></canvas></view></view> </templat…

HashMap知识总结

HashMap: 1. 扰动函数hash值右移16位与原hash值做异或运算得出的新hash值散列程度高. 2. 负载因子0.75,就是说一个数组初始化new HashMap(17)容量会比17最小2的n次方大,就是32,想要已空间换时间,就是负载因子小于0.75这样的话hash冲突更低,但是扩容频率更高.3 扩容,jdk…