模型的选择与调优(网格搜索与交叉验证)

1、为什么需要交叉验证

  • 交叉验证目的:为了让被评估的模型更加准确可信

2、什么是交叉验证(cross validation)

  • 交叉验证:将拿到的训练数据,分为训练和验证集。以下图为例:将数据分成4份,其中一份作为验证集。然后经过4次(组)的测试,每次都更换不同的验证集。即得到4组模型的结果,取平均值作为最终结果。又称4折交叉验证。
    • 训练集:训练集+验证集
    • 测试集:测试集
      在这里插入图片描述
      问题:那么这个只是对于参数得出更好的结果,那么怎么选择或者调优参数呢?

3、超参数搜索-网格搜索(Grid Search)

通常情况下,有很多参数是需要手动指定的(如k-近邻算法中的K值),这种叫超参数。但是手动过程繁杂,网格搜索帮我们实现了这个调参过程,首先需要对模型预设几种超参数组合,每组超参数都采用交叉验证来进行评估,最后选出最优参数组合建立模型。
在这里插入图片描述

3.1、模型选择与调优 API

  • sklearn.model_selection.GridSearchCV(estimator, param_grid=None,cv=None)
    • 对估计器的指定参数值进行详尽搜索
    • estimator:估计器对象
    • param_grid:估计器参数(dict){“n_neighbors”:[1,3,5]}
    • cv:指定几折交叉验证
    • fit:输入训练数据
    • score:准确率
  • 结果分析:
    • bestscore:在交叉验证中验证的最好结果_
    • bestestimator:最好的参数模型
    • cvresults:每次交叉验证后的验证集准确率结果和训练集准确率结果

3.2、网格搜索与交叉验证代码

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler"""
用KNN算法对鸢尾花进行分类,添加网格搜索和交叉验证
:return:
"""
# 1)获取数据
iris = load_iris()# 2)划分数据集
x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.3, random_state=22)# 3)特征工程:标准化
transfer = StandardScaler()
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)# 4)KNN算法预估器
estimator = KNeighborsClassifier()# 加入网格搜索与交叉验证
# 参数准备
param_dict = {"n_neighbors": [1, 2, 3, 4, 5, 6, 7, 8, 9, 11]}
estimator = GridSearchCV(estimator, param_grid=param_dict, cv=10)
estimator.fit(x_train, y_train)# 5)模型评估
# 方法1:直接比对真实值和预测值
y_predict = estimator.predict(x_test)
print("y_predict:\n", y_predict)
print("直接比对真实值和预测值:\n", y_test == y_predict)# 方法2:计算准确率
score = estimator.score(x_test, y_test)
print("准确率为:\n", score)# 最佳参数:best_params_
print("最佳参数:\n", estimator.best_params_)
# 最佳结果:best_score_
print("最佳结果:\n", estimator.best_score_)
# 最佳估计器:best_estimator_
print("最佳估计器:\n", estimator.best_estimator_)
# 交叉验证结果:cv_results_
print("交叉验证结果:\n", estimator.cv_results_)

在这里插入图片描述

4、facebook 签到位置预测

在这里插入图片描述
在这里插入图片描述

  • 数据介绍:将根据用户的位置,准确性和时间戳预测用户正在查看的业务。
  • train.csv
    • row_id:登记事件的ID
    • xy:坐标
    • 准确性:定位准确性
    • 时间:时间戳
    • place_id:业务的ID,这是您预测的目标

官网:https://www.kaggle.com/navoshta/grid-knn/data

4.1、流程分析

对于数据做一些基本处理(这里所做的一些处理不一定达到很好的效果,我们只是简单尝试,有些特征我们可以根据一些特征选择的方式去做处理)

1、缩小数据集范围 DataFrame.query()(选择性处理!)
2、删除没用的日期数据 DataFrame.drop(可以选择保留)
3、将签到位置少于n个用户的删除

place_count = data.groupby('place_id').count()
tf = place_count[place_count.row_id > 3].reset_index()
data = data[data['place_id'].isin(tf.place_id)]

4、分割数据集
5、标准化处理
6、k-近邻预测

4.2、代码

import pandas as pd
# 1、获取数据
data = pd.read_csv("train.csv")
data.head()

在这里插入图片描述

# 1)处理时间特征
time_value = pd.to_datetime(data["time"], unit="s")
date = pd.DatetimeIndex(time_value)
data["day"] = date.day
data["weekday"] = date.weekday
data["hour"] = date.hour
data.head()

在这里插入图片描述

# 2)过滤签到次数少的地点
place_count = data.groupby("place_id").count()["row_id"]
data_final = data[data["place_id"].isin(place_count[place_count > 3].index.values)]
data_final.head()

在这里插入图片描述

# 筛选特征值和目标值
x = data_final[["x", "y", "accuracy", "day", "weekday", "hour"]]
y = data_final["place_id"]

在这里插入图片描述

# 数据集划分
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y)
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV# 3)特征工程:标准化
transfer = StandardScaler()
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)# 4)KNN算法预估器
estimator = KNeighborsClassifier()# 加入网格搜索与交叉验证
# 参数准备
param_dict = {"n_neighbors": [3, 5, 7, 9]}
estimator = GridSearchCV(estimator, param_grid=param_dict, cv=3)
estimator.fit(x_train, y_train)# 5)模型评估
# 方法1:直接比对真实值和预测值
y_predict = estimator.predict(x_test)
print("y_predict:\n", y_predict)
print("直接比对真实值和预测值:\n", y_test == y_predict)# 方法2:计算准确率
score = estimator.score(x_test, y_test)
print("准确率为:\n", score)# 最佳参数:best_params_
print("最佳参数:\n", estimator.best_params_)
# 最佳结果:best_score_
print("最佳结果:\n", estimator.best_score_)
# 最佳估计器:best_estimator_
print("最佳估计器:\n", estimator.best_estimator_)
# 交叉验证结果:cv_results_
print("交叉验证结果:\n", estimator.cv_results_)

这个结果数据量比较大,毕竟两千万训练数据了,各位可自行试验及调参;

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/109920.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

VulnHub lazysysadmin

一、信息收集 1.nmap扫描开发端口 开放了:22、80、445 访问80端口,没有发现什么有价值的信息 2.扫描共享文件 enum4linux--扫描共享文件 使用: enum4linux 192.168.103.182windows访问共享文件 \\192.168.103.182\文件夹名称信息收集&…

UWB安全数据通讯STS-加密、身份认证

DW3000系列才能支持UWB安全数据通讯,DW1000不支持 IEEE 802.15.4a没有数据通讯安全保护机制,IEEE 802.15.4z中指定的扩展得到增强(在PHY/RF级别):增添了一个重要特性“扰频时间戳序列(STS)”&a…

软件开发“自我毁灭”的七宗罪

软件开发是一门具有挑战性的学科,它建立在数以百万计的参数、变量、库以及更多必须绝对正确的因素之上。即便是一个字符不合适,整个堆栈也会随之瓦解。 多年来,软件开发团队已经想出了一些完成工作的规则。从复杂的方法论到新兴的学科和哲学…

百度地图高级进阶开发:圆形区域周边搜索地图监听事件(覆盖物重叠显示层级\图像标注监听事件、setZIndex和setTop方法)

百度地图API 使用百度地图API添加多覆盖物渲染时,会出现覆盖物被相互覆盖而导致都无法触发它们自己的监听;在百度地图API里,map的z-index为0,但是触发任意覆盖物的监听如click时也必定会触发map的监听; 项目需求 在…

最详细STM32,cubeMX 点亮 led

这篇文章将详细介绍 如何在 stm32103 板子上点亮一个LED. 文章目录 前言一、开发环境搭建。二、LED 原理图解读三、什么是 GPIO四、cubeMX 配置工程五、解读 cubeMX 生成的代码六、延时函数七、控制引脚状态函数点亮 LED 八、GPIO 的工作模式九、为什么使用推挽输出驱动 LED总结…

基于svg+js实现简单动态时钟

实现思路 创建SVG容器&#xff1a;首先&#xff0c;创建一个SVG容器元素&#xff0c;用于容纳时钟的各个部分。指定SVG的宽度、高度以及命名空间。 <svg width"200" height"200" xmlns"http://www.w3.org/2000/svg"><!-- 在此添加时钟…

排查手机应用app微信登录问题不跳转失败原因汇总及其解决方案

经过最近我发的文章,我个人觉得解决了不少小问题,因为最近很小白的问题已经没有人私聊问我了,我总结了一下排查手机应用app微信登录问题不跳转失败的原因汇总及其解决方案在这篇文章中,分析微信登录不跳转的原因,并提供解决方案。希望通过这篇文章,能够帮助大家顺利解决这…

紫光同创FPGA实现UDP协议栈网络视频传输,带录像和抓拍功能,基于YT8511和RTL8211,提供2套PDS工程源码和技术支持

目录 1、前言免责声明 2、相关方案推荐我这里已有的以太网方案紫光同创FPGA精简版UDP方案紫光同创FPGA带ping功能UDP方案紫光同创FPGA精简版UDP视频传输方案 3、设计思路框架OV5640摄像头配置及采集数据缓冲FIFOUDP协议栈详解MAC层发送MAC发送模式MAC层接收ARP发送ARP接收ARP缓…

DCDC Buck电路地弹造成的影响

很多读者都应该听过地弹&#xff0c;但是实际遇到的地弹的问题应该很少。本案例就是一个地弹现象导致电源芯片工作不正常的案例。 問題描述 如下图1 &#xff0c;产品其中一个供电是12V转3.3V的电路&#xff0c;产品发货50K左右以后&#xff0c;大约有1%的产品无法启动&#…

4.MidBook项目经验之MonogoDB和easyExcel导入导出

1.数据字典(固定的数据,省市级有层级关系的) //mp表如果没有这个字段,防报错,eleUI需要这个字段TableField(exist false) //父根据id得到子数据 ,从controller开始自动生成代码-->service//hasChildren怎么判断,只需要判断children的parentid的count数量>0就可以了//优化…

vue.js - 断开发送的请求,解决接口重复请求数据错误问题(vue中axios多次相同请求中断上一个)

描述 进入页面时第一个接口还在请求,立即切换tab请求第二个接口。但是第二个接口响应比第一个接口响应快,页面展示的时第一个接口的数据,如图: 解决方法 判断如果是相同的接

visual studio安装时候修改共享组件、工具和SDK路径方法

安装了VsStudio后,如果自己修改了Shared路径&#xff0c;当卸载旧版本&#xff0c;需要安装新版本时发现&#xff0c;之前的Shared路径无法进行修改&#xff0c;这就很坑爹了&#xff0c;因为我运行flutter程序的时候&#xff0c;报错找不到windows sdk的位置&#xff0c;所以我…

怎么把flac音频变为mp3?

怎么把flac音频变为mp3&#xff1f;FLAC音频格式在许多平台和应用程序中都得到支持和应用。FLAC音频格式被广泛支持和应用。许多平台、设备和应用程序都支持FLAC格式&#xff0c;如Windows、macOS和Linux操作系统、各种音乐播放器软件、智能手机和平板电脑、在线音乐平台和流媒…

c++小知识

内联函数 inline 用来替换宏函数 不能分文件编辑 在c语言中#define NULL 0在c中使用nullptr表示空指针class内存的大小计算规则使用的是内存对齐 没有成员&#xff0c;但是还有1个字节&#xff0c;我们使用这个来标记他是个类 类成员函数不存在于类中 为什么每个对象使用的…

基于Linux安装Hive

Hive安装包下载地址 Index of /dist/hive 上传解压 [rootmaster opt]# cd /usr/local/ [rootmaster local]# tar -zxvf /opt/apache-hive-3.1.2-bin.tar.gz重命名及更改权限 mv apache-hive-3.1.2-bin hivechown -R hadoop:hadoop hive配置环境变量 #编辑配置 vi /etc/pro…

云安全—云计算基础

0x00 前言 学习云安全&#xff0c;那么必然要对云计算相关的内容进行学习和了解&#xff0c;所以云安全会分为两个部分来进行&#xff0c;首先是云计算先关的内容。 0x01 云计算 广泛传播 云计算最早大范围传播是2006年&#xff0c;8月&#xff0c;在圣何塞【1】举办的SES&a…

SpringMVC的响应处理

目录 传统同步业务数据的响应 请求资源转发 请求资源重定向 响应数据模型 直接回写数据给客户端 前后端分离异步业务数据响应 在前面的文章中&#xff0c;我们已经介绍了Spring接收请求的部分&#xff0c;接下来看Spring如何给客户端响应数据 传统同步业务数据的响应 准…

Redis数据结构之quicklist

前言 为了节省内存&#xff0c;Redis 推出了 ziplist 数据类型&#xff0c;采用一种更加紧凑的方式来存储 hash、zset 元素。因为查找的时间复杂度是 O(N)&#xff0c;且写入需要重新分配内存&#xff0c;所以它仅适用于小数据量的存储&#xff0c;而且它还存在 连锁更新 的风…

34 机器学习(二):数据准备|knn

文章目录 数据准备数据下载数据切割转换器估计器 kNN正常的流程网格多折交叉训练原理讲解距离度量欧式距离(Euclidean Distance)曼哈顿距离(Manhattan Distance)切比雪夫距离 (Chebyshev Distance)还有一些自定义的距离 就请读者自行研究 再识K-近邻算法API选择n邻居的思辨总结…

C# Winform编程(5)菜单栏和工具栏

菜单和菜单组件 添加菜单编辑菜单菜单栏和工具栏 添加菜单 将MenuStrip控件拖拽到Form窗体顶部添加菜单 编辑菜单 添加菜单项&#xff0c;编辑菜单属性等功能。 右键单击已添加的菜单项可以弹出右键菜单&#xff1a; 可以设置菜单图标&#xff0c;使能菜单&#xff0c;显示…