【机器学习】KNN算法-模型选择与调优

KNN算法-模型选择与调优

文章目录

  • KNN算法-模型选择与调优
    • 1. 交叉验证
    • 2. 超参数搜索-网格搜索(Grid Search)
    • 3. 模型选择与调优API
    • 4. 鸢尾花种类预测-代码和输出结果
    • 5. 计算距离

问题背景:KNN算法的K值不好确定

1. 交叉验证

交叉验证:将拿到的训练数据,分为训练集和验证集。以下表为例:将数据分成4份,其中一份作为验证集,然后经过4次(组)的测试,每次都更换不同的验证集。即得到4组模型的结果,取平均值作为最终的结果。这种又称作为4折交叉认证。

第一块第二块第三块第四块准确率
验证集训练集训练集训练集80%
训练集验证集训练集训练集78%
训练集训练集验证集训练集75%
训练集训练集训练集验证集82%

我们之前知道数据分为训练集和测试集,但是为了从训练得到的模型结果更加准确,做出以下处理

  • 训练集=训练集+验证集
  • 测试集=测试集

2. 超参数搜索-网格搜索(Grid Search)

通常情况下,有很多参数是要手动去指定的,如KNN算法中的K值,这种叫超参数。但是手动过程繁杂,我们可能会定义一个列表,里面有一堆K的值来遍历选择,相当于“暴力破解”。而网格搜索会采用交叉认证来进行评估,在你给定的一定范围内的K值中选出最优参数组合建立模型。

3. 模型选择与调优API

  • sklearn.model_selection.GridSearchCV(estimator,param_grid=None,cv=None)
    • 对估计器的指定参数值进行详尽搜索
    • estimator估计器对象
    • param_grid:估计器参数(dict){“n_neighbors":[1,3,5]}
    • cv:指定几折交叉验证
    • fit():输入训练数据
    • score():准确率
    • 结果分析:best_params_最佳参数,best_score_最佳结果,best_estimator_最佳估计器,cv_results_交叉验证结果

4. 鸢尾花种类预测-代码和输出结果

from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV# K—近邻算法
def KNN_demo():"""sklearn.neighbors.KNeighborsClassifier(n_neighbors=5,algorithm='auto')n_neighbors:int可选,默认为5,k_neighbors查询默认使用的邻居数algorithm:{'auto','ball_tree','kd_tree','brute'},可选用于计算最近邻居的算法:‘ball_tree’将会使用BallTree,'kd_tree'将会使用KDTree。'auto'将尝试根据传递给fit方法的值来决定最合适的算法。(不同实现方式影响效率):return:"""# 获取数据iris = load_iris()# 划分数据集x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state= 6)# 特征工程 标准化transfer = StandardScaler()x_train = transfer.fit_transform(x_train)x_test = transfer.transform(x_test)# KNN算法预估器estimator = KNeighborsClassifier(n_neighbors= 3)estimator.fit(x_train, y_train)# 模型评估# 方法一:y_predict = estimator.predict(x_test)print("y_predict:\n", y_predict)print("直接比对真实值和预测值:\n", y_test == y_predict)# 方法二:score = estimator.score(x_test, y_test)print("准确率为:\n", score)return None# KNN添加网格搜索和交叉认证
def KNN_gscv_demo():# 获取数据iris = load_iris()# 划分数据集x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=6)# 特征工程 标准化transfer = StandardScaler()x_train = transfer.fit_transform(x_train)x_test = transfer.transform(x_test)# KNN算法预估器estimator = KNeighborsClassifier()# 加入网格搜索和交叉认证param_dict = {"n_neighbors": [1, 3, 5, 7, 9, 11]}estimator = GridSearchCV(estimator, param_grid= param_dict, cv =10)estimator.fit(x_train, y_train)# 模型评估# 方法一:y_predict = estimator.predict(x_test)print("y_predict:\n", y_predict)print("直接比对真实值和预测值:\n", y_test == y_predict)# 方法二:score = estimator.score(x_test, y_test)print("准确率为:\n", score)# 最佳print("最佳参数为:\n", estimator.best_params_)print("最佳结果:\n", estimator.best_score_)print("最佳估计器:\n", estimator.best_estimator_)print("交叉验证结果:\n", estimator.cv_results_)# 交叉验证结果为:训练集划分训练集和验证集之后的,不是整体的,和测试集无关return Noneif __name__ == "__main__":# KNN_demo() 没有添加网格搜索和交叉认证KNN_gscv_demo()pass
y_predict:[0 2 0 0 2 1 2 0 2 1 2 1 2 2 1 1 2 1 1 0 0 2 0 0 1 1 1 2 0 1 0 1 0 0 1 2 12]
直接比对真实值和预测值:[ True  True  True  True  True  True  True  True  True  True  True  TrueTrue  True  True False  True  True  True  True  True  True  True  TrueTrue  True  True  True  True  True  True  True  True  True False  TrueTrue  True]
准确率为:0.9473684210526315
最佳参数为:{'n_neighbors': 11}
最佳结果:0.9734848484848484
最佳估计器:KNeighborsClassifier(n_neighbors=11)
交叉验证结果:{'mean_fit_time': array([0.00010171, 0.        , 0.00030091, 0.        , 0.        ,0.00020049]), 'std_fit_time': array([0.00030513, 0.        , 0.00045964, 0.        , 0.        ,0.00040097]), 'mean_score_time': array([0.00110393, 0.00069332, 0.00051594, 0.00090301, 0.00085185,0.0005013 ]), 'std_score_time': array([0.00070476, 0.00039479, 0.00065858, 0.00030101, 0.00032043,0.0005013 ]), 'param_n_neighbors': masked_array(data=[1, 3, 5, 7, 9, 11],mask=[False, False, False, False, False, False],fill_value='?',dtype=object), 'params': [{'n_neighbors': 1}, {'n_neighbors': 3}, {'n_neighbors': 5}, {'n_neighbors': 7}, {'n_neighbors': 9}, {'n_neighbors': 11}], 'split0_test_score': array([1., 1., 1., 1., 1., 1.]), 'split1_test_score': array([0.91666667, 0.91666667, 1.        , 0.91666667, 0.91666667,0.91666667]), 'split2_test_score': array([1., 1., 1., 1., 1., 1.]), 'split3_test_score': array([1.        , 1.        , 1.        , 1.        , 0.90909091,1.        ]), 'split4_test_score': array([1., 1., 1., 1., 1., 1.]), 'split5_test_score': array([0.90909091, 0.90909091, 1.        , 1.        , 1.        ,1.        ]), 'split6_test_score': array([1., 1., 1., 1., 1., 1.]), 'split7_test_score': array([0.90909091, 0.90909091, 0.90909091, 0.90909091, 1.        ,1.        ]), 'split8_test_score': array([1., 1., 1., 1., 1., 1.]), 'split9_test_score': array([0.90909091, 0.81818182, 0.81818182, 0.81818182, 0.81818182,0.81818182]), 'mean_test_score': array([0.96439394, 0.95530303, 0.97272727, 0.96439394, 0.96439394,0.97348485]), 'std_test_score': array([0.04365767, 0.0604591 , 0.05821022, 0.05965639, 0.05965639,0.05742104]), 'rank_test_score': array([5, 6, 2, 3, 3, 1])}

5. 计算距离

K最近邻(KNN)是一种有监督的机器学习算法,它根据其K个最近邻居的大多数类别来对数据点进行分类。在使用KNN时,需要确定一个距离度量来衡量数据点之间的相似性。常用的KNN距离度量包括欧氏距离、曼哈顿距离和闵可夫斯基距离。

  1. 欧氏距离:

    • 欧氏距离是KNN中最常用的距离度量。

    • 它是欧几里得空间中两个点之间的直线距离

    • 在二维空间中,计算两个点(x1,y1)和(x2,y2)之间的欧氏距离的公式如下:

      ( x 1 − x 2 ) 2 + ( y 1 − y 2 ) 2 \sqrt{(x1 - x2)^2 + (y1 - y2)^2} (x1x2)2+(y1y2)2

    • 在n维空间中,公式扩展为:
      ∑ i = 1 n ( x i − y i ) 2 \sqrt{\sum_{i=1}^{n}(x_i - y_i)^2} i=1n(xiyi)2

    • 这种距离度量对特征的尺度敏感,因此在使用时重要的是标准化或归一化特征。

  2. 曼哈顿距离:

    • 它以每个维度上的坐标绝对差的总和来衡量两个点之间的距离。

    • 在二维空间中,计算两个点(x1,y1)和(x2,y2)之间的曼哈顿距离的公式如下:
      ∣ x 1 − x 2 ∣ + ∣ y 1 − y 2 ∣ |x1 - x2| + |y1 - y2| x1x2∣+y1y2∣

    • 在n维空间中,公式扩展为:
      ∑ i = 1 n ∣ x i − y i ∣ \sum_{i=1}^{n}|x_i - y_i| i=1nxiyi

    • 曼哈顿距离对异常值不太敏感,因此在数据可能不服从正态分布的情况下,它是更好的选择。

  3. 闵可夫斯基距离:

    • 闵可夫斯基距离是欧氏距离和曼哈顿距离的通用化。
    • 它包括一个参数“p”,可以调整以将公式转换为欧氏或曼哈顿距离。
    • 当p=2时,它变为欧氏距离,当p=1时,它变为曼哈顿距离。
    • 两点(x,y)之间的闵可夫斯基距离的公式如下:
      ( ∑ i = 1 n ∣ x i − y i ∣ p ) 1 / p \left(\sum_{i=1}^{n}|x_i - y_i|^p\right)^{1/p} (i=1nxiyip)1/p

默认情况下,KNN使用欧氏距离作为距离度量。如果使用不同的距离度量(例如曼哈顿或闵可夫斯基距离),可以在KNeighborsClassifier构造函数中使用“metric”参数进行指定。例如:

estimator = KNeighborsClassifier(metric='manhattan')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/126582.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

联想百应:构建“生态资源池”,打造中小企业转型第一服务平台

与3800多家服务商和100多家SaaS生态伙伴携手,累计支持超过20万中小企业智能化转型……在近日由工业和信息化部和安徽省举办的2023全国中小企业数字化转型大会上,联想集团首次公布供应链、平台、技术、生态与绿色赋能五大赋能能力和助力中小企业“链式”成…

sqlite3 关系型数据库语言 SQL 语言

SQL(Structured Query Language)语言是一种结构化查询语言,是一个通用的,功能强大的关系型数据库操作语言. 包含 6 个部分: 1.数据查询语言(DQL:Data Query Language) 从数据库的二维表格中查询数据,保留字 SELECT 是 DQL 中用的最多的语句 2.数据操作语言(DML) 最主要的关…

js编写一个函数判断所有数据类型

一、typeof 在 JavaScript 里使用 typeof 来判断数据类型,只能区分基本类型,即 “number”,”string”,”undefined”,”boolean”,”object” 五种。 对于数组、对象来说,其关系错综复杂&…

神经网络的解释方法之CAM、Grad-CAM、Grad-CAM++、LayerCAM

原理优点缺点GAP将多维特征映射降维为一个固定长度的特征向量①减少了模型的参数量;②保留更多的空间位置信息;③可并行计算,计算效率高;④具有一定程度的不变性①可能导致信息的损失;②忽略不同尺度的空间信息CAM利用…

前端 :用HTML , CSS ,JS 做一个秒表

1.HTML&#xff1a; <body><div id "content"><div id "top"><div id"time">00:00:000</div></div><div id "bottom"><div id "btn_start">开始</div><div …

04.Oracle的体系架构

Oracle的体系架构 一、主要组件 一、主要组件 下面是一张网图&#xff0c;大家可以了解一下oracle的体系架构 Oracle数据库的体系架构可以分为以下几个主要组件&#xff1a;实例&#xff08;Instance&#xff09;、数据库&#xff08;Database&#xff09;、表空间&#xff…

瑞数专题五

今日文案&#xff1a;焦虑&#xff0c;想象力过度发酵的产物。 网址&#xff1a;https://www.iyiou.com/ 专题五主要是分享瑞数6代。6代很少见&#xff0c;所以找理想哥要的&#xff0c;感谢感谢。 关于瑞数作者之前已经分享过4篇文章&#xff0c;全都收录在瑞数专栏中了&am…

21. 合并两个有序链表、Leetcode的Python实现

博客主页&#xff1a;&#x1f3c6;看看是李XX还是李歘歘 &#x1f3c6; &#x1f33a;每天不定期分享一些包括但不限于计算机基础、算法、后端开发相关的知识点&#xff0c;以及职场小菜鸡的生活。&#x1f33a; &#x1f497;点关注不迷路&#xff0c;总有一些&#x1f4d6;知…

正式启航!指导品牌开拓下一个增长蓝海

种草的商品总在不经意间推送到面前&#xff0c;深夜刷了会儿短视频&#xff0c;不小心又下单了一个不太熟悉的产品&#xff0c;明星达人素人全部入局直播带货&#xff0c;社交平台演变成购物场&#xff0c;无人幸免的兴趣电商时代强势来临。尤其到了每年一度的双11大促节点&…

数据库概念和sql语句

数据库概念和sql语句 数据&#xff1a;数&#xff1a;数字信息 据&#xff1a;属性 对一系列对象的具体属性的描述的集合 数据库&#xff1a;数据库就是用来组织&#xff08;各个数据之间是有关联&#xff0c;是按照规则组织起来的&#xff09;&#xff0c;存储和管理&…

音视频rtsp rtmp gb28181在浏览器上的按需拉流

按需拉流是从客户视角来看待音视频的产品功能&#xff0c;直观&#xff0c;好用&#xff0c;为啥hls flv大行其道也是这个原因&#xff0c;不过上述存在的问题是延迟没法降到实时毫秒级延迟&#xff0c;也不能随心所欲的控制。通过一段时间的努力&#xff0c;结合自己闭环技术栈…

C++新版本学习资源整理

链接资源推荐&#xff1a; C11/14/17/20 特性介绍 转 | 有点博客

Web APIs——日期对象的使用

1、日期对象 日期对象&#xff1a;用来表示时间的对象 作用&#xff1a;可以得到当前系统时间 1.1实例化 在代码中发现了new关键字时&#xff0c;一般将这个操作称为实例化 创建一个时间对象并获取时间 获得当前时间 const date new Date() <script>// 实例化 new //…

UE5 Android下载zip文件并解压缩到指定位置

一、下载是使用市场的免费插件 二、解压缩是使用市场的免费插件 三、Android路径问题 windows平台下使用该插件没有问题&#xff0c;只是在Android平台下&#xff0c;只有使用绝对路径才能进行解压缩&#xff0c;所以如何获得Android下的绝对路径&#xff1f;增加C文件获得And…

铁轨(Rails, ACM/ICPC CERC 1997, UVa 514)rust解法

有一个火车站&#xff0c;铁轨铺设如图6-1所示。有n节车厢从A方向驶入车站&#xff0c;按进站顺序编号为1&#xff5e;n。你的任务是判断是否能让它们按照某种特定的顺序进入B方向的铁轨并驶出车站。例如&#xff0c;出栈顺序(5 4 1 2 3)是不可能的&#xff0c;但(5 4 3 2 1)是…

python使用requests+excel进行接口自动化测试

在当今的互联网时代中&#xff0c;接口自动化测试越来越成为软件测试的重要组成部分。Python是一种简单易学&#xff0c;高效且可扩展的语言&#xff0c;自然而然地成为了开发人员的首选开发语言。而requests和xlwt这两个常用的Python标准库&#xff0c;能够帮助我们轻松地开发…

29、枚举

枚举 枚举使用场景枚举语法及特性特性&#xff1a; 手动给枚举赋值手动赋值项和未手动赋值项重复手动赋值项智能赋值数字&#xff1f;NO常数项和计算项常数枚举外部枚举 枚举使用场景 枚举类型 用于取值被限定在一定范围内的场景。 demo&#xff1a; 一周只能有七天&#xff0…

sqlLite 如何使用数据库连接池

文章底部有个人公众号&#xff1a;热爱技术的小郑。主要分享开发知识、学习资料、毕业设计指导等。有兴趣的可以关注一下。为何分享&#xff1f; 踩过的坑没必要让别人在再踩&#xff0c;自己复盘也能加深记忆。利己利人、所谓双赢。 一、前言 编写的一个jar包工具中&#xff…

JS(JavaScript) 实现延迟等待(sleep方法)

起因&#xff1a; 只使用 setTimeout 会产生嵌套等方面的问题&#xff0c;达不到想要的效果。 解决方法&#xff1a; 使用 async/await 还有 Promise 相结合的方式来解决问题。 直接上代码&#xff1a; function sleep(time) {return new Promise((resolve) > setTimeout…

公众号留言功能报价是多少?值得开通吗?

为什么公众号没有留言功能&#xff1f;根据要求&#xff0c;自2018年2月12日起&#xff0c;新申请的微信公众号默认无留言功能。有些人听过一个说法&#xff1a;公众号粉丝累计到一定程度或者原创文章数量累计到一定程度就可以开通留言功能。其实这个方法是2018年之前才可以&am…