PyTorch深度学习模型训练流程:(二、回归)

回归的流程与分类基本一致,只需要把评估指标改动一下就行。回归输出的是损失曲线、R^2曲线、训练集预测值与真实值折线图、测试集预测值散点图与真实值折线图。输出效果如下:

 注意:预测值与真实值图像处理为按真实值排序,图中呈现的升序与数据集趋势无关。

代码如下:

from functools import partial
import numpy as np
import pandas as pd
from sklearn.preprocessing import label_binarize
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, roc_curve, r2_scoreimport torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, Dataset
from visdom import Visdomfrom typing import Union, Optional
from sklearn.base import TransformerMixin
from torch.optim.optimizer import Optimizerdef regress(data: tuple[Union[np.ndarray, Dataset], Union[np.ndarray, Dataset]],model: nn.Module,optimizer: Optimizer,criterion: nn.Module,scaler: Optional[TransformerMixin] = None,batch_size: int = 64,epochs: int = 10,device: Optional[torch.device] = None
) -> nn.Module:"""回归任务的训练函数。:param data: 形如(X,y)的np.ndarray类型,及形如(train_data,test_data)的torch.utils.data.Dataset类型:param model: 回归模型:param optimizer: 优化器:param criterion: 损失函数:param scaler: 数据标准化器:param batch_size: 批大小:param epochs: 训练轮数:param device: 训练设备:return: 训练好的回归模型"""if isinstance(data[0], np.ndarray):X, y = data# 分离训练集和测试集,指定随机种子以便复现X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 数据标准化if scaler is not None:X_train = scaler.fit_transform(X_train)X_test = scaler.transform(X_test)# 转换为tensorX_train = torch.from_numpy(X_train.astype(np.float32))X_test = torch.from_numpy(X_test.astype(np.float32))y_train = torch.from_numpy(y_train.astype(np.float32))y_test = torch.from_numpy(y_test.astype(np.float32))# 将X和y封装成TensorDatasettrain_dataset = TensorDataset(X_train, y_train)test_dataset = TensorDataset(X_test, y_test)elif isinstance(data[0], Dataset):train_dataset, test_dataset = dataelse:raise ValueError('Unsupported data type')train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True,num_workers=2,)test_loader = DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=True,num_workers=2,)model.to(device)vis = Visdom()# 训练模型for epoch in range(epochs):for step, (batch_x_train, batch_y_train) in enumerate(train_loader):batch_x_train = batch_x_train.to(device)batch_y_train = batch_y_train.to(device)# 前向传播output = model(batch_x_train)loss = criterion(output, batch_y_train)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()niter = epoch * len(train_loader) + step + 1  # 计算迭代次数if niter % 100 == 0:# 评估模型model.eval()with torch.no_grad():eval_dict = {'test_loss': [],'test_r2': [],'y_test': [],'y_pred': [],}for batch_x_test, batch_y_test in test_loader:batch_x_test = batch_x_test.to(device)batch_y_test = batch_y_test.to(device)test_output = model(batch_x_test)test_predicted_tuple = (batch_y_test.numpy(), test_output.numpy())# 计算并记录损失、R^2、真实值、预测值eval_dict['test_loss'].append(criterion(test_output, batch_y_test))eval_dict['test_r2'].append(r2_score(*test_predicted_tuple))eval_dict['y_test'].append(batch_y_test)eval_dict['y_pred'].append(test_output)# 画出损失曲线vis.line(X=torch.ones((1, 2)) * (niter // 100),Y=torch.stack((loss, torch.mean(torch.tensor(eval_dict['test_loss'])))).unsqueeze(0),win='loss',update='append',opts=dict(title='Loss', legend=['train_loss', 'test_loss']),)# 画出R^2曲线train_r2 = r2_score(batch_y_train.numpy(), output.numpy())vis.line(X=torch.ones((1, 2)) * (niter // 100),Y=torch.tensor((train_r2, np.mean(eval_dict['test_r2']))).unsqueeze(0),win='R^2',update='append',opts=dict(title='R^2', legend=['train_R^2', 'test_R^2'], ytickmin=0, ytickmax=1),)# 画出训练集预测值和真实值折线图sorted_train_idx = torch.argsort(batch_y_train)  # 按真实值排序vis.line(X=torch.arange(batch_size).repeat(2, 1).t(),Y=torch.stack((batch_y_train[sorted_train_idx], output[sorted_train_idx]), dim=1),win='batch_train_line',opts=dict(title='Predicted vs. Actual (Train Set)', legend=['Actual', 'Predicted']),)# 画出测试集预测值散点图和真实值折线图x = list(range(len(y_test)))y_test = torch.cat(eval_dict['y_test'])y_pred = torch.cat(eval_dict['y_pred'])sorted_test_idx = torch.argsort(y_test)vis._send({'data': [{'x': x, 'y': y_test[sorted_test_idx].tolist(), 'type': 'custom', 'mode': 'lines', 'name': 'Actual'},{'x': x, 'y': y_pred[sorted_test_idx].tolist(), 'type': 'custom', 'mode': 'markers', 'name': 'Predicted', 'marker': {'size': 3}}],'win': 'test_line','layout': {'title': 'Predicted vs. Actual (Test Set)'},})return model

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/878140.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【JS】使用MessageChannel实现深度克隆

前言 通常使用简便快捷的JSON 序列化与反序列化实现深克隆,也可以递归实现或者直接使用lodash。 但 JSON 序列化与反序列化 无法处理如下的循环引用: 实现 MessageChannel 内部使用了浏览器内置的结构化克隆算法,该算法可以在不同的浏览器上…

Qt WebAssembly 警告:构建套件中未设置编译器

目录 Qt WebAssembly 警告:构建套件中未设置编译器问题解决方法 参考资料 Qt WebAssembly 警告:构建套件中未设置编译器 问题 安装好QT之后构建套件中出现黄色感叹号Qt WebAssembly 警告:构建套件中未设置编译器。 原因是现在你只安装了qt for webassembly的qt的库&#xff…

Task-Embedded Control Networks for Few-Shot Imitation Learning

发表时间:CoRL 2018 论文链接:https://readpaper.com/pdf-annotate/note?pdfId4500197057754718210&noteId2424798567891365120 作者单位:Imperial College London Motivation:就像人类一样,机器人应该能够利用来…

JVM上篇:内存与垃圾回收篇-07-方法区

笔记来源:尚硅谷 JVM 全套教程,百万播放,全网巅峰(宋红康详解 java 虚拟机) 文章目录 7. 方法区7.1. 栈、堆、方法区的交互关系7.2. 方法区的理解7.2.1. 方法区在哪里?7.2.2. 方法区的基本理解7.2.3. HotSp…

fastapi 下怎么正确使用 async和await?

fastapi操作异步和同步请求 声明:异步“请求” 和 异步“方法调用” 的区别【关键点】 1、同步、异步方法 同步阻塞1.1 仅同步请求的并发测试 1.2 仅异步请求的并发测试 1.3 同步请求 和 异步请求 的并发 2、异步方法阻塞的解决方案2.1 使用线程池执行同步阻塞2.2 …

无人机之基本结构篇

无人机(Unmanned Aerial Vehicle, UAV)作为一种无人驾驶的飞行器,其基本结构涵盖了多个关键组件,这些组件共同协作以实现无人机的自主飞行和执行各种任务。以下是无人机基本结构的详细解析: 一、飞机平台系统 机身&am…

C++中 inline 的含义是什么?

在C中,inline是一个关键字,它向编译器发出一个请求(注意,这是一个请求而不是命令),请求编译器尝试将函数的调用替换为函数体本身的代码。这样做的目的是减少函数调用的开销,特别是对于那些体积小…

vue2表单校验:添加自定义el-form表单校验规则

前言 在vue2表单校验:el-form表单绑定数组并使用rules进行校验_vue2 rules校验-CSDN博客中,使用form原生的rules对表单中每个控件的必填、格式等做了校验。但是保存时,除了验证每一个控件的输入合乎要求外,还需要验证控件之间的数…

SpringBoot集成kafka-生产者发送消息

springboot集成kafka发送消息 1、kafkaTemplate.send()方法1.1、springboot集成kafka发送消息Message对象消息1.2、springboot集成kafka发送ProducerRecord对象消息1.3、springboot集成kafka发送指定分区消息 2、kafkaTemplate.sendDefault()方法3、kafkaTemplate.send(...)和k…

WIN/MAC 图像处理软件Adobe Photoshop PS2024软件下载安装

目录 一、软件概述 1.1 基本信息 1.2 主要功能 二、系统要求 2.1 Windows 系统要求 2.2 macOS 系统要求 三、下载 四、使用教程 4.1 基本界面介绍 4.2 常用工具使用 4.3 进阶操作 一、软件概述 1.1 基本信息 Adobe Photoshop(简称PS)是一款…

springboot嵌入式数据库实践-H2内嵌数据库(文件、内存)

本文章记录笔者的嵌入式数据库简单实现, 记录简要的配置过程。自用文章,仅作参考。 目录 本文章记录笔者的嵌入式数据库简单实现, 记录简要的配置过程。自用文章,仅作参考。 嵌入式数据库 -------------------------------具…

前端手写源码系列(三)——手写_deepClone深浅拷贝

目录 一、基本类型和引用类型二、深浅拷贝概念三、浅拷贝实现方式1、Object.assign()2、Array.prototype.concat() 修改新对象会改到原对象3、Array.prototype.slice() 四、深拷贝实现方式1、JSON.parse(JSON.stringify())2、手写递归方法3、函数库lodash 五、手写深拷贝 一、基…

Linux系统(centos7)增加一个开机自启任务

任务背景 已经上线了一个java的springboot项目,使用start.sh脚本进行启动,脚本内容为: #!/bin/bashnohup java -jar /opt/javaProject/PracticeSpring-0.0.1-SNAPSHOT.jar > /opt/javaProject/run.log 2>&1 & 现在&#xff…

16岁激活交学费银行卡需要本人实名电话卡,线下营业厅不给办,怎么办?

16岁激活交学费银行卡需要本人实名电话卡,线下营业厅不给办,怎么办? 话卡办理规定: 根据《民法典》和《电话用户真实身份信息登记规定》的相关要求,未满16周岁的用户通常需要在监护人的陪同下办理电话卡,并…

uniapp微信小程序 分享功能

uniapp https://zh.uniapp.dcloud.io/api/plugins/share.html#onshareappmessage export default {onShareAppMessage(res) {if (res.from button) {// 来自页面内分享按钮console.log(res.target)}return {title: 自定义分享标题,path: /pages/test/test?id123}} }需要再真机…

IntelliJ IDEA智能代码补全​和集成AI助手说明及操作

IntelliJ IDEA的智能代码补全和集成AI助手是开发者提高编码效率和代码质量的重要工具。以下是对这些功能的详细说明及操作指南: 一、智能代码补全 1. 功能说明 IntelliJ IDEA的智能代码补全功能利用先进的算法和上下文分析,为开发者提供准确、快速的代…

程序猿必备技能-Bat脚本

Batch 脚本(批处理脚本)是在 Windows 操作系统中使用的一种脚本语言,用于自动化执行一系列命令。Batch 脚本是由 .bat 或 .cmd 文件扩展名标识的文本文件,这些文件可以被 Windows 的命令行解释器(如 cmd.exe&#xff0…

衡石科技BI的API如何授权文档解析

授权说明​ 授权模式​ 使用凭证式(client credentials)授权模式。 授权模式流程说明​ 第一步,A 应用在命令行向 B 发出请求。 第二步,B 网站验证通过以后,直接返回令牌。 授权模式结构说明​ 接口说明​ 获取a…

shell之getopts

getopts 是一个常用于解析命令行选项的bash内建命令。它的基本语法是: getopts optstring name [arg...]optstring列出了对应的Shell Script可以识别的所有参数。比如: 如果 Shell Script可以识别-a,-f以及-s参数,则optstring就是…

【贪心 决策包容性 】757. 设置交集大小至少为2

本文涉及知识点 贪心 决策包容性 LeetCode757. 设置交集大小至少为2 给你一个二维整数数组 intervals ,其中 intervals[i] [starti, endi] 表示从 starti 到 endi 的所有整数,包括 starti 和 endi 。 包含集合 是一个名为 nums 的数组,并…