4.sklearn-K近邻算法、模型选择与调优

文章目录

  • 环境配置(必看)
  • 头文件引用
    • 1.sklearn转换器和估计器
        • 1.1 转换器 - 特征工程的父类
        • 1.2 估计器(sklearn机器学习算法的实现)
    • 2.K-近邻算法
      • 2.1 简介:
      • 2.2 K-近邻算法API
      • 2.3 K-近邻算法代码
      • 2.4 运行结果
      • 2.5 K-近邻算法优缺点
    • 3.模型选择与调优
      • 3.1 交叉验证(cross validation)
      • 3.2 网格搜索(Grid Search)
      • 3.3 交叉验证,网格搜索(模型选择与调优)API:
      • 3.4 代码
      • 3.5 运行结果
  • 本章学习资源

环境配置(必看)

Anaconda-创建虚拟环境的手把手教程相关环境配置看此篇文章,本专栏深度学习相关的版本和配置,均按照此篇文章进行安装。

头文件引用

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier

1.sklearn转换器和估计器

1.1 转换器 - 特征工程的父类
1 实例化 (实例化的是一个转换器类(Transformer))
2 调用fit_transform(对于文档建立分类词频矩阵,不能同时调用)标准化:(x - mean) / std	(特征 - 均值)/ 标准差fit_transform()fit()           	计算 每一列的平均值、标准差transform()     	(x - mean) / std进行最终的转换
1.2 估计器(sklearn机器学习算法的实现)
1 实例化一个estimator
2 estimator.fit(x_train, y_train) 计算—— 调用完毕,模型生成
3 模型评估:1)直接比对真实值和预测值y_predict = estimator.predict(x_test)y_test == y_predict2)计算准确率accuracy = estimator.score(x_test, y_test)

2.K-近邻算法

2.1 简介:

KNN核心思想:你的“邻居”来推断出你的类别1 K-近邻算法(KNN)原理k = 1容易受到异常点的影响如何确定谁是邻居?计算距离:距离公式欧氏距离  --  算法默认的是使用欧式距离曼哈顿距离 绝对值距离明可夫斯基距离如果取的k值不一样?会是什么结果?k 值取得过小,容易受到异常点的影响k 值取得过大,样本不均衡的影响

2.2 K-近邻算法API

sklearn.neighbors.KNeighborsClassifier(n_neighbors=5,algorithm='auto')

API注释:

n_neighbors:int,可选(默认= 5),k_neighbors查询默认使用的邻居数
algorithm:{‘auto’,‘ball_tree’,‘kd_tree’,‘brute’}快速k近邻搜索算法,默认参数为auto,可以理解为算法自己决定合适的搜索算法。除此之外,用户也可以自己指定搜索算法ball_tree、kd_tree、brute方法进行搜索,
brute:是蛮力搜索,也就是线性扫描,当训练集很大时,计算非常耗时。
kd_tree:构造kd树存储数据以便对其进行快速检索的树形数据结构,kd树也就是数据结构中的二叉树。以中值切分构造的树,每个结点是一个超矩形,在维数小于20时效率高。
ball tree:是为了克服kd树高维失效而发明的,其构造过程是以质心C和半径r分割样本空间,每个节点是一个超球体

2.3 K-近邻算法代码

分析:

  1. x_test = transfer.transform(x_test),测试集只是使用transform进行标准化,是因为要和训练集x_train 做一样的处理,训练集调用transfer.fit_transform()计算出的均值,标准差的值均在模型中,x_test = transfer.transform(x_test)就是直接使用测试集的参数进行计算。
def knn_iris():"""用KNN算法对鸢尾花进行分类:return:"""# 1.获取数据iris = load_iris()# 2.划分数据集  参数:特征值,目标值,随机数种子x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=22)# 3.特征工程:标准化transfer = StandardScaler()x_train = transfer.fit_transform(x_train)x_test = transfer.transform(x_test)                                 # 4.KNN算法预估器  n_neighbors=3就是K值等于3estimator = KNeighborsClassifier(n_neighbors=3)estimator.fit(x_train, y_train)# 5.模型评估# 方法1: 直接比对真实值和预测值y_predict = estimator.predict(x_test)print(f"y_predict:\n{y_predict}")print(f"直接比对真实值和预测值: {y_test == y_predict}")# 方法2: 计算准确率score = estimator.score(x_test, y_test)print(f"准确率为: {score}")return None

2.4 运行结果

在这里插入图片描述

2.5 K-近邻算法优缺点

优点:简单,易于理解,易于实现,无需训练
缺点:1)必须指定K值,K值选择不当则分类精度不能保证2)懒惰算法,对测试样本分类时的计算量大,内存开销大使用场景:小数据场景,几千~几万样本,具体场景具体业务去测试

3.模型选择与调优

3.1 交叉验证(cross validation)

交叉验证:将拿到的训练数据,分为训练和验证集。以下图为例:将数据分成4份,其中一份作为验证集。然后经过4()的测试,每次都更换
不同的验证集。即得到4组模型的结果,取平均值作为最终结果。又称4折交叉验证。

在这里插入图片描述

3.2 网格搜索(Grid Search)

通常情况下,有很多参数是需要手动指定的(如k-近邻算法中的K值),这种叫超参数。但是手动过程繁杂,所以需要对模型预设几种超参数组合。
每组超参数都采用交叉验证来进行评估。最后选出最优参数组合建立模型。

在这里插入图片描述

3.3 交叉验证,网格搜索(模型选择与调优)API:

sklearn.model_selection.GridSearchCV(estimator, param_grid=None,cv=None)
对估计器的指定参数值进行详尽搜索estimator:估计器对象param_grid:估计器参数(dict){“n_neighbors”:[1,3,5]}cv:指定几折交叉验证fit:输入训练数据score:准确率
结果分析:bestscore__:在交叉验证中验证的最好结果bestestimator:最好的参数模型cvresults:每次交叉验证后的验证集准确率结果和训练集准确率结果

3.4 代码

def knn_iris_gscv():"""用KNN算法对鸢尾花进行分类,添加网格搜索和交叉验证:return:"""# 1.获取数据iris = load_iris()# 2.划分数据集x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=20)# 3.特征工程:标准化transfer = StandardScaler()x_train = transfer.fit_transform(x_train)x_test = transfer.transform(x_test)# 4.KNN算法预估器estimator = KNeighborsClassifier()# 加入网格搜索和交叉验证# 参数准备param_dict = {"n_neighbors": [1, 3, 5, 7, 9, 11]}   # 网格搜索# cv=10 代表10折运算(交叉验证)estimator = GridSearchCV(estimator, param_grid=param_dict, cv=10)estimator.fit(x_train, y_train)# 5.模型评估# 方法1: 直接比对真实值和预测值y_predict = estimator.predict(x_test)print(f"y_predict:\n{y_predict}")print(f"直接比对真实值和预测值: {y_test == y_predict}")# 方法2: 计算准确率score = estimator.score(x_test, y_test)print(f"准确率为: {score}")# 最佳参数:print("最佳参数: \n", estimator.best_params_)# 最佳结果:print("最佳结果: \n", estimator.best_score_)# 最佳参数:print("最佳估计器: \n", estimator.best_estimator_)# 交叉验证结果:print("交叉验证结果: \n", estimator.cv_results_)return None

3.5 运行结果

在这里插入图片描述

本章学习资源

黑马程序员3天快速入门python机器学习
我是跟着视频进行的学习,欢迎大家一起来学习!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/53249.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数组与贪心算法——605、121、122、561、455、575(5简1中)

605. 种花问题(简单) 假设有一个很长的花坛,一部分地块种植了花,另一部分却没有。可是,花不能种植在相邻的地块上,它们会争夺水源,两者都会死去。 给你一个整数数组 flowerbed 表示花坛&#xf…

python科学计算:NumPy 简介与安装

1 NumPy 是什么? NumPy(Numerical Python 的简称)是 Python 语言中最为广泛使用的科学计算库。它支持多维数组和矩阵运算,并提供丰富的数学函数库,使得数据处理和数值计算变得更加高效。 NumPy 的核心是提供了一个强…

golang zap日志模块封装sentry

我们自己写个log日志包,把zap和sentry封装到一起。 下面直接贴上主要部分代码(两个模块初始化部分的代码请自行查阅官方文档): logger.go package logimport ("github.com/getsentry/sentry-go""go.uber.org/zap…

MFC读取PC6408板卡输入信号实例

本程序基于前期我的博客文章《MFC用信号灯模拟工控机数字量输入信号实时采集实例(源码下载》 1、在TheradDlg.h中相关代码 ... private:unsigned short nAddr; ... TheradDlg.cpp中相关代码 #include "pc60002k.h"BOOL CTheradDlg::OnInitDialog() { ..…

Mapmost让你实现地图标注自由

最近在勤勤恳恳(moyuhaushui)搬砖之余,偶然间看到一个在线古籍图书馆,虽然对文言文阅读的心理障碍不亚于英文阅读理解,但网站中有很多历史图集还是引起了兴趣。比如这幅《水经注图》,顺藤摸瓜的瞧&#xff…

Java中对象拷贝的深度解析:从零拷贝到深拷贝的演进

前言 在Java编程世界中,对象的拷贝是一个基础而重要的操作。它涉及到内存管理、数据一致性以及程序的健壮性等多个方面。随着软件架构的复杂化和数据的多样化,对象拷贝的策略也从最初的简单赋值(零拷贝)发展到深拷贝,…

DataWorks数据质量监控方案

背景 日常的调度监控,可以查看实例任务的运行情况,对运行失败的实例进行告警,但是却无法对运行成功的实例进行数据质量的判断。而有些情况下,即使实例任务运行成功了,数据也仍然存在问题,这时候就需要对数…

多线程——线程安全

线程安全问题 同时满足以下两个条件时: 多个线程在操作共享的数据。操作共享数据的线程代码有多条。 当一个线程在执行操作共享数据的多条代码过程中,其他线程参与了运算,就会导致线程安全问题的产生。 解决这样的问题就是线程同步的方式来…

揭秘Taboola原生广告:欧美流量变现联盟营销金牌策略

揭秘Taboola原生广告:欧美流量变现的金牌策略 在数字营销日益精进的今天,如何高效地将网站流量转化为实际收益成为了众多欧美网站主关注的焦点。Taboola,作为原生广告领域的佼佼者,凭借其独特的广告展示方式与强大的数据驱动能力…

判断两个yaw角度之差是否超过了90度

一. 判断两个yaw角度之差是否超过了90度 要判断两个 yaw 角度之差是否超过 90 度,你可以通过计算这两个角度的差值,并将其归一化为 [-180, 180] 的范围内。接着,只需判断该差值的绝对值是否大于 90 度。 实现步骤: 计算角度差&…

上海晋名室外危化品暂存柜助力新能源行业发展

近日又有一个SAVEST室外危化品暂存柜项目成功验收交付使用。 用户在日常经营活动中涉及到气瓶和硅粉的室外安全暂存问题,4月下旬在网上看到上海晋名室外暂存柜系列很感兴趣,联系到了销售部钟经理,双方对晋名的室外暂存柜进行了高效的沟通&am…

无人机+应用综合实训室解决方案

随着无人机技术的飞速发展,其在航拍、农业、环境监测、物流运输等多个领域展现出巨大的应用潜力。为了满足职业院校及企业对无人机应用技术型人才的培养需求,唯众紧跟市场趋势,推出了全面且详尽的《无人机应用综合实训室解决方案》。本方案旨…

MACOS安装配置前端开发环境

官网下载安装Mac版本的谷歌浏览器以及VS code代码编辑器,还有在App Store中直接安装Xcode(里面自带git); node.js版本管理器nvm的下载安装如下: 参考B站:https://www.bilibili.com/video/BV1M54y1N7fx/?sp…

【学习AI-相关路程-工具使用-自我学习-jetson模型训练-图片识别-使用模型检测图片-基础样例 (5)】

【学习AI-相关路程-工具使用-自我学习-jetson&模型训练-图片识别-使用模型检测图片-基础样例 (5)】 1 -前言2 -环境说明3 -先行了解(1)整理流程了解(2)了解模型-MobileNet1、MobileNetV2 的主要特性&am…

python源码 PBOCMaster MAC的计算函数及计算过程 2des

注意最后一步要用整个key加密 计算过程: MAC: PBOC-MAC DES算法 密钥 长度16(0x10)字节 57 75 20 4D 69 61 6F 6A 75 6E 40 47 26 44 43 11 初始向量 长度8(0x08)字节 00 00 00 00 00 00 00 00 数据 长度74(0x4A)字节 43 48 45 4E 48 41 4F 2D 50 43 7…

Python股票接口实现量化交易的优势是什么

炒股自动化:申请官方API接口,散户也可以 python炒股自动化(0),申请券商API接口 python炒股自动化(1),量化交易接口区别 Python炒股自动化(2):获取…

MSP430F149实现1.8寸TFT_LCD真彩屏显示

目录 一、功能实现 二、设备准备 三、接线表设计 四、代码实现 五、实现效果 六、代码链接 一、功能实现 实现1.8寸TFT_LCD真彩屏显示。显示数字、图片、字符串等。 二、设备准备 1.TFT_LCD真彩屏(1.8寸) 该真彩屏使用SPI通信。 2.MSP430F149开…

CSRF,SSRF和重放攻击的区别

CSRF是跨站请求伪造攻击,由客户端发起 SSRF是服务器端请求伪造,由服务器发起 重放攻击时将截获的数据包进行重放,达到身份认证等目的 三种是不同的网络安全攻击方式,他们在攻击方式,目标,影响以及防御策略…

微服务CI/CD实践(五)Jenkins Docker 自动化构建部署Node服务

微服务CI/CD实践系列: 微服务CI/CD实践(一)环境准备及虚拟机创建 微服务CI/CD实践(二)服务器先决准备 微服务CI/CD实践(三)gitlab部署及nexus3部署 微服务CI/CD实践(四&#xff09…

【软件设计】常用设计模式--策略模式

软件设计模式(三) 策略模式(Strategy Pattern)1. 概念2. 模式结构3. UML 类图4. 实现方式C# 示例步骤1:定义策略接口步骤2:实现具体策略类步骤3:实现上下文类步骤4:使用策略模式 Jav…