Keras使用sklearn中的交叉验证和网格搜索

Keras是Python在深度学习领域非常受欢迎的第三方库,但Keras的侧重点是深度学习,而不是所以的机器学习。事实上,Keras力求极简主义,只专注于快速、简单地定义和构建深度学习模型所需要的内容。Python中的scikit-learn是非常受欢迎的机器学习库,它基于Scipy,用于高效的数值计算。scikit-learn是一个功能齐全的通用机器学习库,并提供了许多在开发深度学习过程中非常有帮助的方法。例如scikit-learn提供了很多用于选择模型和对模型调参的方法,这些方法同样适用于深度学习。

Keras提供了一个Wrapper,将Keras的深度学习模型包装成scikit-learn中的分类模型或回归模型,以便于使用scikit-learn中的方法和函数。对于深度学习模型的包装是通过KerasClassifier(分类模型)和KerasRegressor(回归模型)来实现的。KerasClassifier和KerasRegressor类使用参数build_fn,指定用来创建模型的函数的名称。

Keras的一般构建流程:

model = Sequential() # 定义模型
model.add(Dense(units=64, activation='relu', input_dim=100)) # 定义网络结构
#第一层网络:输出尺寸64,输入尺寸100,activation激活函数relu
model.add(Dense(units=10, activation='softmax')) # 定义网络结构
#第二层网络:输出尺寸10,输入是上一层的输出尺寸64,activation激活函数softmax
model.compile(loss='categorical_crossentropy', # 定义loss函数、优化方法、评估标准optimizer='sgd',metrics=['accuracy'])
#输入训练样本和标签,迭代5次,每次迭代32个数据
model.fit(x_train, y_train, epochs=5, batch_size=32) # 训练模型
loss_and_metrics = model.evaluate(x_test, y_test, batch_size=128) # 评估模型
classes = model.predict(x_test, batch_size=128) # 使用训练好的数据进行预测

参数意义:

keras.layers.Dense(units, activation=None, use_bias=True, kernel_initializer='glorot_uniform', bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None)

units: 正整数,输出空间维度。
activation: 激活函数。 若不指定,则不使用激活函数 (即,「线性」激活: a(x) = x)。
use_bias: 布尔值,该层是否使用偏置向量。
kernel_initializer: kernel 权值矩阵的初始化器。
bias_initializer: 偏置向量的初始化器。
kernel_regularizer: 运用到 kernel 权值矩阵的正则化函数 。
bias_regularizer: 运用到偏置向的的正则化函数 。
activity_regularizer: 运用到层的输出的正则化函数 。
kernel_constraint: 运用到 kernel 权值矩阵的约束函数 。
bias_constraint: 运用到偏置向量的约束函数。

Keras调用scikit-learn实现交叉验证:

from keras.models import Sequential
from keras.layers import Dense
import numpy as np
import pandas as pd
from sklearn.model_selection import cross_val_score, KFold
from keras.wrappers.scikit_learn import KerasClassifierdef creat_model():# 构建模型model = Sequential()model.add(Dense(units=12, input_dim=11, activation='relu'))model.add(Dense(units=8, activation='relu'))model.add(Dense(units=1, activation='sigmoid'))# 模型编译model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])return model# 导入数据
data = pd.read_csv('data.csv',encoding='gbk')# 删除id列
data.drop('客户编号',axis=1,inplace=True)X, Y = data.values[:,:-1], data.values[:,-1] # Keras调用sklearn
model = KerasClassifier(build_fn=creat_model, epochs=150, batch_size=10, verbose=0)# 10折交叉验证
kfold = KFold(n_splits=10, shuffle=True, random_state=10)
result = cross_val_score(model, X, Y, cv=kfold)

 Keras调用scikit-learn实现模型调参

在构建深度学习模型时,如何配置一个最优模型一直是进行一个项目的重点。在机器学习中,可以通过算法自动调优这些配置参数,在这里将通过Keras的包装类,借助scikit-learn的网格搜索算法评估神经网络模型的不同配置,并找到最佳评估性能的参数组合。creat_model()函数被定义为具有两个默认值的参数(optimizer和init)的函数,创建模型后,定义要搜索的参数的数值数组,包括优化器(optimizer)、权重初始化方案(init)、epochs和batch_size。

在scikit-learn中的GridSearchCV需要一个字典类型的字段作为需要调整的参数,默认采用3折交叉验证来评估算法,由于4个参数需要进行调参,因此将会产生4✖️3个模型。
Keras调用scikit-learn实现GridSearchCV网格搜索:

from keras.models import Sequential
from keras.layers import Dense
import numpy as np
import pandas as pd
from sklearn.model_selection import GridSearchCV
from keras.wrappers.scikit_learn import KerasClassifierdef creat_model(optimizer='adam,init='glorot_uniform'):# 构建模型model = Sequential()model.add(Dense(units=12, input_dim=11,kernel_initializer=init, activation='relu'))model.add(Dense(units=8, kernel_initializer=init, activation='relu'))model.add(Dense(units=1, kernel_initializer=init, activation='sigmoid'))# 模型编译model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])return model# 导入数据
data = pd.read_csv('data.csv',encoding='gbk')# 删除id列
data.drop('客户编号',axis=1,inplace=True)X, Y = data.values[:,:-1], data.values[:,-1] # Keras调用sklearn
model = KerasClassifier(build_fn=creat_model, verbose=0)# 构建需要调整的参数
param_gird = {}
param_grid['optimizer'] = ['rmsprop','adam']
param_grid['init'] = ['glorot_uniform', 'normal', 'uniform']
param_gird['epochs'] = [50, 100, 150, 200]
param_gird['batch_size'] = [5, 10, 20]# 调参
grid = GridSearchCV(estimator=model, param_gird=param_grid)
result = grid.fit(X, Y)# 输出结果
print('Best: %f using %s' % (result.best_score_, result.best_params_))

关于Epochs和batch_size的解释?

Epochs是神经网络训练过程中的一个重要超参数,定义为向前和向后传播中所有批次的单次训练迭代。简单说,一个Epoch是将所有的数据输入网络完成一次向前计算及反向传播。在训练过程中,数据会被“轮”多少次,即应当完整遍历数据集多少次(一次为一个Epoch)。如果Epoch数量太少,网络有可能发生欠拟合(即对于定型数据的学习不够充分);如果Epoch数量太多,则有可能发生过拟合(即网络对定型数据中的“噪声”而非信号拟合)。所以,选择适当的Epoch数量需要在充分训练和避免过拟合之间找到平衡。
 

假设我们有1000个数据样本,每次我们送入10个数据进行训练(也就是batch_size为10)。那么完成一个Epoch,我们需要进行100次迭代(也就是100次前向传播和100次反向传播)。具体来说,我们需要将所有的数据都送入神经网络进行一次前向传播和反向传播,所以一次Epoch相当于所有数据集/batch size=N次迭代。
 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/236448.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【单调栈】LeetCode2030:含特定字母的最小子序列

作者推荐 map|动态规划|单调栈|LeetCode975:奇偶跳 涉及知识点 单调栈 题目 给你一个字符串 s ,一个整数 k ,一个字母 letter 以及另一个整数 repetition 。 返回 s 中长度为 k 且 字典序最小 的子序列,该子序列同时应满足字母 letter 出…

SpringMVC01

SpringMVC 1. 学习⽬标2. 什么叫MVC?3. SpringMVC 框架概念与特点4. SpringMVC 请求流程5. Spring MVC 环境搭建6. URL 地址映射配置7. 参数绑定8. JSON 数据开发JSON普通数组步骤1:pom.xml添加依赖步骤2: 修改配置⽂件步骤3. 注解使⽤ 1. 学习⽬标 2. 什…

【介质】DWPD 每日整盘写入次数 衡量SSD固态硬盘寿命

SSD DWPD什么意思? 经常看到厂商发布的SSD产品有DWPD这个指标,DWPD是什么意思? DWPD DWPD,(Diskful Writes Per Day),每日整盘写入次数,指在预期寿命内可每日完整写入SSD固态硬盘所有容量的次数。 也有…

树莓派,opencv,Picamera2利用舵机云台追踪特定颜色对象(PID控制)

一、需要准备的硬件 Raspiberry 4b两个SG90 180度舵机(注意舵机的角度,最好是180度且带限位的,切勿选360度舵机)二自由度舵机云台(如下图)Raspiberry CSI 摄像头 组装后的效果: 二、项目目标…

Python 中如何编写类型提示

哈喽大家好,我是咸鱼 我们知道 Python 是一门具有动态特性的语言,在编写 Python 代码的时候不需要显式地指定变量的类型 这样做虽然方便,但是降低了代码的可阅读性,在后期 review 代码的时候容易对变量的类型产生混淆&#xff0…

解决文件导出过大->java压缩zip文件--封装工具类

阿丹: 在业务逻辑中的数据存在一部分业务场景,在导出文件或者视频的时候需要将文件暂存在服务器上再上传到oss对象存储或者fastdfs中让用户来下载使用。但是出现的问题就是如果目标文件过大,文件的上传云端和下载本地就会时间拉长&#xff0c…

使用python读取EXCEL放假日历并制作订阅文件

前言 不想升级IOS,苦于找不到新的日历订阅url,小菜鸡百度来百度去发现ics这东西可以自己做一个,惊喜于看到了这篇文章--使用python获取日历信息并制作订阅文件_https: //github.com/lk-itween/calendar-CSDN博客 感谢作者大大。就想自己写一…

滑动窗口双指针

力扣 209 找出该数组中满足其总和大于等于 target 的长度最小的 连续子数组 [numsl, numsl1, ..., numsr-1, numsr] ,并返回其长度。如果不存在符合条件的子数组,返回 0 。 类似窗口滑动 j代表的是窗口的结束位置 i表示开始位置 在while循环中是寻找最…

服务器数据恢复-昆腾存储StorNext文件系统下raid5数据恢复案例

服务器数据恢复环境: 昆腾某型号存储,StorNext文件存储系统。 共有9个分别配置了24块磁盘的磁盘柜,其中8个磁盘柜存放普通数据,1个磁盘柜存放元数据。 存放元数据的磁盘柜中的24块磁盘组建了8组RAID1阵列和1组4盘RAID10阵列&#…

NCV8460ADR2G在汽车和工业应用中高压侧驱动如何破?

NCV8460ADR2G是一款完全保护的高压侧驱动器,可用于开关各种负载,如灯泡、电磁阀和其他致动器。该器件可以通过有源电流限制和高温关断针对过载情况进行内部保护。 诊断状态输出引脚提供了高温以及开关状态开路负载情况的数字故障指示。 特性:…

Nginx的location路径与proxy_pass匹配规则

location路径与proxy_pass匹配规则 路径替换总结当访问地址是 http://127.0.0.1/api/user/list注意:location后斜杆与proxy_pass后斜杆"/"问题,最好要么两者都加斜杆,要么都不加 路径替换 配置proxy_pass时,可以实现URL路径的部分…

22 Vue3中使用v-for遍历对象

概述 使用v-for遍历对象在真实的开发中比较少见,了解即可。 对象我更喜欢统一称之为字典,假如你哪天发现我在某个前端的教程中把对象叫做字典,请你知道这两个是同一个玩意儿。 所谓字典,就是一种key-value类型的结构的统称。 …

队列(C语言版)

一.队列的概念及结构 队列:只允许在一端进行插入数据操作,在另一端进行删除数据操作的特殊线性表,队列具有 先进先出 FIFO(First In First Out) 入队列:进行插入操作的一端称为 队尾 出队列:进行删除操作的一端称为…

网络安全之Linux环境配置及Linux基础知识讲解<三>

目录 一.下载安装Vmware二.下载安装Kali三.Linux目录结构四.Linux文件属性五.文件目录管理六.vim编辑器 一.下载安装Vmware Vmware官网:https://www.vmware.com 二.下载安装Kali Kali包含数百种工具,可用于各种信息安全任务,例如渗透测试、…

vue中判断应该import 哪个js或css

import Vue from vue import App from ./App.vueif (window.myFlag true) {import(./jsFile1.js).then((module) > {// 执行导入的jsFile1.js模块的代码module.default.init()}) } else {import(./jsFile2.js).then((module) > {// 执行导入的jsFile2.js模块的代码modul…

vue导出element表格,xlsx和xlsx-style生成xlsx文件并修改样式

1.下载依赖 npm install xlsx --save npm install file-saver --save npm install xlsx-style --save2.先修改xlsx-style的源码,一旦引入xlsx-style则会报错 xlsx-style使用中常见问题及解决办法: xlsx-style使用中常见问题及解决办法-CSDN博客 在\n…

SpringBoot 多环境开发配置文件

在开发过程中,往往开发环境和生产环境需要不同的配置。为了兼容两种运行环境,提高开发效率,可以使用多环境开发配置文件。 配置文件结构大概是这样: application.yml -主启动配置文件(用于控制使用哪种环境配…

Flink系列之:Apache Kafka SQL 连接器

Flink系列之:Apache Kafka SQL 连接器 一、Apache Kafka SQL 连接器二、依赖三、创建Kafka 表四、可用的元数据五、连接器参数六、特性七、Topic 和 Partition 的探测八、起始消费位点九、有界结束位置十、CDC 变更日志(Changelog) Source十一…

Java:获取当前线程的线程组

代码示例: package com.thb;public class Demo4 {public static void main(String[] args) {ThreadGroup threadGroup Thread.currentThread().getThreadGroup();System.out.println(threadGroup.getName());} }运行输出:

“2024山西智博会”由中国人工智能学会和省科学技术协会联合主办

近日,山西省政府新闻办近日举行了“山西加快转型发展”系列主题新闻发布会的第六场发布会,同时也是“推动数字经济发展壮大”专场发布会。在发布会上,省委、省政府强调了数字经济的重要性,并将其作为重组要素资源、重塑经济结构、…