机器学习中参数优化调试方法

1  超参数优化

图片

调参即超参数优化,是指从超参数空间中选择一组合适的超参数,以权衡好模型的偏差(bias)和方差(variance),从而提高模型效果及性能。常用的调参方法有:

  • 人工手动调参

  • 网格/随机搜索(Grid / Random Search)

  • 贝叶斯优化(Bayesian Optimization)

图片

注:超参数 vs 模型参数差异 超参数是控制模型学习过程的(如网络层数、学习率);模型参数是通过模型训练学习后得到的(如网络最终学习到的权重值)。

2  人工调参

手动调参需要结合数据情况及算法的理解,选择合适调参的优先顺序及参数的经验值。

不同模型手动调参思路会有差异,如随机森林是一种bagging集成的方法,参数主要有n_estimators(子树的数量)、max_depth(树的最大生长深度)、max_leaf_nodes(最大叶节点数)等。(此外其他参数不展开说明) 对于n_estimators:通常越大效果越好。参数越大,则参与决策的子树越多,可以消除子树间的随机误差且增加预测的准度,以此降低方差与偏差。对于max_depth或max_leaf_nodes:通常对效果是先增后减的。取值越大则子树复杂度越高,偏差越低但方差越大。

图片

3 网格/随机搜索

图片

  • 网格搜索(grid search),是超参数优化的传统方法,是对超参数组合的子集进行穷举搜索,找到表现最佳的超参数子集。

  • 随机搜索(random search),是对超参数组合的子集简单地做固定次数的随机搜索,找到表现最佳的超参数子集。对于规模较大的参数空间,采用随机搜索往往效率更高。

import numpy as np
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import RandomizedSearchCV
from sklearn.ensemble import RandomForestClassifier# 选择模型 
model = RandomForestClassifier()
# 参数搜索空间
param_grid = {'max_depth': np.arange(1, 20, 1),'n_estimators': np.arange(1, 50, 10),'max_leaf_nodes': np.arange(2, 100, 10)}
# 网格搜索模型参数
grid_search = GridSearchCV(model, param_grid, cv=5, scoring='f1_micro')
grid_search.fit(x, y)
print(grid_search.best_params_)
print(grid_search.best_score_)
print(grid_search.best_estimator_)
# 随机搜索模型参数
rd_search = RandomizedSearchCV(model, param_grid, n_iter=200, cv=5, scoring='f1_micro')
rd_search.fit(x, y)
print(rd_search.best_params_)
print(rd_search.best_score_)
print(rd_search.best_estimator_)

4 贝叶斯优化

贝叶斯优化(Bayesian Optimization) 与网格/随机搜索最大的不同,在于考虑了历史调参的信息,使得调参更有效率。(但在高维参数空间下,贝叶斯优化复杂度较高,效果会近似随机搜索。)

图片

4.1 算法简介

贝叶斯优化思想简单可归纳为两部分:

  • 高斯过程(GP):以历史的调参信息(Observation)去学习目标函数的后验分布(Target)的过程。

  • 采集函数(AC):由学习的目标函数进行采样评估,分为两种过程:1、开采过程:在最可能出现全局最优解的参数区域进行采样评估。2、勘探过程:兼顾不确定性大的参数区域的采样评估,避免陷入局部最优。

4.2 算法流程

for循环n次迭代:采集函数依据学习的目标函数(或初始化)给出下个开采极值点 Xn+1;评估超参数Xn+1得到表现Yn+1;加入新的Xn+1、Yn+1数据样本,并更新高斯过程模型;

图片

"""
随机森林分类Iris使用贝叶斯优化调参
"""
import numpy as np
from hyperopt import hp, tpe, Trials, STATUS_OK, Trials, anneal
from functools import partial
from hyperopt.fmin import fmin
from sklearn.metrics import f1_score
from sklearn.ensemble import RandomForestClassifierdef model_metrics(model, x, y):""" 评估指标 """yhat = model.predict(x)return  f1_score(y, yhat,average='micro')def bayes_fmin(train_x, test_x, train_y, test_y, eval_iters=50):"""bayes优化超参数eval_iters:迭代次数"""def factory(params):"""定义优化的目标函数"""fit_params = {'max_depth':int(params['max_depth']),'n_estimators':int(params['n_estimators']),'max_leaf_nodes': int(params['max_leaf_nodes'])}# 选择模型model = RandomForestClassifier(**fit_params)model.fit(train_x, train_y)# 最小化测试集(- f1score)为目标train_metric = model_metrics(model, train_x, train_y)test_metric = model_metrics(model, test_x, test_y)loss = - test_metricreturn {"loss": loss, "status":STATUS_OK}# 参数空间space = {'max_depth': hp.quniform('max_depth', 1, 20, 1),'n_estimators': hp.quniform('n_estimators', 2, 50, 1), 'max_leaf_nodes': hp.quniform('max_leaf_nodes', 2, 100, 1)}# bayes优化搜索参数best_params = fmin(factory, space, algo=partial(anneal.suggest,), max_evals=eval_iters, trials=Trials(),return_argmin=True)# 参数转为整型best_params["max_depth"] = int(best_params["max_depth"])best_params["max_leaf_nodes"] = int(best_params["max_leaf_nodes"])best_params["n_estimators"] = int(best_params["n_estimators"])return best_params#  搜索最优参数
best_params = bayes_fmin(train_x, test_x, train_y, test_y, 100)
print(best_params)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/114286.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

面试-Redis-缓存雪崩

问:什么是缓存雪崩 ? 答:缓存过期是指设置缓存时都采用了同一过期时间,导致缓存在莫一时刻同时失效,从而请求全部全部打到数据库中,导致数据库压力过大而挂机。 它与缓存击穿的区别是:缓存击穿是一个key…

01、MySQL-------性能优化

目录 一、影响性能的相关因素存储过程: 二、sql优化1>、Mysql系统架构2>、引擎区别: 3>、索引1、什么是索引?联合主键索引理解:索引长度理解:什么是慢查询? 1)、索引理解2)…

Rust 泛型

泛型 Generics泛型详解 使用泛型参数&#xff0c;有一个先决条件&#xff0c;必需在使用前对其进行声明&#xff1a; fn largest<T>(list: &[T]) -> T {该泛型函数的作用是从列表中找出最大的值&#xff0c;其中列表中的元素类型为 T。首先 largest<T> 对…

SpringBoot结合Druid实现SQL监控

1、前言 SpringBoot不用我多介绍了吧&#xff0c;目前后端最流行的框架。后端开发人员最基本的要求。 Druid数据库连接池&#xff0c;出自国内 ”java圣地" 阿里巴巴。 Druid是一个用于大数据实时查询和分析的高容错、高性能开源分布式系统&#xff0c;旨在快速处理大规模…

CentOS 7.9 安装 MySQL 8 配置模板

1. 服务器主机 BIOS 关闭 NUMA 2. 系统版本&#xff1a;CentOS Linux release 7.9.2009 (Core)&#xff1b;MySQL 8.0.22 3. 修改系统核心参数 # 编辑 /etc/sysctl.conf 文件&#xff0c;添加以下参数&#xff1a; fs.aio-max-nr524288 vm.swappiness0 net.ipv6.conf.all.di…

新年学新语言Go之五

一、前言 Go虽然不算是面向对象语言&#xff0c;但它支持面向对象一些特性&#xff0c;面向接口编程是Go一个很重要的特性&#xff0c;而Go的接口与Java的接口区别很大&#xff0c;Go的接口比较复杂&#xff0c;这里仅用一个最简单例子做介绍&#xff0c;复杂的我也还没学。 …

filebeat(8.9.0)采集日志到logstash,由logstash发送的es

filebeat采集日志到logstash&#xff0c;由logstash发送的es 下载并配置filebeat下载配置logback.xml logstash配置 下载并配置filebeat 下载 参考 配置 filebeat.inputs: - type: filestreamenabled: truepaths:# 日志文件目录- D:\modellog\elkdemo\*\*.logparsers:# 多…

LeetCode2409——统计共同度过的日子数

博主的解法过于冗长&#xff0c;是一直对着不同的案例debug修改出来的&#xff0c;不建议学习。虽然提交成功了&#xff0c;但是自己最后都不知道写的是啥了哈哈哈。 package keepcoding.leetcode.leetcode2409; /*Alice 和 Bob 计划分别去罗马开会。给你四个字符串 arriveA…

分发糖果[困难]

优质博文&#xff1a;IT-BLOG-CN 一、题目 n个孩子站成一排。给你一个整数数组ratings表示每个孩子的评分。你需要按照以下要求&#xff0c;给这些孩子分发糖果&#xff1a; 【1】每个孩子至少分配到1个糖果。 【2】相邻两个孩子评分更高的孩子会获得更多的糖果。 请你给每个孩…

浅谈 docker run 命令中的 -i -t 和 -d 选项

以 docker Ubuntu 镜像为例&#xff0c;ubuntu镜像启动时默认执行的命令是"/bin/bash"。 文章目录 不带任何选项带 -i 选项带 -i 和 -t 选项-d 选项 不带任何选项 rootubuntu20:~# docker run ubuntu:20.04 rootubuntu20:~# docker ps CONTAINER ID IMAGE …

操作系统【OS】进程的控制【进程的创建、终止、阻塞、唤醒】

定义和过程 对应事件 创建 允许一个进程创建另一个进程允许子进程继承父进程所拥有的资源创建进程的过程如下&#xff1a; 申请一个空白的 PCB&#xff0c;并向 PCB 中填写一些控制和管理进程的信息&#xff0c;比如进程的唯一标识等&#xff1b;为该进程分配运行时所必需的…

源码安装Openlava 4.0

安装需求 基本硬件配置建议&#xff1a;CPU 4核或以上&#xff08;LSF 没有最低 CPU 需求&#xff0c;此处只是建议&#xff09;内存 8G或以上&#xff08; 当没有作业在运行时&#xff0c; Linux x86-64 上集群中的 LSF 守护程序将使用大约 488 MB 内存。&#xff09;交换空…

DDOS攻击的有效防护方式有哪些?

DDoS攻击简介&#xff1a; DDoS攻击&#xff0c;即分布式拒绝服务攻击&#xff08;Distributed Denial of Service&#xff09;&#xff0c;是一种网络攻击&#xff0c;旨在通过向目标服务器发送大量恶意请求&#xff0c;使服务器资源耗尽&#xff0c;无法满足合法用户的需求&a…

网络协议--ARP:地址解析协议

4.1 引言 本章我们要讨论的问题是只对TCP/IP协议簇有意义的IP地址。数据链路如以太网或令牌环网都有自己的寻址机制&#xff08;常常为48 bit地址&#xff09;&#xff0c;这是使用数据链路的任何网络层都必须遵从的。一个网络如以太网可以同时被不同的网络层使用。例如&#…

git创建与合并分支

文章目录 创建与合并分支分支管理的概念实际操作 解决冲突分支管理策略Bug分支Feature分支多人协作 创建与合并分支 分支管理的概念 分支在实际中有什么用呢&#xff1f;假设你准备开发一个新功能&#xff0c;但是需要两周才能完成&#xff0c;第一周你写了50%的代码&#xf…

shell的case选择

shell笔记 case语法结构 case语法结构 Caseesac语句与其他语言中的switch.case 语句类似&#xff0c;是一种多分支选择结构。case语句匹配一个值或一个模式&#xff0c;如果匹配成功&#xff0c;执行相匹配的命令。 case语法结构&#xff1a; case expr in #expr为表达式&am…

《动手学深度学习 Pytorch版》 9.6 编码器-解码器架构

为了处理这种长度可变的输入和输出&#xff0c; 可以设计一个包含两个主要组件的编码器-解码器&#xff08;encoder-decoder&#xff09;架构&#xff1a; 编码器&#xff08;encoder&#xff09;&#xff1a;它接受一个长度可变的序列作为输入&#xff0c;并将其转换为具有固定…

React +AntD + From组件重复提交数据(已解决)

开发场景&#xff1a; react Hooks andt 提交form表单内容给数据库(使用antd的form组件) 问题描述 提交是异步的&#xff0c;请提交方式是POST 方式 提交表单内容给后端&#xff0c;却产生了两次提交记录&#xff08;当然&#xff0c;数据新增了两条数据&#xff09;。可以…

基于WebRTC的程序因虚拟内存不足导致闪退问题的排查以及解决办法的研究

目录 1、WebRTC简介 2、问题现象描述 3、将Windbg附加到目标进程上分析 3.1、Windbg没有附加到主程序进程上&#xff0c;没有感知到异常或中断 3.2、Windbg感知到了中断&#xff0c;中断在DebugBreak函数调用上 3.3、32位进程用户态虚拟地址和内核态虚拟地址的划分 …