时间序列预测实战(十二)DLinear模型实现滚动长期预测并可视化预测结果

官方论文地址->官方论文地址

官方代码地址->官方代码地址

个人修改代码->个人修改的代码已经上传CSDN免费下载

一、本文介绍

本文给大家带来是DLinear模型,DLinear是一种用于时间序列预测(TSF)的简单架构,DLinear的核心思想是将时间序列分解为趋势和剩余序列,并分别使用两个单层线性网络对这两个序列进行建模以进行预测(值得一提的是DLinear的出现是为了挑战Transformer在实现序列预测中有效性)本文的讲解内容包括:模型原理、数据集介绍、参数讲解、模型训练和预测、结果可视化、训练个人数据集,讲解顺序如下->

预测类型->单元预测、多元预测

适用对象->如果你的配置不是很好这个模型应该很适合你因为参数量很小训练速度很快

二、模型原理

DLinear模型出现是为了调整Transformer的有效性从而存在,Transformer的设计都十分的复杂和需要大量的参数,所以作者提出了一种简单的结构DLinear(参数量我实验过程中确实非常小)

DLinear的核心思想是将时间序列分解为趋势和剩余序列,并分别使用两个单层线性网络对这两个序列进行建模以进行预测。

具体地,DLinear如何工作的关键点如下

  1. 时间序列分解:DLinear将输入的时间序列分解为两部分——趋势部分和剩余部分。这种分解有助于分别处理时间序列中的长期趋势和短期波动。

  2. 单层线性网络:对于趋势和剩余序列,DLinear分别使用两个单层的线性网络进行建模。这种简单的架构使得DLinear在处理时间序列时既高效又有效。

  3. 预测任务:在进行预测时,DLinear结合这两个网络的输出来生成最终的时间序列预测。

总结->可以看出DLinear的核心结构真的十分简单就包括一个分解和两个线性网络进行建模最后经过一个简单的相加就输出了结果。

模型的网络结构图如下所示->

图片分析->可以看到和我们上面讲的一样,数据从输入进来经过两个分支,一个为趋势性一个为剩余序列,然后分别经过一个线性层处理(这里的提到的线性层就是普通的全连接层),然后将结果进行简单的拼接就完成了结果的输出(这就这样的简单模型结果比过程十分复杂的Transformer模型效果要好->我自己实验效果确实要好,我拿2020年的bestpaper和普通的Transformer都进行了对比效果确实要有提升)。

下面的图片是一个简单的线性层(普通的全连接层)提取数据的过程图->

这里把模型的代码结构放出来方便大家根据讲解和代码进行对比。

class moving_avg(nn.Module):"""Moving average block to highlight the trend of time series"""def __init__(self, kernel_size, stride):super(moving_avg, self).__init__()self.kernel_size = kernel_sizeself.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)def forward(self, x):# padding on the both ends of time seriesfront = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)x = torch.cat([front, x, end], dim=1)x = self.avg(x.permute(0, 2, 1))x = x.permute(0, 2, 1)return xclass series_decomp(nn.Module):"""Series decomposition block"""def __init__(self, kernel_size):super(series_decomp, self).__init__()self.moving_avg = moving_avg(kernel_size, stride=1)def forward(self, x):moving_mean = self.moving_avg(x)res = x - moving_meanreturn res, moving_meanclass Model(nn.Module):"""Decomposition-Linear"""def __init__(self, configs):super(Model, self).__init__()self.seq_len = configs.seq_lenself.pred_len = configs.pred_len# Decompsition Kernel Sizekernel_size = 25self.decompsition = series_decomp(kernel_size)self.individual = configs.individualself.channels = configs.enc_inif self.individual:self.Linear_Seasonal = nn.ModuleList()self.Linear_Trend = nn.ModuleList()for i in range(self.channels):self.Linear_Seasonal.append(nn.Linear(self.seq_len,self.pred_len))self.Linear_Trend.append(nn.Linear(self.seq_len,self.pred_len))# Use this two lines if you want to visualize the weights# self.Linear_Seasonal[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))# self.Linear_Trend[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))else:self.Linear_Seasonal = nn.Linear(self.seq_len,self.pred_len)self.Linear_Trend = nn.Linear(self.seq_len,self.pred_len)# Use this two lines if you want to visualize the weights# self.Linear_Seasonal.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))# self.Linear_Trend.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))def forward(self, x):# x: [Batch, Input length, Channel]seasonal_init, trend_init = self.decompsition(x)seasonal_init, trend_init = seasonal_init.permute(0,2,1), trend_init.permute(0,2,1)if self.individual:seasonal_output = torch.zeros([seasonal_init.size(0),seasonal_init.size(1),self.pred_len],dtype=seasonal_init.dtype).to(seasonal_init.device)trend_output = torch.zeros([trend_init.size(0),trend_init.size(1),self.pred_len],dtype=trend_init.dtype).to(trend_init.device)for i in range(self.channels):seasonal_output[:,i,:] = self.Linear_Seasonal[i](seasonal_init[:,i,:])trend_output[:,i,:] = self.Linear_Trend[i](trend_init[:,i,:])else:seasonal_output = self.Linear_Seasonal(seasonal_init)trend_output = self.Linear_Trend(trend_init)x = seasonal_output + trend_outputreturn x.permute(0,2,1) # to [Batch, Output length, Channel]

我看论文的内容大比分都是对比实验,因为DLinear的产生就是为了质疑Transformer所以他和各种Transformer的模型进行对比试验,因为本篇文章就是DLinear的实战案例,对比的部分我就不讲了,大家有兴趣可以看看论文内容在最上面我已经提供了链接。 

三、数据集介绍

所用到的数据集为某公司的业务水平评估和其它参数具体的内容我就介绍了估计大家都是想用自己的数据进行训练模型,这里展示部分图片给大家提供参考。

四、参数讲解

模型的参数如下(大部分都是一些公共参数并不涉及模型)->

parser = argparse.ArgumentParser(description='DLinearNet Multivariate Time Series Forecasting')# basic configparser.add_argument('--train', type=bool, default=True, help='Whether to conduct training')parser.add_argument('--rollingforecast', type=bool, default=True, help='rolling forecast True or False')parser.add_argument('--rolling_data_path', type=str, default='ETTh1-Test.csv', help='rolling data file')parser.add_argument('--show_results', type=bool, default=True, help='Whether show forecast and real results graph')parser.add_argument('--model', type=str, default='SCINet',help='Model name')# data loaderparser.add_argument('--root_path', type=str, default='./data/', help='root path of the data file')parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file')parser.add_argument('--features', type=str, default='MS',help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')parser.add_argument('--freq', type=str, default='h',help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')parser.add_argument('--checkpoints', type=str, default='./models/', help='location of model models')# forecasting taskparser.add_argument('--seq_len', type=int, default=126, help='input sequence length')parser.add_argument('--label_len', type=int, default=64, help='start token length')parser.add_argument('--pred_len', type=int, default=4, help='prediction sequence length')# modelparser.add_argument('--individual', action='store_true', default=False,help='DLinear: a linear layer for each variate(channel) individually')parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')parser.add_argument('--c_out', type=int, default=1, help='output size')parser.add_argument('--dropout', type=float, default=0.05, help='dropout')parser.add_argument('--embed', type=str, default='timeF',help='time features encoding, options:[timeF, fixed, learned]')parser.add_argument('--activation', type=str, default='gelu', help='activation')# optimizationparser.add_argument('--num_workers', type=int, default=0, help='data loader num workers')parser.add_argument('--train_epochs', type=int, default=10, help='train epochs')parser.add_argument('--batch_size', type=int, default=16, help='batch size of train input data')parser.add_argument('--learning_rate', type=float, default=0.001, help='optimizer learning rate')parser.add_argument('--loss', type=str, default='mse', help='loss function')parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate')# GPUparser.add_argument('--use_gpu', type=bool, default=True, help='use gpu')parser.add_argument('--device', type=int, default=0, help='gpu')

模型的详细参数讲解如下(如果你想训练你自己的数据集可以仔细看看)->

参数名称参数类型参数讲解
0trainbool是否进行训练,如果你单纯只想进行预测设置为False即可,
1rollingforecastbool是否进行滚动预测,如果是则设置为True,如果不进行滚动预测则进行正常的预测
2rolling-data-pathstr如果进行滚动预测则需要添加新的和训练文件相同格式的数据
3show_resultsbool是否保存预测值和真实值的对比
4modelstr定义的模型名称
5root_pathstr这个才是你文件的路径,不要到具体的文件,到目录级别即可。
6data_pathstr这个填写你文件的具体名称。
7featuresstr这个是特征有三个选项M,MS,S。分别是多元预测多元,多元预测单元,单元预测单元。
8targetstr这个是你数据集中你想要预测那一列数据,假设我预测的是油温OT列就输入OT即可。
9freqstr时间的间隔,你数据集每一条数据之间的时间间隔。
10checkpointsstr训练出来的模型保存路径
11seq_lenint用过去的多少条数据来预测未来的数据
12label_lenint可以理解为更高的权重占比的部分要小于seq_len
13pred_lenint预测未来多少个时间点的数据
14enc_inint你数据有多少列,要减去时间那一列,这里我是输入8列数据但是有一列是时间所以就填写7
15dec_inint同上
16individualbool这个就是我们上面提到的两个线性层,如果为True我们则对每一个通道用单独的线性层处理,False则为所有的通道用一个线性层
17c_outint这里有一些不同如果你的features填写的是M那么和上面就一样,如果填写的MS那么这里要输入1因为你的输出只有一列数据。
18dropoutfloat这个应该都理解不说了,丢弃的概率,防止过拟合的。
19embedstr时间特征的编码方式,默认为"timeF"
20activationstr激活函数
21num_workersint线程windows大家最好设置成0否则会报线程错误,linux系统随便设置。
22train_epochsint训练的次数
23batch_sizeint一次往模型力输入多少条数据
24learning_ratefloat学习率。
25lossstr     损失函数,默认为"mse"
26lradjstr     学习率的调整方式,默认为"type1"
27use_gpubool是否使用GPU训练,根据自身来选择
28gpuintGPU的编号

五、模型训练和预测

1.项目目录结构

项目的目录构造如下->

其中data为训练用的数据放的地方,layers为模型结构存放的地方,models为训练保存的训练模型,results为可视化结果保存的图片和滚动预测的结果,util为一些工具。 

2.模型训练

当我们经过上面的参数讲解之后,我们可以开始训练模型了,控制台输出如下->

3.滚动预测 

这里进行滚动预测的控制台输出->

4.结果展示 

运行结果后,结果保存到同级目录下(下图为预测值和真实值的对比)-> 

5.结果分析 

可以看到预测值和真实值之间的差距还可以,但是这个模型的参数量少得可怜,不得不得质疑Transformer模型的有效性~

六、训练你个人数据集

这个模型我在写的过程中为了节省大家训练自己数据集,我基本上把大部分的参数都写好了,需要大家注意的就是如果要进行滚动预测下面的参数要设置为True。

    parser.add_argument('--rollingforecast', type=bool, default=True, help='rolling forecast True or False')

如果上面的参数设置为True那么下面就要提供一个进行滚动预测的数据集该数据集的格式要和你训练模型的数据集格式完全一致(重要!!!),如果没有可以考虑在自己数据的尾部剪切一部分,不要粘贴否则数据模型已经训练过了的话预测就没有效果了。 

    parser.add_argument('--rolling_data_path', type=str, default='ETTh1-Test.csv', help='rolling data file')

其它的没什么可以讲的了大部分的修改操作在参数讲解的部分我都详细讲过了,这里的滚动预测可能是大家想看的所以摘出来详细讲讲。 

总结

到此本文已经全部讲解完成了,希望能够帮助到大家,在这里也给大家推荐一些我其它的博客的时间序列实战案例讲解,其中有数据分析的讲解就是我前面提到的如何设置参数的分析博客,最后希望大家订阅我的专栏,本专栏均分文章均分98,并且免费阅读。

概念理解 

15种时间序列预测方法总结(包含多种方法代码实现)

数据分析

时间序列预测中的数据分析->周期性、相关性、滞后性、趋势性、离群值等特性的分析方法

机器学习——难度等级(⭐⭐)

时间序列预测实战(四)(Xgboost)(Python)(机器学习)图解机制原理实现时间序列预测和分类(附一键运行代码资源下载和代码讲解)

深度学习——难度等级(⭐⭐⭐⭐)

时间序列预测实战(五)基于Bi-LSTM横向搭配LSTM进行回归问题解决

时间序列预测实战(七)(TPA-LSTM)结合TPA注意力机制的LSTM实现多元预测

时间序列预测实战(三)(LSTM)(Python)(深度学习)时间序列预测(包括运行代码以及代码讲解)

时间序列预测实战(十一)用SCINet实现滚动预测功能(附代码+数据集+原理介绍)

Transformer——难度等级(⭐⭐⭐⭐)

时间序列预测模型实战案例(八)(Informer)个人数据集、详细参数、代码实战讲解

时间序列预测模型实战案例(一)深度学习华为MTS-Mixers模型

个人创新模型——难度等级(⭐⭐⭐⭐⭐)

时间序列预测实战(十)(CNN-GRU-LSTM)通过堆叠CNN、GRU、LSTM实现多元预测和单元预测

传统的时间序列预测模型(⭐⭐)

时间序列预测实战(二)(Holt-Winter)(Python)结合K-折交叉验证进行时间序列预测实现企业级预测精度(包括运行代码以及代码讲解)

时间序列预测实战(六)深入理解ARIMA包括差分和相关性分析

融合模型——难度等级(⭐⭐⭐)

时间序列预测实战(九)PyTorch实现融合移动平均和LSTM-ARIMA进行长期预测

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/139001.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Aspose.OCR for .NET 2023Crack

Aspose.OCR for .NET 2023Crack 为.NET在图片上播放OCR使所有用户和程序员都可以从特定的图像片段中提取文本和相关的细节,如字体、设计以及书写位置。这一特定属性为OCR的性能及其在扫描遵循排列的记录时的功能提供了动力。OCR的库使用一条线甚至几条线来处理这些特…

19 异步通知

一、异步通知 1. 异步通知简介 阻塞和非阻塞两种方式都是需要应用程序去主动查询设备的使用情况。 异步通知类似于驱动可以主动报告自己可以访问,应用程序获取信号后会从驱动设备中读取或写入数据。 异步通知最核心的就是信号: #define SIGHUP 1 /* 终…

[工业自动化-14]:西门子S7-15xxx编程 - 软件编程 - STEP7 TIA博途是全集成自动化软件TIA portal快速入门

目录 一、TIA博途是全集成自动化软件TIA portal快速入门 1.1 简介 1.2 软件常用界面 1.3 软件安装的电脑硬件要求 1.4 入口 1.5 主界面 二、PLC软件编程包含哪些内容 2.1 概述 2.2 电机运动控制 一、TIA博途是全集成自动化软件TIA portal快速入门 1.1 简介 Siemens …

openssl研发之base64编解码实例

一、base64编码介绍 Base64编码是一种将二进制数据转换成ASCII字符的编码方式。它主要用于在文本协议中传输二进制数据,例如电子邮件的附件、XML文档、JSON数据等。 Base64编码的特点如下: 字符集: Base64编码使用64个字符来表示二进制数据…

SparkSQL语法优化

SparkSQL在整个执行计划处理的过程中,使用了Catalyst 优化器。 1 基于RBO的优化 在Spark 3.0 版本中,Catalyst 总共有 81 条优化规则(Rules),分成 27 组(Batches),其中有些规则会被归…

uniapp——项目02

分类 创建cate分支 渲染分类页面的基本结构 效果页面,包含左右两个滑动区. 利用提供的api获取当前设备的信息。用来计算窗口高度。可食用高度就是屏幕高度减去上下导航栏的高度。 最终效果: 每一个激活项都特殊背景色,又在尾部加了个红条一样的东西。 export d…

python3GUI--QQ音乐By:PyQt5(附下载地址)

文章目录 一.前言二.展示0.播放页1.主界面1.精选2.有声电台3.排行4.歌手5.歌单 2.推荐3.视频1.视频2.分类3.视频分类 4.雷达5.我喜欢1.歌曲2.歌手 6.本地&下载7.最近播放8.歌单1.一般歌单2.自建歌单3.排行榜 9.其他1.搜索词推荐2.搜索结果 三&#x…

ElasticSearch7.x - HTTP 操作 - 文档操作

创建文档(添加数据) 索引已经创建好了,接下来我们来创建文档,并添加数据。这里的文档可以类比为关系型数 据库中的表数据,添加的数据格式为 JSON 格式 向 ES 服务器发 POST 请求 :http://192.168.254.101:9200/shopping/_doc 请求体内容为: {"title":"小…

智慧城市建设解决方案分享【完整】

文章目录 第1章 前言第2章 智慧城市建设的背景2.1 智慧城市的发展现状2.2 智慧城市的发展趋势 第3章 智慧城市“十二五”规划要点3.1 国民经济和社会发展“十二五”规划要点3.2 “十二五”信息化发展规划要点 第4章 大数据:智慧城市的智慧引擎4.1 大数据技术—智慧城…

智慧城市照明为城市节能降耗提供支持继电器开关钡铼S270

智慧城市照明:为城市节能降耗提供支持——以钡铼技术S270继电器开关为例 随着城市化进程的加速,城市照明系统的需求也日益增长。与此同时,能源消耗和环境污染问题日益严重,使得城市照明的节能减排成为重要议题。智慧城市照明系统…

Linux技能篇-yum源搭建(本地源和公网源)

文章目录 前言一、yum源是什么?二、使用镜像搭建本地yum源1.搭建临时仓库第一步:挂载系统ios镜像到虚拟机第二步:在操作系统中挂载镜像第三步:修改yum源配置文件 2.搭建本地仓库第一步:搭建临时yum源来安装httpd并做文…

javaEE案例,前后端交互,计算机和用户登录

加法计算机,前端的代码如下 : 浏览器访问的效果如图 : 后端的代码如下 再在浏览器进行输入点击相加,就能获得结果 开发中程序报错,如何定位问题 1.先定位前端还是后端(通过日志分析) 1)前端 : F12 看控制台 2)后端 : 接口,控制台日志 举个例子: 如果出现了错误,我们就在后端…

如何查看网站的https的数字证书

如题 打开Chrome浏览器,之后输入想要抓取https证书的网址,此处以知乎为例点击浏览器地址栏左侧的锁的按钮,如下图 点击“连接是安全的”选项,如下图 点击“证书有效”选项卡,如下图 查看基本信息和详细信息 点击详细信…

C/C++数字判断 2021年9月电子学会青少年软件编程(C/C++)等级考试一级真题答案解析

目录 C/C数字判断 一、题目要求 1、编程实现 2、输入输出 二、算法分析 三、程序编写 四、程序说明 五、运行结果 六、考点分析 C/C数字判断 2021年9月 C/C编程等级考试一级编程题 一、题目要求 1、编程实现 输入一个字符,如何输入的字符是数字&#x…

Spark的转换算子和操作算子

1 Transformation转换算子 1.1 Value类型 1)创建包名:com.shangjack.value 1.1.1 map()映射 参数f是一个函数可以写作匿名子类,它可以接收一个参数。当某个RDD执行map方法时,会遍历该RDD中的每一个数据项,并依次应用f函…

Mac下eclipse配置JDK

一、配置JDK,需要电脑下载Java并且配置环境 Mac环境配置(Java)----使用bash_profile进行配置(附下载地址) (1)、左上角找到“Eclipse”-->“Preferences...” (2)、找到“Java”-->“Installde JREs”-->界…

S7-1200PLC和SMART PLC开放式以太网通信(UDP双向通信)

S7-1200PLC的以太网通信UDP通信相关介绍还可以参考下面文章链接: 博途PLC开放式以太网通信TRCV_C指令应用编程(运动传感器UDP通信)-CSDN博客文章浏览阅读2.8k次。博途PLC开放式以太网通信TSENG_C指令应用,请参看下面的文章链接:博途PLC 1200/1500PLC开放式以太网通信TSEND_…

AI:73-结合语法知识的神经机器翻译研究

🚀 本文选自专栏:AI领域专栏 从基础到实践,深入了解算法、案例和最新趋势。无论你是初学者还是经验丰富的数据科学家,通过案例和项目实践,掌握核心概念和实用技能。每篇案例都包含代码实例,详细讲解供大家学习。 📌📌📌在这个漫长的过程,中途遇到了不少问题,但是…

HTTPS的工作流程

. HTTPS是什么? https是应用层中的一个协议,是在http协议的基础上引入的一个加密层。 为什么需要HTTPS 由于http协议内容都是按照文本的方式明文传输的,这就导致传输过程中会出现一些被篡改的情况。运营商劫持事件最开始百度,…

云栖大会丨桑文锋:打造云原生数字化客户经营引擎

近日,2023 云栖大会在杭州举办。今年云栖大会回归了 2015 的主题:「计算,为了无法计算的价值」。神策数据创始人 & CEO 桑文锋受邀出席「生态产品与伙伴赋能」技术主题,并以「打造云原生数字化客户经营引擎」为主题进行演讲。…