informer辅助笔记:exp/exp_informer.py

0 导入库

from data.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_Pred
from exp.exp_basic import Exp_Basic
from models.model import Informer, InformerStackfrom utils.tools import EarlyStopping, adjust_learning_rate
from utils.metrics import metricimport numpy as npimport torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoaderimport os
import timeimport warnings
warnings.filterwarnings('ignore')

1 Exp_Informer

class Exp_Informer(Exp_Basic):def __init__(self, args):super(Exp_Informer, self).__init__(args)

1.1 build_model

'''
用于构建模型。它根据提供的参数来实例化特定类型的模型
'''
def _build_model(self):model_dict = {'informer':Informer,'informerstack':InformerStack,}if self.args.model=='informer' or self.args.model=='informerstack':e_layers = self.args.e_layers if self.args.model=='informer' else self.args.s_layersmodel = model_dict[self.args.model](self.args.enc_in,self.args.dec_in, self.args.c_out, self.args.seq_len, self.args.label_len,self.args.pred_len, self.args.factor,self.args.d_model, self.args.n_heads, e_layers, # self.args.e_layers,self.args.d_layers, self.args.d_ff,self.args.dropout, self.args.attn,self.args.embed,self.args.freq,self.args.activation,self.args.output_attention,self.args.distil,self.args.mix,self.device).float()#用提供的参数实例化模型if self.args.use_multi_gpu and self.args.use_gpu:model = nn.DataParallel(model, device_ids=self.args.device_ids)#如果设置为使用多 GPU,那么模型将被包装在 nn.DataParallel 中,以便在多个 GPU 上并行运行。return model

1.2 get_data

'''
根据指定的模式(如训练、测试或预测)获取数据
'''
def _get_data(self, flag):args = self.argsdata_dict = {'ETTh1':Dataset_ETT_hour,'ETTh2':Dataset_ETT_hour,'ETTm1':Dataset_ETT_minute,'ETTm2':Dataset_ETT_minute,'WTH':Dataset_Custom,'ECL':Dataset_Custom,'Solar':Dataset_Custom,'custom':Dataset_Custom,}'''定义了一个字典,映射不同的数据集名称到相应的数据集类。例如,'ETTh1' 和 'ETTh2' 映射到 Dataset_ETT_hour 类。'''Data = data_dict[self.args.data]#根据参数中指定的数据集名称选择相应的数据集类timeenc = 0 if args.embed!='timeF' else 1    #设置时间编码标志。如果嵌入类型不是 'timeF',则 timeenc 设置为 0,否则设置为 1。if flag == 'test':shuffle_flag = False; drop_last = True; batch_size = args.batch_size; freq=args.freqelif flag=='pred':shuffle_flag = False; drop_last = False; batch_size = 1; freq=args.detail_freqData = Dataset_Predelse:shuffle_flag = True; drop_last = True; batch_size = args.batch_size; freq=args.freq'''根据 flag 参数(指示数据集用途,如 'test', 'pred', 或其他)设置不同的参数:shuffle_flag:是否打乱数据。drop_last:在数据批次不足时是否丢弃最后一批数据。batch_size:每批数据的大小。freq:数据频率,用于确定数据处理的时间间隔。'''data_set = Data(root_path=args.root_path,data_path=args.data_path,flag=flag,size=[args.seq_len, args.label_len, args.pred_len],features=args.features,target=args.target,inverse=args.inverse,timeenc=timeenc,freq=freq,cols=args.cols)'''使用指定参数实例化数据集。这里包括了数据路径标志(如 'train', 'test')序列长度、标签长度、预测长度特征类型 (M,S,MS)目标列时间编码标志频率需要使用的列'''print(flag, len(data_set))data_loader = DataLoader(data_set,batch_size=batch_size,shuffle=shuffle_flag,num_workers=args.num_workers,drop_last=drop_last)'''使用 DataLoader 创建一个数据加载器,用于批量加载数据同时指定是否打乱、是否丢弃最后一个批次、使用的工作进程数量等。'''return data_set, data_loader#返回数据集和数据加载器的实例

1.3 optimizer & criterion

def _select_optimizer(self):model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)return model_optimdef _select_criterion(self):criterion =  nn.MSELoss()return criterion#选择优化器和损失函数

1.4 vali

'''
在验证集上评估模型
'''
def vali(self, vali_data, vali_loader, criterion):self.model.eval() #将模型设置为评估模式total_loss = []for i, (batch_x,batch_y,batch_x_mark,batch_y_mark) in enumerate(vali_loader):#遍历验证数据加载器中的每个批次pred, true = self._process_one_batch(vali_data, batch_x, batch_y, batch_x_mark, batch_y_mark)#调用 _process_one_batch 方法处理一个批次的数据。这个方法会返回预测值(pred)和真实值(true)loss = criterion(pred.detach().cpu(), true.detach().cpu())#计算预测值和真实值之间的损失total_loss.append(loss)#将计算出的损失添加到 total_loss 列表中total_loss = np.average(total_loss)#计算所有批次损失的平均值。这个平均损失表示在验证数据集上模型的整体性能。self.model.train()#将模型重新设置为训练模式,继续训练模型return total_loss#返回计算出的平均损失值

1.5 train

'''
训练模型
'''
def train(self, setting):train_data, train_loader = self._get_data(flag = 'train')vali_data, vali_loader = self._get_data(flag = 'val')test_data, test_loader = self._get_data(flag = 'test')#使用 _get_data 方法加载训练、验证和测试数据集。path = os.path.join(self.args.checkpoints, setting)if not os.path.exists(path):os.makedirs(path)#创建用于保存模型检查点的目录time_now = time.time()train_steps = len(train_loader)early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)#使用EarlyStopping  检查是否应停止训练model_optim = self._select_optimizer()criterion =  self._select_criterion()if self.args.use_amp:scaler = torch.cuda.amp.GradScaler()'''初始化一些变量:train_steps:训练数据加载器中的批次总数。early_stopping:如果验证损失在一定迭代次数后没有改善,则停止训练。model_optim:选择优化器。criterion:选择损失函数。如果启用了自动混合精度(AMP),则初始化 scaler。'''for epoch in range(self.args.train_epochs):iter_count = 0train_loss = []self.model.train()epoch_time = time.time()for i, (batch_x,batch_y,batch_x_mark,batch_y_mark) in enumerate(train_loader):#遍历训练数据加载器中的所有批次iter_count += 1model_optim.zero_grad() #清除模型优化器的梯度pred, true = self._process_one_batch(train_data, batch_x, batch_y, batch_x_mark, batch_y_mark)#使用 _process_one_batch 处理批次数据,计算损失loss = criterion(pred, true)#计算这一个batch预测值和实际值的差距train_loss.append(loss.item())if (i+1) % 100==0:print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))speed = (time.time()-time_now)/iter_countleft_time = speed*((self.args.train_epochs - epoch)*train_steps - i)print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))iter_count = 0time_now = time.time()#每100次迭代打印损失和预计剩余时间if self.args.use_amp:scaler.scale(loss).backward()scaler.step(model_optim)scaler.update()else:loss.backward()model_optim.step()#损失后向传播和优化器步骤,如果启用了 AMP,则使用 scaler 进行这些步骤print("Epoch: {} cost time: {}".format(epoch+1, time.time()-epoch_time))train_loss = np.average(train_loss)vali_loss = self.vali(vali_data, vali_loader, criterion)#对模型进行validationtest_loss = self.vali(test_data, test_loader, criterion)print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(epoch + 1, train_steps, train_loss, vali_loss, test_loss))early_stopping(vali_loss, self.model, path)if early_stopping.early_stop:print("Early stopping")breakadjust_learning_rate(model_optim, epoch+1, self.args)best_model_path = path+'/'+'checkpoint.pth'self.model.load_state_dict(torch.load(best_model_path))#在训练结束后,加载表现最好的模型状态return self.model

1.6 test

'''
在测试集上评估模型
'''
def test(self, setting):test_data, test_loader = self._get_data(flag='test')#加载测试数据集self.model.eval()preds = []trues = []#存储模型的预测和相应的真实值for i, (batch_x,batch_y,batch_x_mark,batch_y_mark) in enumerate(test_loader):pred, true = self._process_one_batch(test_data, batch_x, batch_y, batch_x_mark, batch_y_mark)preds.append(pred.detach().cpu().numpy())trues.append(true.detach().cpu().numpy())'''遍历测试数据加载器中的每个批次。使用 _process_one_batch 方法处理每个批次的数据。将预测值和真实值添加到各自的列表中。'''preds = np.array(preds)trues = np.array(trues)print('test shape:', preds.shape, trues.shape)preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])trues = trues.reshape(-1, trues.shape[-2], trues.shape[-1])print('test shape:', preds.shape, trues.shape)# result savefolder_path = './results/' + setting +'/'if not os.path.exists(folder_path):os.makedirs(folder_path)#创建一个文件夹来存储测试结果mae, mse, rmse, mape, mspe = metric(preds, trues)#使用自定义的 metric 函数计算各种性能指标,如 MAE(平均绝对误差)、MSE(均方误差)、RMSE(均方根误差)、MAPE(平均绝对百分比误差)和 MSPE(均方百分比误差)。print('mse:{}, mae:{}'.format(mse, mae))np.save(folder_path+'metrics.npy', np.array([mae, mse, rmse, mape, mspe]))np.save(folder_path+'pred.npy', preds)np.save(folder_path+'true.npy', trues)return

1.7 predict

#在新数据上进行模型预测
def predict(self, setting, load=False):pred_data, pred_loader = self._get_data(flag='pred')#加载预测数据集if load:path = os.path.join(self.args.checkpoints, setting)best_model_path = path+'/'+'checkpoint.pth'self.model.load_state_dict(torch.load(best_model_path))#如果 load 为 True,则从保存的路径加载最佳模型的状态。self.model.eval()preds = []for i, (batch_x,batch_y,batch_x_mark,batch_y_mark) in enumerate(pred_loader):pred, true = self._process_one_batch(pred_data, batch_x, batch_y, batch_x_mark, batch_y_mark)preds.append(pred.detach().cpu().numpy())'''遍历预测数据加载器中的每个批次。使用 _process_one_batch 方法处理每个批次的数据。将预测值添加到 preds 列表中。'''preds = np.array(preds)preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])# result savefolder_path = './results/' + setting +'/'if not os.path.exists(folder_path):os.makedirs(folder_path)np.save(folder_path+'real_prediction.npy', preds)#保存预测结果return

1.8 process_one_batch

'''
处理一个数据批次
'''
def _process_one_batch(self, dataset_object, batch_x, batch_y, batch_x_mark, batch_y_mark):batch_x = batch_x.float().to(self.device)batch_y = batch_y.float()batch_x_mark = batch_x_mark.float().to(self.device)batch_y_mark = batch_y_mark.float().to(self.device)# decoder inputif self.args.padding==0:dec_inp = torch.zeros([batch_y.shape[0], self.args.pred_len, batch_y.shape[-1]]).float()elif self.args.padding==1:dec_inp = torch.ones([batch_y.shape[0], self.args.pred_len, batch_y.shape[-1]]).float()#根据 self.args.padding 的值创建一个全零或全一的张量作为解码器的初始输入dec_inp = torch.cat([batch_y[:,:self.args.label_len,:], dec_inp], dim=1).float().to(self.device)#将这个张量与 batch_y 的一部分拼接,形成完整的解码器输入# encoder - decoderif self.args.use_amp:with torch.cuda.amp.autocast():if self.args.output_attention:outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]else:outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)else:if self.args.output_attention:outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]else:outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)if self.args.inverse:outputs = dataset_object.inverse_transform(outputs)#encoder-decoder的输出f_dim = -1 if self.args.features=='MS' else 0batch_y = batch_y[:,-self.args.pred_len:,f_dim:].to(self.device)#从 batch_y 中选择与预测长度相对应的部分,并移动到指定设备。#f_dim 变量用于确定特征维度。return outputs, batch_y

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/187182.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

huggingface - pipeline - translate 记录

文章目录 #!/usr/bin/env python # -*- encoding: utf-8 -*-import os ,sys from transformers import pipelinemodel_checkpoint "Helsinki-NLP/opus-mt-zh-en" translator pipeline("translation", modelmodel_checkpoint)def translate_arr(arr): r…

Java 数据结构篇-用链表、数组实现栈

🔥博客主页: 【小扳_-CSDN博客】 ❤感谢大家点赞👍收藏⭐评论✍ 文章目录 1.0 栈的说明 2.0 用链表来实现栈 2.1 实现栈 - 入栈方法(push) 2.2 实现栈 - 出栈(pop) 2.3 实现栈 - 查看栈顶元素…

mybatis 实现批量更新的三种方式

注&#xff1a;Mybatis实现批量更新有三种方式&#xff0c;分别是使用foreach标签、使用SQL的case when语句和使用动态SQL的choose语句。具体实现方法如下&#xff1a; 1&#xff1a;使用foreach标签 <update id"batchUpdate" parameterType"java.util.Lis…

C 标准库 <errno.h>与 <float.h>

C 标准库 <errno.h> C 标准库的 errno.h 头文件定义了整数变量 errno&#xff0c;它是通过系统调用设置的&#xff0c;在错误事件中的某些库函数表明了什么发生了错误。该宏扩展为类型为 int 的可更改的左值&#xff0c;因此它可以被一个程序读取和修改。 在程序启动时…

国产linux单用户模式破解无密码登陆 (麒麟系统用户登录密码遗忘解决办法)

笔者手里有一批国产linu系统&#xff0c;目前开始用在日常的工作生产环境中&#xff0c;我这个老程序猿勉为其难的充当运维的或网管的角色。 国产linux系统常见的为麒麟Linux&#xff0c;统信UOS等&#xff0c;基本都是基于debian再开发的linux。 问题描述&#xff1a; 因为…

基于AT89C51单片机的倒数计时器设计

1&#xff0e;设计任务 利用AT89C51单片机为核心控制元件,设计一个简易的数字电压表&#xff0c;设计的系统实用性强、操作简单&#xff0c;实现了智能化、数字化。 本设计采用单片机为主控芯片&#xff0c;结合周边电路组成LED彩灯的闪烁控制系统器&#xff0c;用来控制红色…

用于缓存一些固定名称的小组件

项目中&#xff0c;用于缓存姓名、地名、单位名称等一些较固定名称的id-name小组件。用于减少一些表的关连操作和冗余字段。优化代码结构。扩展也方便&#xff0c;写不同的枚举就行了。 具体用法&#xff1a; {NameCacheUser.USER.getName(userId);NameCacheUser.ACCOUNT.getN…

excel合并单元格教程

在表格里&#xff0c;总是会遇到一级表格、二级表格的区别&#xff0c;这时候一级表格会需要合并成一个大格子&#xff0c;那么excel如何合并单元格呢&#xff0c;其实使用快捷键或者功能键就可以了。 excel如何合并单元格&#xff1a; 1、首先我们用鼠标选中所有要合并的单元…

最大公约数的C语言实现xdoj31

时间限制: 1 S 内存限制: 1000 Kb 问题描述: 最大公约数&#xff08;GCD&#xff09;指某几个整数共有因子中最大的一个&#xff0c;最大公约数具有如下性质&#xff0c; gcd(a,0)a gcd(a,1)1 因此当两个数中有一个为0时&#xff0c;gcd是不为0的那个整数&#xff…

Redis编码类型及对应含义

对象类型编码类型(encoding)取值范围Stringintlong长度范围内的数字embstr长度小于40的value值。数字和字符。raw长度大于40的value值Listziplist所有元素长度小于64字节&#xff0c;并且列表元素的个数小于512个linkedlist不满足ziplist的数据Setintset纯数字&#xff0c;列表…

分治法之查找最大值

思路: 定义一个递归函数 findMax&#xff0c;它接受三个参数&#xff1a;数组 arr、起始位置 start 和结束位置 end。如果 start 等于 end&#xff0c;说明数组中只有一个元素&#xff0c;直接返回该元素的值作为最大值。否则&#xff0c;计算数组的中间位置 mid&#xff0c;可…

XXL-Job详解(一):组件架构

目录 XXL-Job特性系统组成架构图调度模块剖析任务 “运行模式” 剖析执行器 XXL-Job XXL-JOB是一个分布式任务调度平台&#xff0c;其核心设计目标是开发迅速、学习简单、轻量级、易扩展。现已开放源代码并接入多家公司线上产品线&#xff0c;开箱即用。 特性 1、简单&#…

java+springboot实验室管理系统的设计与实现ssm+jsp

课题研究内容&#xff1a; &#xff08;1&#xff09; 系统需求分析&#xff08;构成模块&#xff0c;系统流程&#xff0c;功能结构图&#xff0c;系统需求&#xff09; &#xff08;2&#xff09; 实验室课程安排功能模块&#xff08;课程的录入和调补&#xff09; &#xff…

prompt提示

用例生成 # 任务描述 作为一个高级c程序员&#xff0c;需要完成下列功能的gtest测试用例 # 功能描述 给定两个数字型字符串s1和s2,求和&#xff0c;返回值也是字符串 # 接口举例 调用strAdd("123", "132"),输出“255” # 输出要求 - 入参为空串、nu…

wyler水平仪维修WYLER倾角仪维修CH-8405

瑞士WYLER电子水平仪维修&#xff1b;BIueCLINO倾斜度测量仪维修&#xff1b;wyler电子倾角仪维修。 水平仪常见故障及处理方法 1、 仪表通电不工作。 A、检查仪表220V电源端子接线是否正确 B、检查仪表电容是否熔断&#xff1b; C、拧下仪表后的固定螺钉&#xff0c;将表…

王道数据结构课后代码题p40 4.在带头结点的单链表L中删除一个最小值结点的高效算法(假设最小值唯一) (c语言代码实现)

本题代码为 void deletemin(linklist* L)//找到最小值并删除 {lnode* p (*L)->next, * pre *L;lnode* s p,*sprepre;while (p ! NULL)//找到最小值{if (p->data < s->data){s p;spre pre;}p p->next;pre pre->next;}p s->next;spre->next p;…

有IP没有域名可以申请证书吗?

一、IP证书是什么&#xff1f; ip证书是用于公网ip地址的SSL证书&#xff0c;与我们通常所讲的SSL证书并无本质上的区别&#xff0c;但由于SSL证书通常颁发给域名&#xff0c;而组织机构需要公共ip地址的SSL证书&#xff0c;这类SSL证书就是我们所说的ip证书。ip证书具有安全、…

仅仅通过提示词,GPT-4可以被引导成为多个领域的特定专家

The Power of Prompting&#xff1a;提示的力量&#xff0c;仅通过提示&#xff0c;GPT-4可以被引导成为多个领域的特定专家。微软研究院发布了一项研究&#xff0c;展示了在仅使用提策略的情况下让GPT 4在医学基准测试中表现得像一个专家。研究显示&#xff0c;GPT-4在相同的基…

查看代码运行时间

#include<bits/stdc.h> signed main() {clock_t start_time clock();std::cout<<"hello\n";double tot_time double(clock() - start_time) / 1000; std::cout << "\n代码跑了" << tot_time << "秒";return 0; …

上海毅速丨新材料将推动3D打印在压铸行业的应用

压铸是一种应用广泛的制造工艺&#xff0c;它的制造原理是将液态或半液态金属&#xff0c;在高压作用下&#xff0c;以高速度填充压铸模具型腔&#xff0c;并在压力下快速凝固而获得铸件的一种方法。压铸模的设计和制造需要考虑到多方面的因素&#xff0c;如模具材料、结构、冷…