深度学习篇---模型训练(1)


文章目录

  • 前言
  • 一、库导入与配置部分
    • 介绍
  • 二、超参数配置
    • 简介
  • 三、模型定义
    • 1. 改进残差块
    • 2. 完整CNN模型
  • 四、数据集类
  • 五、数据加载函数
  • 六、训练函数
  • 七、验证函数
  • 八、检查点管理
  • 九、主函数
  • 十、执行入口
  • 十一、关键设计亮点总结
    • 1.维度管理
    • 2.数据标准化
    • 3.动态学习率
    • 4.梯度剪裁
    • 5.检查点系统
    • 6.结果可追溯
    • 7.工业级健壮性
    • 8.高效数据加载


前言

本文再网络结构(1)的基础上,完善数据读取、数据增强、数据处理、模型训练、断点训练等功能。


一、库导入与配置部分

import torch
import torch.nn as nn  # PyTorch核心神经网络模块
import pandas as pd    # 数据处理
import numpy as np     # 数值计算
from torch.utils.data import Dataset, DataLoader  # 数据加载工具
from sklearn.preprocessing import StandardScaler  # 数据标准化
from sklearn.model_selection import train_test_split  # 数据分割
from torch.optim.lr_scheduler import ReduceLROnPlateau  # 动态学习率调整
from collections import Counter  # 统计类别分布
import csv  # 结果记录
import time  # 时间戳生成
import joblib  # 模型/参数持久化

介绍

导入Pytorch核心神经网路模块、数据处理库和数值处理库数据标准化、数据分割、动态学习率调整、统计类别分布、结果记录、时间戳生成、模型/参数持久化。

二、超参数配置

config = {"batch_size": 256,        # 每批数据量"num_workers": 128,       # 数据加载并行进程数"lr": 1e-3,               # 初始学习率"weight_decay": 1e-4,     # L2正则化强度"epochs": 200,            # 最大训练轮数"patience": 15,           # 早停等待轮数"min_delta": 0.001,       # 视为改进的最小精度提升"grad_clip": 5.0,         # 梯度裁剪阈值"num_classes": None       # 自动计算类别数
}

简介

设置每批数据量、数据加载并行进程数、初始学习率、L2正则化强度、最大训练轮数、早停等待轮数、视为改进的最小精度提升、梯度剪裁阈值、自动计算类别数。

三、模型定义

1. 改进残差块

class ImprovedResBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1):super().__init__()  # 初始化父类# 第一个卷积层self.conv1 = nn.Conv1d(in_channels, out_channels, 5, stride, 2)# 参数解释:输入通道,输出通道,卷积核大小5,步长,填充2(保持尺寸)self.bn1 = nn.BatchNorm1d(out_channels)  # 批量归一化# 第二个卷积层self.conv2 = nn.Conv1d(out_channels, out_channels, 3, 1, 1)# 3x1卷积,步长1,填充1保持尺寸self.bn2 = nn.BatchNorm1d(out_channels)self.relu = nn.ReLU()  # 激活函数# 下采样路径(当需要调整维度时)self.downsample = nn.Sequential(nn.Conv1d(in_channels, out_channels, 1, stride),  # 1x1卷积调整维度nn.BatchNorm1d(out_channels)) if in_channels != out_channels or stride != 1 else None# 当输入输出通道不同或步长>1时启用def forward(self, x):identity = x  # 保留原始输入作为残差# 主路径处理x = self.relu(self.bn1(self.conv1(x)))  # Conv1 -> BN1 -> ReLUx = self.bn2(self.conv2(x))  # Conv2 -> BN2(无激活)# 调整残差路径维度if self.downsample:identity = self.downsample(identity)x += identity  # 残差连接return self.relu(x)  # 最终激活

2. 完整CNN模型

class EnhancedCNN(nn.Module):def __init__(self, input_channels, seq_len, num_classes):super().__init__()# 初始特征提取层self.initial = nn.Sequential(nn.Conv1d(input_channels, 64, 7, stride=2, padding=3),  # 快速下采样nn.BatchNorm1d(64),nn.ReLU(),nn.MaxPool1d(3, 2, 1)  # 核3,步长2,填充1,输出尺寸约为输入1/4)# 残差块堆叠self.blocks = nn.Sequential(ImprovedResBlock(64, 128, stride=2),  # 通道翻倍,尺寸减半ImprovedResBlock(128, 256, stride=2),ImprovedResBlock(256, 512, stride=2),nn.AdaptiveAvgPool1d(1)  # 自适应全局平均池化到长度1)# 分类器self.classifier = nn.Sequential(nn.Linear(512, 256),     # 全连接层nn.Dropout(0.5),         # 强正则化防止过拟合nn.ReLU(),nn.Linear(256, num_classes)  # 最终分类层)def forward(self, x):x = self.initial(x)  # 初始特征提取x = self.blocks(x)   # 通过残差块x = x.view(x.size(0), -1)  # 展平维度 (batch, 512)return self.classifier(x)  # 分类预测

四、数据集类

class SequenceDataset(Dataset):def __init__(self, sequences, labels, scaler=None):self.sequences = sequences  # 原始序列数据self.labels = labels        # 对应标签self.scaler = scaler or StandardScaler()  # 标准化器# 如果未提供scaler,用当前数据拟合新的if scaler is None:flattened = np.concatenate(sequences)  # 展平所有数据点self.scaler.fit(flattened)  # 计算均值和方差# 对每个序列进行标准化self.normalized = [self.scaler.transform(seq) for seq in sequences]def __len__(self):return len(self.sequences)  # 返回数据集大小def __getitem__(self, idx):# 获取单个样本seq = torch.tensor(self.normalized[idx], dtype=torch.float32).permute(1, 0)# permute将形状从(seq_len, features)转为(features, seq_len)符合Conv1d输入要求label = torch.tensor(self.labels[idx], dtype=torch.long)# 数据增强if np.random.rand() > 0.5:  # 50%概率时序翻转seq = seq.flip(-1)  # 沿时间维度翻转if np.random.rand() > 0.3:  # 70%概率添加噪声seq += torch.randn_like(seq) * 0.01  # 高斯噪声(均值0,方差0.01)return seq, label

五、数据加载函数

def load_data(excel_path):df = pd.read_excel(excel_path)  # 读取Excel数据sequences = []labels = []for _, row in df.iterrows():  # 遍历每一行数据try:# 处理可能存在的字符串格式异常loads = list(map(float, str(row['载荷']).split(',')))displacements = list(map(float, str(row['位移']).split(',')))powers = list(map(float, str(row['功率']).split(',')))# 对齐三列数据的长度min_len = min(len(loads), len(displacements), len(powers))# 组合成(时间步长, 3个特征)的数组combined = np.array([loads[:min_len], displacements[:min_len], powers[:min_len]).T  # 转置为(min_len, 3)label = int(float(row['工况结果']))  # 转换标签sequences.append(combined)labels.append(label)except Exception as e:print(f"处理第{_}行时出错: {str(e)}")  # 异常处理# 统计类别分布label_counts = Counter(labels)print("类别分布:", label_counts)# 创建标签映射(将任意标签转换为0~N-1的索引)unique_labels = sorted(list(set(labels)))label_map = {l:i for i,l in enumerate(unique_labels)}config["num_classes"] = len(unique_labels)  # 更新配置labels = [label_map[l] for l in labels]  # 转换所有标签# 分层划分训练/验证集(保持类别比例)return train_test_split(sequences, labels, test_size=0.2, stratify=labels)

六、训练函数

def train_epoch(model, loader, optimizer, criterion, device):model.train()  # 训练模式total_loss = 0for x, y in loader:  # 遍历数据加载器x, y = x.to(device), y.to(device)  # 数据迁移到设备optimizer.zero_grad()  # 清空梯度outputs = model(x)     # 前向传播loss = criterion(outputs, y)  # 计算损失loss.backward()        # 反向传播# 梯度裁剪防止爆炸nn.utils.clip_grad_norm_(model.parameters(), config["grad_clip"])optimizer.step()       # 参数更新total_loss += loss.item() * x.size(0)  # 累加损失(考虑批次大小)return total_loss / len(loader.dataset)  # 平均损失

七、验证函数

def validate(model, loader, criterion, device):model.eval()  # 评估模式total_loss = 0correct = 0with torch.no_grad():  # 禁用梯度计算for x, y in loader:x, y = x.to(device), y.to(device)outputs = model(x)            loss = criterion(outputs, y)total_loss += loss.item() * x.size(0)# 计算准确率preds = outputs.argmax(dim=1)  # 取最大概率类别correct += preds.eq(y).sum().item()  # 统计正确数return (total_loss / len(loader.dataset),  # 平均损失(correct / len(loader.dataset))    # 准确率

八、检查点管理

def save_checkpoint(epoch, model, optimizer, scheduler, best_acc, scaler, filename="checkpoint.pth"):torch.save({'epoch': epoch,                    # 当前轮数'model_state_dict': model.state_dict(),          # 模型参数'optimizer_state_dict': optimizer.state_dict(),  # 优化器状态'scheduler_state_dict': scheduler.state_dict(),  # 学习率调度器状态'best_acc': best_acc,              # 当前最佳准确率'scaler': scaler                   # 数据标准化参数}, filename)def load_checkpoint(filename, model, optimizer, scheduler):checkpoint = torch.load(filename)model.load_state_dict(checkpoint['model_state_dict'])       # 加载模型optimizer.load_state_dict(checkpoint['optimizer_state_dict']) scheduler.load_state_dict(checkpoint['scheduler_state_dict'])return checkpoint['epoch'], checkpoint['best_acc'], checkpoint['scaler']

九、主函数

def main(resume=False):device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 自动选择设备# 生成带时间戳的结果文件名timestamp = time.strftime("%Y%m%d_%H%M%S")results_file = f"training_results_{timestamp}.csv"# 加载并划分数据train_seq, val_seq, train_lb, val_lb = load_data("./dcgt.xls")# 初始化模型(恢复训练时自动获取序列长度)sample_seq = train_seq[0].shape[1] if resume else Nonemodel = EnhancedCNN(input_channels=3, seq_len=sample_seq,  num_classes=config["num_classes"]).to(device)# 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])# 学习率调度器(根据验证损失调整)scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)# 恢复训练逻辑start_epoch = 0best_acc = 0if resume:checkpoint = torch.load("checkpoint.pth")model.load_state_dict(checkpoint['model_state_dict'])optimizer.load_state_dict(checkpoint['optimizer_state_dict'])scheduler.load_state_dict(checkpoint['scheduler_state_dict'])start_epoch = checkpoint['epoch']best_acc = checkpoint['best_acc']train_set = SequenceDataset(train_seq, train_lb, scaler=checkpoint['scaler'])else:train_set = SequenceDataset(train_seq, train_lb)# 验证集使用训练集的scalerval_set = SequenceDataset(val_seq, val_lb, scaler=train_set.scaler)# 持久化标准化参数joblib.dump(train_set.scaler, 'scaler.save')# 创建数据加载器train_loader = DataLoader(train_set, batch_size=config["batch_size"], shuffle=True, num_workers=config["num_workers"]  # 多进程加载加速)val_loader = DataLoader(val_set, batch_size=config["batch_size"], num_workers=config["num_workers"])# 训练循环with open(results_file, 'w', newline='') as f:writer = csv.writer(f)writer.writerow(['epoch', 'train_loss', 'val_loss', 'val_acc', 'learning_rate'])for epoch in range(start_epoch, config["epochs"]):# 训练一个epochtrain_loss = train_epoch(model, train_loader, optimizer, criterion, device)# 验证val_loss, val_acc = validate(model, val_loader, criterion, device)current_lr = optimizer.param_groups[0]['lr']  # 获取当前学习率# 更新学习率scheduler.step(val_loss)# 保存检查点save_checkpoint(epoch+1, model, optimizer, scheduler, best_acc, train_set.scaler)# 记录结果writer.writerow([epoch + 1, f"{train_loss:.4f}", f"{val_loss:.4f}", f"{val_acc:.4f}", f"{current_lr:.6f}"])print(f"\nEpoch {epoch+1}/{config['epochs']}")print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")print(f"Val Acc: {val_acc*100:.2f}% | Learning Rate: {current_lr:.6f}")# 早停逻辑(伪代码示意)if val_acc > best_acc + config["min_delta"]:best_acc = val_accpatience_counter = 0else:patience_counter += 1if patience_counter >= config["patience"]:print(f"早停触发于第{epoch+1}轮")break# 保存最终模型torch.save(model.state_dict(), "best_model.pth")

十、执行入口

if __name__ == "__main__":main(resume=False)  # 首次训练# main(resume=True)  # 恢复训练

十一、关键设计亮点总结

1.维度管理

维度管理:通过permute确保数据形状符合Conv1d要求

2.数据标准化

数据标准化:使用全体训练数据计算均值和方差,避免数据泄露

3.动态学习率

动态学习率:ReduceLROnPlateau根据验证损失自动调整

4.梯度剪裁

梯度裁剪:防止梯度爆炸,稳定训练过程

5.检查点系统

检查点系统:完整保存训练状态,支持训练中断恢复

6.结果可追溯

结果可追溯:带时间戳的CSV记录和模型保存

7.工业级健壮性

工业级健壮性:异常捕获、参数持久化、自动类别映射

8.高效数据加载

高效数据加载:多进程并行加速数据预处理

这个实现涵盖了从数据预处理到模型训练的完整流程,适合工业级时间序列分类任务,具有良好的可扩展性和可维护性。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/76843.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

题解:AT_abc241_f [ABC241F] Skate

一道经典的 bfs 题。 提醒:本题解是为小白专做的,不想看的大佬请离开。 这道题首先一看就知道是 bfs,但是数据点不让我们过: 1 ≤ H , W ≤ 1 0 9 1\le H,W\le10^9 1≤H,W≤109。 那么我们就需要优化了,从哪儿下手…

【含文档+PPT+源码】基于微信小程序的乡村振兴民宿管理系统

项目介绍 本课程演示的是一款基于微信小程序的乡村振兴民宿管理系统,主要针对计算机相关专业的正在做毕设的学生与需要项目实战练习的 Java 学习者。 1.包含:项目源码、项目文档、数据库脚本、软件工具等所有资料 2.带你从零开始部署运行本套系统 3.该…

STM32定时器通道1-4(CH1-CH4)的引脚映射关系

以下是 STM32定时器通道1-4(CH1-CH4)的引脚映射关系的详细说明,以常见型号为例。由于不同系列/型号差异较大,请务必结合具体芯片的参考手册确认。 一、STM32F1系列(如STM32F103C8T6) 1. TIM1(高级定时器) 通道默认引脚重映射引脚(部分/完全)备注CH1PA8无互补输出CH1…

bge-m3+deepseek-v2-16b+离线语音能力实现离线文档向量化问答语音版

ollama run deepseek-v2:16b ollama pull bge-m3 1、离线听写效果的大幅度提升。50M 1.3G(每次初始化都会很慢)---优化到首次初始化使用0延迟响应。 2、文档问答历史问题处理与优化,文档问答离线策略讨论与参数暴露。 3、离线大模型答复中断…

前端界面在线excel编辑器 。node编写post接口获取文件流,使用传参替换表格内容展示、前后端一把梭。

首先luckysheet插件是支持在线替换excel内容编辑得但是浏览器无法调用本地文件,如果只是展示,让后端返回文件得二进制文件流就可以了,直接使用luckysheet展示。 这里我们使用xlsx-populate得node简单应用来调用本地文件,自己写一个…

JavaScript学习20-Event事件对象

1.属性 即点击谁就打印出来谁 2.方法 未添加stopPropagatio方法: 添加stopPropagatio方法后:

FreeRTOS 启动过程中 SVC 和 PendSV 的工作流程​

在 FreeRTOS 的启动过程中,SVC(Supervisor Call) 和 PendSV(Pendable Service Call) 是两个关键的系统异常,分别用于 首次任务启动 和 任务上下文切换。它们的协作确保了从内核初始化到多任务调度的平滑过渡。以下是详细的工作流程分析(以 ARM Cortex-M 为例): 1. SVC…

[自制调试工具]构建高效调试利器:Debugger 类详解

一、引言 在软件开发的漫漫征程中,调试就像是一位忠诚的伙伴,时刻陪伴着开发者解决代码里的各类问题。为了能更清晰地了解程序运行时变量的状态,我们常常需要输出各种变量的值。而 Debugger 类就像是一个贴心的调试助手,它能帮我…

foobar2000 VU Meter Visualisation 插件汉化版 VU表

原英文插件点此 界面展示 下载 https://wwtn.lanzout.com/iheI22ssoybi 安装方式 解压安装文件,文件名为:foo_vis_vumeter-0.10.2_CHINIESE.fb2k-component

消息中间件对比与选型指南:Kafka、ActiveMQ、RabbitMQ与RocketMQ

目录 引言 消息中间件的定义与作用 消息中间件在分布式系统中的重要性 对比分析的四种主流消息中间件概述 消息中间件核心特性对比 消息传递模型 Kafka:专注于发布-订阅模型 ActiveMQ:支持点对点和发布-订阅两种模型 RabbitMQ:支持点…

liunx输入法

1安装fcitx5 sudo apt update sudo apt install fcitx fcitx-pinyin 2配置为默认输入法 设置-》系统-》区域和语言 点击系统弹出语言和支持选择键盘输入法系统 3设置设置 fcitx-configtool 如果没显示需要重启电脑 4配置fcitx 把搜狗输入法放到第一位(点击下面…

WindowsPE文件格式入门05.PE加载器LoadPE

https://bpsend.net/thread-316-1-1.html LoadPE - pe 加载器 壳的前身 如果想访问一个程序运行起来的内存,一种方法就是跨进程读写内存,但是跨进程读写内存需要来回调用api,不如直接访问地址来得方便,那么如果我们需要直接访问地址,该怎么做呢?.需要把dll注进程,注进去的代码…

QGIS中第三方POI坐标偏移的快速校正-百度POI

1.百度POI: name,lng,lat,address 龙记黄焖鸡米饭(共享区店),121.908315,30.886636,南汇新城镇沪城环路699弄117号(A1区110室) 好福记黄焖鸡(御桥路店),121.571409,31.162292,沪南路2419弄26号1层B间 御品黄焖鸡米饭(安亭店),121.160322,31.305977,安亭镇新源路792号…

SQL的调优方案

一、前言 SQL调优是提升数据库性能的关键手段。需结合索引优化、SQL语句优化、执行计划分析及数据库架构设计等多方面综合处理。 二、索引优化 创建合适索引 高频查询字段:对WHERE、JOIN、ORDER BY涉及的字段创建索引,尤其是区分度高的字段&#xff08…

【项目管理】第一部分 信息技术 1/2

相关文档,希望互相学习,共同进步 风123456789~-CSDN博客 概要 知识点: 现代化基础设施、数字经济、工业互联网、车联网、智能制造、智慧城市、数字政府、5G、常用数据库类型、数据仓库、信息安全、网络安全态势感知、物联网、大数…

【玩泰山派】1、mac上使用串口连接泰山派

文章目录 前言picocom工具连接泰山派安装picocom工具安装ch340的驱动串口工具接线使用picocom连接泰山派 参考 前言 windows上面有xshell这个好用的工具可以使用串口连接板子,在mac上好像没找到太好的工具,只能使用命令行工具去搞了。 之前查找说mac上…

【C++奇遇记】C++中的进阶知识(继承(一))

🎬 博客主页:博主链接 🎥 本文由 M malloc 原创,首发于 CSDN🙉 🎄 学习专栏推荐:LeetCode刷题集 数据库专栏 初阶数据结构 🏅 欢迎点赞 👍 收藏 ⭐留言 📝 如…

【Scratch编程系列】Scratch编程软件界面

Scratch是一款由麻省理工学院(MIT) 设计开发的少儿编程工具。其特点是:使用者可以不认识英文单词,也可以不使用键盘,就可以进行编程。构成程序的命令和参数通过积木形状的模块来实现。用鼠标拖动指令模块到脚本区就可以了。 这个软…

开篇 - 配置Unlua+VsCode的智能提示、调试以及学习方法

智能提示 为要绑定Lua的蓝图创建模板文件,这会在Content/Script下生成lua文件 然后点击生成智能代码提示,这会在Plugins/Unlua/Intermediate/生成Intenllisense文件夹 打开VSCode,点击文件->将工作区另存为。生成一个空工作区,放置在工程…

QEMU-KVM加SPICE,云电脑诞生了

没错!‌QEMU-KVM SPICE‌ 的组合,本质上就是一套‌轻量级云电脑(云桌面)‌的解决方案。通过虚拟化技术将计算资源池化,再通过SPICE协议提供流畅的远程桌面体验,用户用任意设备(笔记本/平板/瘦客…