分批次训练和评估神经网络模型

【背景】 

训练神经网络模型的时候,特征组合太多,电脑的资源会不足,所以采用分批逐步进行。已经处理过的批次保存下来,在下一次跳过,只做新加入的批次训练。

选择最优模型组合在中间结果的范围内选择,这样能保证所有的特征都能得到组合,所有的组合都能得到训练和评估。

【流程】

+-------------------------------------+
|          开始 (Start)               |
+-------------------------------------+|v
+-------------------------------------+
| 读取中间结果 (loss_records)          |
+-------------------------------------+|v
+-------------------------------------+
| 计算总的特征组合数量               |
| (total_combinations)               |
+-------------------------------------+|v
+-------------------------------------+
| 计算批次数量 (num_batches)          |
+-------------------------------------+|v
+-------------------------------------+
| 初始化进度条                       |
+-------------------------------------+|v
+-------------------------------------+
| 清理多余记录                        |
| (Clean extra records)               |
+-------------------------------------+|v
+-------------------------------------+
| 遍历每个批次 (for each batch)       |
+-------------------------------------+|v
+-------------------------------------+
| 获取当前批次特征组合和数据          |
+-------------------------------------+|v
+-------------------------------------+
| 检查当前批次是否已处理              |
| (if batch in loss_records)          |
+-------------------+-----------------+
|       否          |        是       |
|                   |                 |
v                   |                 |
+-------------------------------------+|
| 调用 train_and_evaluate_torch        |
+-------------------------------------+||                      |v                      |
+-------------------------------------+|
| 更新所有评估结果                    | |
+-------------------------------------+ ||                     | vv                     +-------------------------------------+
+-------------------------------------+| 跳过已处理的批次,更新评估结果    |
| 保存中间结果                        |+-------------------------------------+
| (save intermediate results)         |
+-------------------------------------+|v
+-------------------------------------+
| 更新进度条                          |
+-------------------------------------+|v
+-------------------------------------+
| 所有批次处理完成                    |
| (All batches processed)             |
+-------------------------------------+|v
+-------------------------------------+
| 保存最佳模型和特征组合到Excel        |
| (save_result_to_excel)              |
+-------------------------------------+|v
+-------------------------------------+
|               结束 (End)            |
+-------------------------------------+

【需求】

读取中间结果
执行特征工程
遍历传入的特征组合

 对比中间结果和新传入的特征组合,
 找出和新传入的特征组合的差异,包括新增的和不再用的
 执行训练和评估,针对新增的,同步中间数据,中间结果中也包括预测值和模型参数(因为我希望从中选出最优模型,并记录,其中也包括参数信息和预测值)
 从最新的评估数据(包括新的和中间结果中的), 选出最优的特征组合,保存到excel 

 【代码】

import os
import json
import pandas as pd
from tqdm import tqdm
import logging# 读取中间结果以防程序中途停止
loss_records = {}
if os.path.exists(loss_records_file):try:with open(loss_records_file, "r") as f:loss_records = json.load(f)print('~~~~~~~~从中间文件中读取到的loss_records:', loss_records)# 确保键是字符串,并转换回元组形式loss_records = {deserialize_features(k): v for k, v in loss_records.items()}print('~~~~~~~~转换回元组形式的loss_records:', loss_records)print("成功加载 loss_records.json")except json.JSONDecodeError as e:print(f"JSONDecodeError: {e}. 重置 loss_records.json 文件内容。")loss_records = {}with open(loss_records_file, "w") as f:json.dump(loss_records, f)# 获取所有特征组合的总数
total_combinations = len(feature_combinations)# 计算批次数量
num_batches = (total_combinations + combination_batch_size - 1) // combination_batch_size# 进度条初始化
pbar = tqdm(total=total_combinations, desc='特征组合训练进度', position=0, leave=True)
all_evaluation_results = []
new_feature_set = set(feature_combinations)# 删除 loss_records 中多余的记录
loss_records = {k: v for k, v in loss_records.items() if deserialize_features(k) in new_feature_set}
print('Cleaned loss_records:', loss_records)for batch_index in range(num_batches):start = batch_index * combination_batch_sizeend = min(start + combination_batch_size, total_combinations)current_batch = feature_combinations[start:end]current_normalized_data = normalized_data[start:end]print('current_batch: ', current_batch)print('loss_records: ', loss_records)# 检查当前批次是否已处理过if all(features in loss_records for features in current_batch):# 更新进度条pbar.update(len(current_batch))print('跳过已经处理过的批次')# 将已处理过的结果添加到所有评估结果中for features in current_batch:serialized_features = serialize_features(features)if serialized_features in loss_records:results = loss_records[serialized_features]all_evaluation_results.append({'features': features,'mse': results['MSE'],'mae': results['MAE'],'r2': results['R2']})continueprint('----没有跳过----已经处理过的批次')# 调用 train_and_evaluate_torch 函数处理当前批次的特征组合evaluation_results = train_and_evaluate_torch(current_batch, current_normalized_data, param_model, scaler_close, evaluation_results, n, data_obj, parameter_period, loss_records)all_evaluation_results.extend(evaluation_results)# 保存中间结果for features in current_batch:serialized_features = serialize_features(features)print(f'Serializing features: {features} -> {serialized_features}')# 提取结果并保存results = next(item for item in evaluation_results if item['features'] == features)if 'best_metrics' in results:best_metrics = results['best_metrics']loss_records[serialized_features] = {'MSE': convert_numpy_types(best_metrics['mse']),'MAE': convert_numpy_types(best_metrics['mae']),'R2': convert_numpy_types(best_metrics['r2'])}else:loss_records[serialized_features] = {'MSE': convert_numpy_types(results['mse']),'MAE': convert_numpy_types(results['mae']),'R2': convert_numpy_types(results['r2'])}# 输出当前的 loss_records 以进行调试print('Current loss_records before saving: ', loss_records)with open(loss_records_file, "w") as f:json.dump(loss_records, f)# 再次读取并检查文件内容,确保保存正确with open(loss_records_file, "r") as f:loaded_loss_records = json.load(f)print('Loaded loss_records after saving: ', loaded_loss_records)# 更新进度条pbar.update(len(current_batch))print("所有批次处理完成。")
pbar.close()# 最佳模型和每个特征组合的最佳模型保存到excel
save_result_to_excel(strategy_name, all_evaluation_results, OUTPUT_FILE_NEURAL_NETWORK_PATH, weights)def save_result_to_excel(strategy_name, evaluation_results, file_path, weights=None):"""数据保存到excel.Parameters:- evaluation_results 评估数据- file_path excel文件名称,用来保存测试报告Returns:None"""# print('评估数据evaluation_results:', evaluation_results)strategy_func = strategy_mapping.get(strategy_name)if strategy_func:num_params = len(inspect.signature(strategy_func).parameters)if weights and num_params > 1:best_result = strategy_func(evaluation_results, weights)print("best_result assigned successfully:", best_result)else:best_result = strategy_func(evaluation_results)print("best_result assigned successfully:", best_result)print('>>>>>>>>>>保存best_result>>>>>>>>>', best_result)print()    try:  # 创建一个空列表来存储评估过程的结果evaluation_process_data = []# 添加评估过程中的结果for result in evaluation_results:evaluation_process_data.append({'Features': result['features'],'Best Parameters': result['best_params'],'Best Metrics': result['best_metrics']})# 创建DataFrame来存储评估过程的结果df_evaluation_process = pd.DataFrame(evaluation_process_data)print('训练过程的数据:df_evaluation_process', df_evaluation_process)# 创建一个空的DataFrame来存储最佳模型的结果df_best_model_results = pd.DataFrame(columns=['Features', 'Best Predictions'])if best_result is not None:df_best_model_results.loc[0] = {'Features': best_result['features'],  # 使用best_result中的特征信息'Best Predictions': best_result['predictions']}# 倒置最佳模型结果DataFrame的行列df_best_model_results_transposed = df_best_model_results.transpose()# 创建一个新的 DataFrame,用于存储转置后的数据以及其含义df_with_labels = pd.DataFrame(columns=['Label', 'Value'])# 将原始表头作为索引,添加到新 DataFrame 中for feature in df_best_model_results_transposed.index:# 获取转置后数据的值,而不包括索引和数据类型信息value = df_best_model_results_transposed.loc[feature].values[0]df_with_labels = pd.concat([df_with_labels, pd.DataFrame({'Label': [feature], 'Value': [value]})], ignore_index=True)# 保存最佳模型的结果到Excel文件with pd.ExcelWriter(file_path, engine='xlsxwriter') as writer:df_with_labels.to_excel(writer, sheet_name='Best Model Results', index=False)print('执行了保存数据到excel,路径是:') print(file_path)    else:print("best_result is None, cannot save to excel")logging.error("best_result is None, cannot save to excel")except Exception as e:print(f"保存测试结果到excel: {e}")logging.error(f"save result to excel: {e}") else:print('Invalid strategy name:', strategy_name)

要点

  1. 清理多余记录:在处理批次之前,根据新的特征组合清理 loss_records 中多余的记录。
  2. 更新所有评估结果:即使跳过已处理的批次,也将其评估结果添加到 all_evaluation_results 中,以确保最终的最佳模型选择是基于所有特征组合。
  3. 保存最佳结果到Excel:保持 save_result_to_excel 函数逻辑不变,确保从所有评估结果中选出最优模型并保存。

这样可以确保即使跳过了一些已处理的批次,最终的最优模型仍然是从所有特征组合中选出的,并且中间结果不会包含多余的记录。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/27974.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

探索Docker容器网络

Docker容器已经成为现代应用部署的核心工具。理解Docker的网络模型对于实现高效、安全的容器化应用至关重要。在这篇博客中,我们将深入探讨Docker的网络架构,并通过一些代码例子来揭示其底层实现。 Docker网络模式 Docker提供了多种网络模式&#xff0c…

【Kubernetes】Helm--包管理工具

​​​​​​​ 微服务是什么? 微服务把大包解耦成小包,使用的时候使用java -jar包启动服务 Helm 什么是Helm? 在没使用 helm 之前,向 kubernetes 部署应用,我们要依次部署 deployment、svc 等,步骤较繁…

Spring (58)什么是Spring Kafka

Spring Kafka 是一个基于 Spring 框架的项目,它提供了对 Apache Kafka 的集成支持。Kafka 是一个分布式流媒体平台,专门用于构建实时数据管道和流应用程序。Spring Kafka 提供了一种简单的抽象来发送和接收消息,使得与 Kafka 交云进行通讯变得…

GPRS与4G网络:技术差异与应用选择

在移动通信的发展历程中,GPRS(General Packet Radio Service)和4G(Fourth-Generation)技术都扮演着举足轻重的角色。虽然两者都旨在提供无线数据传输服务,但在数据传输速率、延迟和覆盖范围等方面&#xff…

(游戏:三个数的加法)编写程序,随机产生三个一位整数,并提示用户输入这三个整数的和,判断用户输入的和是否正确。

(游戏:三个数的加法)编写程序,随机产生三个一位整数,并提示用户输入这三个整 数的和,判断用户输入的和是否正确。 package myjava; import java.math.*; import java.util.Scanner; public class cy {public static void main(String[]args)…

ssl安全证书免费申请方法,非自签证书

注意1: 如果一个域名一定时间内申请超过5次,会被锁定至少1周时间,还有就是一个IP一天太频繁发起安全证书,也有可能被锁订单 注意2: 申请的免费证书只有90天,后续版本再补充自动续签 前提条件 需要将要认证s…

Swift开发——循环执行方式

本文将介绍 Swift 语言的循环执行方式 01、循环执行方式 在Swift语言中,主要有两种循环执行控制方式: for-in结构和while结构。while结构又细分为当型while结构和直到型while结构,后者称为repeat-while结构。下面首先介绍for-in结构。 循环控制方式for-in结构可用于区间中的…

ceph scrub 错误记录

目的 记录 ceph scrub 错误问题解决 ceph scrub 故障故障信息 cluster:id: xxx-xxx-xxxhealth: HEALTH_ERR2 scrub errorsPossible data damage: 2 pg inconsistentmessage 日志信息 # egrep -i medium|i\/o error|sector|Prefailure /var/log/messages Jun 15 00:23:37 m…

跨境电商中的IP隔离是什么?怎么做?

一、IP地址隔离的概念和原理 当我们谈论 IP 地址隔离时,我们实际上是在讨论一种网络安全策略,旨在通过技术手段将网络划分为不同的区域或子网,每个区域或子网都有自己独特的 IP 地址范围。这种划分使网络管理员可以更精细地控制哪些设备或用…

Type-C接口显示器:C口高效连接与无限可能 LDR

Type-C显示器C接口的未来:高效连接与无限可能 随着科技的飞速发展,我们的日常生活和工作中对于高效、便捷的连接方式的需求日益增加。在这样的背景下,Type-C接口显示器凭借其卓越的性能和广泛的兼容性,正逐渐崭露头角&#xff0c…

Java中ArrayList(顺序表)的自我实现(如果想知道Java中怎么自我实现ArrayList,那么只看这一篇就足够了!)

前言:在 Java 编程中,ArrayList 是一种非常常用的数据结构,它提供了动态数组的实现方式,可以方便地存储和操作数据。相比于传统的数组,ArrayList 具有更多的灵活性和便利性,可以根据需要动态地调整大小&…

axios打通fastapi和vue,实现前后端分类项目开发

axios axios是一个前后端交互的工具,负责在前端代码,调用后端接口,将后端的数据请求到本地以后进行解析,然后传递给前端进行处理。 比如,我们用fastapi写了一个接口,这个接口返回了一条信息: …

后端项目怎么做?怎么准备面试,看这篇就够了!

近期群友都在海投,广撒网,为的就是等一个面试机会,等一个offer。 当收到面试通知的时候,大家一定要好好把握机会。 机会很重要,给你机会,没有把握住,那就比较尴尬了。 对于研发岗位来说&…

Hadoop 2.0:主流开源云架构(三)

目录 四、Hadoop 2.0体系架构(一)Hadoop 2.0公共组件Common(二)分布式文件系统HDFS(三)分布式操作系统Yarn(四)Hadoop 2.0安全机制简介 四、Hadoop 2.0体系架构 (一&…

如何解决mfc100u.dll丢失问题,关于mfc100u.dll丢失的多种解决方法

在计算机使用过程中,我们常常会遇到一些错误提示,其中之一就是“计算显示缺失mfc100u.dll”。这个问题可能会影响到我们的正常使用,因此了解它的原因、表现以及解决方法是非常重要的。小编将详细介绍计算显示缺失mfc100u.dll的问题&#xff0…

音视频集式流媒体边缘分布式集群拉流管理

一直以来,由于srs zlm等开源软件采用传统直播协议,即使后面实现了webrtc转发,由于信令交互较弱,使得传统的安防监控方案需要在公网云平台上部署大型流媒体服务器,而且节点资源不能统一管理调度,缺乏灵活性和…

基于PPO的强化学习超级马里奥自动通关

目录 一、环境准备 二、训练思路 1.训练初期: 2.思路整理及改进: 思路一: 思路二: 思路三: 思路四: 3.训练效果: 三、结果分析 四、完整代码 训练代码: 测试代码&#x…

2024.ZCPC.M题 计算三角形个数

题目描述: 小蔡有一张三角形的格子纸,上面有一个大三角形。这个边长为 的大三角形, 被分成 个边长为 1 的小三角形(如图一所示)。现在,小蔡选择了一条水平边 删除(如图二所示),请你找出图上剩余…

C#多线程与函数对象的实例

在C#中,通过使用委托和多线程可以实现传递函数对象给线程进行执行。下面是一个简单的实例,演示如何在多线程中使用函数对象: using System; using System.Threading;class Program {static void Main(){// 创建一个委托,用于传递…

RestTemplate远程请求的艺术

1 简说 编程是一门艺术,追求优雅的代码就像追求优美的音乐。 很多有多年工作经验的开发者,在使用RestTemplate之前常常使用HttpClient,然而接触了RestTemplate之后,却愿意放弃多年相处的“老朋友”,转向RestTemplate。那么一定是RestTemplate有它的魅力,有它的艺术风范。…