informer 辅助笔记:main_informer.py

运行 informer的主文件

import argparse
import os
import torchfrom exp.exp_informer import Exp_Informer

1 参数

parser.add_argument的这些

参数名称参数描述
model实验模型。可以设置为informer、informerstack、informerlight(TBD)
data数据集名称
root_path数据文件的根路径(默认为./data/ETT/)
data_path数据文件名称(默认为ETTh1.csv)
features

预测任务(默认为M)。可以设置为M、S、MS

(M:多变量预测多变量,S:单变量预测单变量,MS:多变量预测单变量)

targetS或MS任务中的目标特征(默认为OT)
freq

时间特征编码的频率(默认为h)

可以设置为s(秒)、t(分钟)、h(小时)、d(日)、b(工作日)、w(周)、m(月)。也可以使用更详细的频率,如15min或3h

checkpoints模型检查点的位置(默认为./checkpoints/)
seq_lenInformer编码器的输入序列长度(默认为96)
label_lenInformer解码器的起始标记长度(默认为48)
pred_len预测序列长度(默认为24)
enc_in编码器输入大小(默认为7)
dec_in解码器输入大小(默认为7)
c_out输出大小(默认为7)
d_model模型的维度(默认为512)
n_heads头的数量(默认为8)
e_layers编码器层的数量(默认为2)
d_layers解码器层的数量(默认为1)
s_layers堆叠编码器层的数量(默认为3,2,1)
d_fffcn的维度(默认为2048)
factorProbsparse attn因子(默认为5)
padding填充类型(默认为0)
distil是否在编码器中使用提炼,使用此参数表示不使用提炼(默认为True)
dropout丢弃的概率(默认为0.05)
attn编码器中使用的注意力(默认为prob)。可以设置为prob(informer)、full(transformer)
embed时间特征的编码(默认为timeF)。可以设置为timeF、fixed、learned
activation激活函数(默认为gelu)
output_attention是否在编码器中输出注意力,使用此参数表示输出注意力(默认为False)
do_predict是否预测未见的未来数据,使用此参数表示进行预测(默认为False)
mix是否在生成解码器中使用混合注意力,使用此参数表示不使用混合注意力(默认为True)
cols数据文件中作为输入特征的某些列
num_workersData loader的工作数(默认为0)
itr实验次数(默认为2)
train_epochs训练周期(默认为6)
batch_size训练输入数据的批量大小(默认为32)
patience提前停止的耐心(默认为3)
learning_rate优化器学习率(默认为0.0001)
des实验描述(默认为test)
loss损失函数(默认为mse)
lradj调整学习率的方式(默认为type1)
use_amp是否使用自动混合精度训练,使用此参数表示使用amp(默认为False)
inverse是否反转输出数据,使用此参数表示反转输出数据(默认为False)
use_gpu是否使用gpu(默认为True)
gpu用于训练和推理的gpu编号(默认为0)
use_multi_gpu是否使用多个gpu,使用此参数表示使用多个gpu(默认为False)
devices多个gpu的设备ID(默认为0,1,2,3)

2 其他部分

2.1 GPU 相关

args = parser.parse_args()
#解析命令行参数args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False
#检查是否可以使用 GPUif args.use_gpu and args.use_multi_gpu:args.devices = args.devices.replace(' ','')device_ids = args.devices.split(',')args.device_ids = [int(id_) for id_ in device_ids]args.gpu = args.device_ids[0]#如果启用了 GPU 且设置了多 GPU 使用,代码会解析 GPU 设备 ID,并准备相应的 GPU 设置。

2.2 数据集相关

data_parser = {'ETTh1':{'data':'ETTh1.csv','T':'OT','M':[7,7,7],'S':[1,1,1],'MS':[7,7,1]},'ETTh2':{'data':'ETTh2.csv','T':'OT','M':[7,7,7],'S':[1,1,1],'MS':[7,7,1]},'ETTm1':{'data':'ETTm1.csv','T':'OT','M':[7,7,7],'S':[1,1,1],'MS':[7,7,1]},'ETTm2':{'data':'ETTm2.csv','T':'OT','M':[7,7,7],'S':[1,1,1],'MS':[7,7,1]},'WTH':{'data':'WTH.csv','T':'WetBulbCelsius','M':[12,12,12],'S':[1,1,1],'MS':[12,12,1]},'ECL':{'data':'ECL.csv','T':'MT_320','M':[321,321,321],'S':[1,1,1],'MS':[321,321,1]},'Solar':{'data':'solar_AL.csv','T':'POWER_136','M':[137,137,137],'S':[1,1,1],'MS':[137,137,1]},
}
'''
这是一个数据解析器字典,包含不同数据集的配置信息如文件名,目标列,输入输出目标维度(M:多变量预测多变量,S:单变量预测单变量,MS:多变量预测单变量)
'''
if args.data in data_parser.keys():data_info = data_parser[args.data]args.data_path = data_info['data']args.target = data_info['T']args.enc_in, args.dec_in, args.c_out = data_info[args.features]
'''
检查输入的数据集是否在 data_parser 中定义,如果是,则从字典中获取相应的配置。数据路径、目标列、encoder输入、decoder输入、decoder输出的维度
'''

2.3 设置参数


args.s_layers = [int(s_l) for s_l in args.s_layers.replace(' ','').split(',')]
#解析并设置堆叠层的数量args.detail_freq = args.freq
#将频率详细信息保存在另一个参数中args.freq = args.freq[-1:]
#频率的最后一个元素(一般是h,s,m这些)print('Args in experiment:')
print(args)

2.4

Exp = Exp_Informerfor ii in range(args.itr):# 对于每次迭代,根据实验参数设置进行训练和测试。setting = '{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_at{}_fc{}_eb{}_dt{}_mx{}_{}_{}'.format(args.model, args.data, args.features, args.seq_len, args.label_len, args.pred_len,args.d_model, args.n_heads, args.e_layers, args.d_layers, args.d_ff, args.attn, args.factor, args.embed, args.distil, args.mix, args.des, ii)exp = Exp(args) # 使用给定参数实例化实验对象print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))exp.train(setting)print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))exp.test(setting)#分别对模型进行训练和测试if args.do_predict:print('>>>>>>>predicting : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))exp.predict(setting, True)#如果设置为进行预测,那么执行预测torch.cuda.empty_cache()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/184117.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JSP+servlet实现高校社团管理系统

JSPservlet实现的高校社团管理系统 &#xff0c;前后台都有&#xff0c;前台演示地址:高校社团管理系统 后台演示地址:登录 用户名:sys,密码:123456 前台功能&#xff1a;首页&#xff0c;社团列表&#xff0c;社团风采&#xff0c;社团活动&#xff0c;新闻列表&#xff0c…

阿里云新版公共实例从注册账号到创建设备生成参数教程

1 注册阿里云 打开阿里云官网&#xff0c;点击右上角的登录/注册 打开的界面按照图片输入手机号注册 注册成功后&#xff0c;登录返回第一次打开的界面&#xff0c;点击控制台 点击控制台后界面如下 点击左上角的菜单&#xff0c;弹出新窗口&#xff0c;搜索物联网平台 开通物…

在Django中使用Q对象和条件运算符来构建动态查询

示例代码&#xff0c;展示了如何根据cost_min和cost_max的值构建查询条件&#xff1a; from django.db.models import Q# 构建查询条件 query Q() # 创建一个空的Q对象# 添加单价范围查询条件 if cost_min is not None:query & Q(UnitCost__gtecost_min) # 添加大于等于…

springmvc(基础学习整合)

SpringMVC是Spring框架提供的构建Web应用程序的全功能MVC模块。 在SpringMVC的各个组件中&#xff0c;处理器映射器、处理器适配器、视图解析器称为SpringMVC的三大组件。 springMVC基本介绍&#xff1a; http://t.csdnimg.cn/TOzw9 MVC是一种设计思想&#xff0c;将一个应…

AcWing 3555:二叉树(多次询问两个结点之间的最短路径长度) ← DFS

【题目来源】https://www.acwing.com/problem/content/3558/【题目描述】 给定一个 n 个结点&#xff08;编号 1∼n&#xff09;构成的二叉树&#xff0c;其根结点为 1 号点。 进行 m 次询问&#xff0c;每次询问两个结点之间的最短路径长度。 树中所有边长均为 1。【输入格式】…

华为云cce容器管理中的调度策略作用

研究不深&#xff0c;但是这个还是挺重要的&#xff0c;在这里记录一下。 在cce节点集群中&#xff0c;有时候会发现有的节点实例过于饱满&#xff0c;有的又有些空&#xff0c;导致部分节点由于压力过大&#xff0c;存在崩溃的危险&#xff0c;这时候调度策略就有用了。 我这…

图扑参展高交会-全球清洁能源创新博览会

“相聚鹏城深圳&#xff0c;共享能源盛宴” 第二十五届中国国际高新技术成果交易会(简称“高交会”)于 11 月 15-18 日在深圳盛大开幕。高交会由商务部、科学技术部、工业和信息化部、国家发展改革委、农业农村部、国家知识产权局、中国科学院、中国工程院和深圳市人民政府共同…

nvm for windows使用与node/npm/yarn的配置

1 下载 nvm for windows download – github 下拉到Assets, 下载.exe文件 2 安装 安装到如下文件夹中 目录可以自己选, 可以换别的名字, 自己记住即可 新手建议全部看完再进行个人配置, 或者使用与博主一致的路径 D:\DevelopEnvironment\nvm3 配置nvm使用的镜像 node_mir…

值得收藏的免费好用API

短信验证码&#xff1a;可用于登录、注册、找回密码、支付认证等等应用场景。支持三大运营商&#xff0c;3秒可达&#xff0c;99.99&#xff05;到达率&#xff0c;支持大容量高并发。通知短信&#xff1a;当您需要快速通知用户时&#xff0c;通知短信是最快捷有效的方式。短信…

Carbonyl ,一个可以在终端里运行的浏览器

浏览器对于我们的日常来说是使用频率比较高的一个东西。 一般来说&#xff0c;对于桌面的发行版的linux的浏览器&#xff0c;用的比较多的是Firefox浏览器。对于我们日常windows、mac等。常用的有chrome、edge等。 但是&#xff01;在终端里运行一个浏览器&#xff0c;我想大多…

SaaS模式C/S检验科LIS系统源码

适用于医院检验科实际需要的管理系统, 实现检验业务全流程的计算机管理。从检验申请、标本编号、联机采集、中文报告单的生成与打印、质控图的绘制和数据的检索与备份。通过将所有仪器自身提供的端口与科室LIS系统中的工作站点连接,实现与医院HIS系统的对接。 通过门诊医生和住…

HTML-CSS知识速查

HTML/CSS知识速查 文章目录 HTML/CSS知识速查[toc]网页的组成浏览器**为什么需要Web标准&#xff1a;** **web标准的构成&#xff1a;**HTMLHTML语法导读**1.1 HTML语法规则&#xff1a;**1.2 基本结构标签**1.3 标签的关系&#xff1a;**1. **包含关系&#xff08;Parent-Chil…

java并发编程(一)----初识

一、什么是并发 先看“科普中国”给出的官方解释。并发在操作系统中&#xff0c;是指一个时间段中有几个程序都处于已启动运行到运行完毕之间&#xff0c;且这几个程序都是在同一个处理机上运行&#xff0c;但任一个时刻点上只有一个程序在处理机上运行。 通俗来讲&#xff0c…

岩土工程监测新利器——振弦采集仪

岩土工程监测新利器——振弦采集仪 振弦采集仪是一种常用的岩土工程监测仪器&#xff0c;主要用于测量岩土体的振动和应变情况。它采用先进的数字信号处理技术&#xff0c;可以实时采集和处理振弦信号&#xff0c;快速准确地获取岩土体的振动和应变信息。 振弦采集仪具有以下优…

数据结构---树

树概念及结构 1.树的概念 树是一种非线性的数据结构&#xff0c;它是由n&#xff08;n>0&#xff09;个有限结点组成一个具有层次关系的集合。把它叫做树是因 为它看起来像一棵倒挂的树&#xff0c;也就是说它是根朝上&#xff0c;而叶朝下的 有一个特殊的结点&#xff0c…

在很多nlp数据集上超越tinybert 的新架构nlp神经网络模型

在很多nlp数据集上超越tinybert 的新架构nlp神经网络模型 网络结构图测试代码网络结构图 测试代码 import paddle import numpy as np import pandas as pd from tqdm import tqdmclass FeedFroward(paddle.nn.Layer):

java多线程CountDownLatch简单测试

学习java多线程&#xff0c;请同时参阅 Java多线程 信号量和屏障实现控制并发线程数量&#xff0c;主线程等待所有线程执行完毕1 CountDownLatch能够使一个线程在等待另外一些线程完成各自工作之后再继续执行。当所有的线程都已经完成任务&#xff0c;然后在CountDownLatch上…

git报错invalid object xxx和unable to read tree xxxxxx

电脑出问题了&#xff0c;导致git仓库像是被损坏了一样&#xff0c;执行git status就会报错unable to read ree&#xff0c;无法正常提交代码至仓库&#xff0c;原因是本地代码仓库.git文件损坏了&#xff0c;无法找到正确的提交历史和路径。 找到了一个解决办法&#xff1a; …

TCP 基本认识

1&#xff1a;TCP 头格式有哪些&#xff1f; 序列号&#xff1a;用来解决网络包乱序问题。 确认应答号&#xff1a;用来解决丢包的问题。 2&#xff1a;为什么需要 TCP 协议&#xff1f; TCP 工作在哪一层&#xff1f; IP 层是「不可靠」的&#xff0c;它不保证网络包的交付…

python .onnx 转 .engine亲测ok

安装tensorRT&#xff1a; 1、下载与电脑中cuda和cudnn版本对应的tensorRT&#xff08;比如我的是TensorRT-8.2.1.8.Windows10.x86_64.cuda-11.4.cudnn8.2&#xff09; 2、打开目录里面有python文件夹&#xff0c;找到对应python版本的whl文件&#xff08;我的是tensorrt-8.2…