使用Transformer 模型进行时间序列预测的Pytorch代码示例

时间序列预测是一个经久不衰的主题,受自然语言处理领域的成功启发,transformer模型也在时间序列预测有了很大的发展。本文可以作为学习使用Transformer 模型的时间序列预测的一个起点。

数据集

这里我们直接使用kaggle中的 Store Sales — Time Series Forecasting作为数据。这个比赛需要预测54家商店中各种产品系列未来16天的销售情况,总共创建1782个时间序列。数据从2013年1月1日至2017年8月15日,目标是预测接下来16天的销售情况。虽然为了简洁起见,我们做了简化处理,作为模型的输入包含20列中的3,029,400条数据,。每行的关键列为’ store_nbr ‘、’ family ‘和’ date '。数据分为三类变量:

1、截止到最后一次训练数据日期(2017年8月15日)之前已知的与时间相关的变量。这些变量包括数字变量,如“销售额”,表示某一产品系列在某家商店的销售额;“transactions”,一家商店的交易总数;’ store_sales ‘,该商店的总销售额;’ family_sales '表示该产品系列的总销售额。

2、训练截止日期(2017年8月31日)之前已知,包括“onpromotion”(产品系列中促销产品的数量)和“dcoilwtico”等变量。这些数字列由’ holiday ‘列补充,它表示假日或事件的存在,并被分类编码为整数。此外,’ time_idx ‘、’ week_day ‘、’ month_day ‘、’ month ‘和’ year '列提供时间上下文,也编码为整数。虽然我们的模型是只有编码器的,但已经添加了16天移动值“onpromotion”和“dcoilwtico”,以便在没有解码器的情况下包含未来的信息。

3、静态协变量随着时间的推移保持不变,包括诸如“store_nbr”、“family”等标识符,以及“city”、“state”、“type”和“cluster”等分类变量(详细说明了商店的特征),所有这些变量都是整数编码的。

我们最后生成的df名为’ data_all ',结构如下:

 categorical_covariates= ['time_idx','week_day','month_day','month','year','holiday']categorical_covariates_num_embeddings= []forcolincategorical_covariates:data_all[col] =data_all[col].astype('category').cat.codescategorical_covariates_num_embeddings.append(data_all[col].nunique())categorical_static= ['store_nbr','city','state','type','cluster','family_int']categorical_static_num_embeddings= []forcolincategorical_static:data_all[col] =data_all[col].astype('category').cat.codescategorical_static_num_embeddings.append(data_all[col].nunique())numeric_covariates= ['sales','dcoilwtico','dcoilwtico_future','onpromotion','onpromotion_future','store_sales','transactions','family_sales']target_idx=np.where(np.array(numeric_covariates)=='sales')[0][0]

在将数据转换为适合我的PyTorch模型的张量之前,需要将其分为训练集和验证集。窗口大小是一个重要的超参数,表示每个训练样本的序列长度。此外,’ num_val '表示使用的验证折数,在此上下文中设置为2。将2013年1月1日至2017年6月28日的观测数据指定为训练数据集,以2017年6月29日至2017年7月14日和2017年7月15日至2017年7月30日作为验证区间。

同时还进行了数据的缩放,完整代码如下:

 defdataframe_to_tensor(series,numeric_covariates,categorical_covariates,categorical_static,target_idx):numeric_cov_arr=np.array(series[numeric_covariates].values.tolist())category_cov_arr=np.array(series[categorical_covariates].values.tolist())static_cov_arr=np.array(series[categorical_static].values.tolist())x_numeric=torch.tensor(numeric_cov_arr,dtype=torch.float32).transpose(2,1)x_numeric=torch.log(x_numeric+1e-5)x_category=torch.tensor(category_cov_arr,dtype=torch.long).transpose(2,1)x_static=torch.tensor(static_cov_arr,dtype=torch.long)y=torch.tensor(numeric_cov_arr[:,target_idx,:],dtype=torch.float32)returnx_numeric, x_category, x_static, ywindow_size=16forecast_length=16num_val=2val_max_date='2017-08-15'train_max_date=str((pd.to_datetime(val_max_date) -pd.Timedelta(days=window_size*num_val+forecast_length)).date())train_final=data_all[data_all['date']<=train_max_date]val_final=data_all[(data_all['date']>train_max_date)&(data_all['date']<=val_max_date)]train_series=train_final.groupby(categorical_static+['family']).agg(list).reset_index()val_series=val_final.groupby(categorical_static+['family']).agg(list).reset_index()x_numeric_train_tensor, x_category_train_tensor, x_static_train_tensor, y_train_tensor=dataframe_to_tensor(train_series,numeric_covariates,categorical_covariates,categorical_static,target_idx)x_numeric_val_tensor, x_category_val_tensor, x_static_val_tensor, y_val_tensor=dataframe_to_tensor(val_series,numeric_covariates,categorical_covariates,categorical_static,target_idx)

数据加载器

在数据加载时,需要将每个时间序列从窗口范围内的随机索引开始划分为时间块,以确保模型暴露于不同的序列段。

为了减少偏差还引入了一个额外的超参数设置,它不是随机打乱数据,而是根据块的开始时间对数据集进行排序。然后数据被分成五部分——反映了我们五年的数据集——每一部分都是内部打乱的,这样最后一批数据将包括去年的观察结果,但还是随机的。模型的最终梯度更新受到最近一年的影响,理论上可以改善最近时期的预测。

 defdivide_shuffle(df,div_num):space=df.shape[0]//div_numdivision=np.arange(0,df.shape[0],space)returnpd.concat([df.iloc[division[i]:division[i]+space,:].sample(frac=1) foriinrange(len(division))])defcreate_time_blocks(time_length,window_size):start_idx=np.random.randint(0,window_size-1)end_idx=time_length-window_size-16-1time_indices=np.arange(start_idx,end_idx+1,window_size)[:-1]time_indices=np.append(time_indices,end_idx)returntime_indicesdefdata_loader(x_numeric_tensor, x_category_tensor, x_static_tensor, y_tensor, batch_size, time_shuffle):num_series=x_numeric_tensor.shape[0]time_length=x_numeric_tensor.shape[1]index_pd=pd.DataFrame({'serie_idx':range(num_series)})index_pd['time_idx'] = [create_time_blocks(time_length,window_size) forninrange(index_pd.shape[0])]iftime_shuffle:index_pd=index_pd.explode('time_idx')index_pd=index_pd.sample(frac=1)else:index_pd=index_pd.explode('time_idx').sort_values('time_idx')index_pd=divide_shuffle(index_pd,5)indices=np.array(index_pd).astype(int)forbatch_idxinnp.arange(0,indices.shape[0],batch_size):cur_indices=indices[batch_idx:batch_idx+batch_size,:]x_numeric=torch.stack([x_numeric_tensor[n[0],n[1]:n[1]+window_size,:] fornincur_indices])x_category=torch.stack([x_category_tensor[n[0],n[1]:n[1]+window_size,:] fornincur_indices])x_static=torch.stack([x_static_tensor[n[0],:] fornincur_indices])y=torch.stack([y_tensor[n[0],n[1]+window_size:n[1]+window_size+forecast_length] fornincur_indices])yieldx_numeric.to(device), x_category.to(device), x_static.to(device), y.to(device)defval_loader(x_numeric_tensor, x_category_tensor, x_static_tensor, y_tensor, batch_size, num_val):num_time_series=x_numeric_tensor.shape[0]foriinrange(num_val):forbatch_idxinnp.arange(0,num_time_series,batch_size):x_numeric=x_numeric_tensor[batch_idx:batch_idx+batch_size,window_size*i:window_size*(i+1),:]x_category=x_category_tensor[batch_idx:batch_idx+batch_size,window_size*i:window_size*(i+1),:]x_static=x_static_tensor[batch_idx:batch_idx+batch_size]y_val=y_tensor[batch_idx:batch_idx+batch_size,window_size*(i+1):window_size*(i+1)+forecast_length]yieldx_numeric.to(device), x_category.to(device), x_static.to(device), y_val.to(device)

模型

我们这里通过Pytorch来简单的实现《Attention is All You Need》(2017)²中描述的Transformer架构。因为是时间序列预测,所以注意力机制中不需要因果关系,也就是没有对注意块应用进行遮蔽。

从输入开始:分类特征通过嵌入层传递,以密集的形式表示它们,然后送到Transformer块。多层感知器(MLP)接受最终编码输入来产生预测。嵌入维数、每个Transformer块中的注意头数和dropout概率是模型的主要超参数。堆叠多个Transformer块由’ num_blocks '超参数控制。

下面是单个Transformer块的实现和整体预测模型:

 classtransformer_block(nn.Module):def__init__(self,embed_size,num_heads):super(transformer_block, self).__init__()self.attention=nn.MultiheadAttention(embed_size, num_heads, batch_first=True)self.fc=nn.Sequential(nn.Linear(embed_size, 4*embed_size),nn.LeakyReLU(),nn.Linear(4*embed_size, embed_size))self.dropout=nn.Dropout(drop_prob)self.ln1=nn.LayerNorm(embed_size, eps=1e-6)self.ln2=nn.LayerNorm(embed_size, eps=1e-6)defforward(self, x):attn_out, _=self.attention(x, x, x, need_weights=False)x=x+self.dropout(attn_out)x=self.ln1(x)fc_out=self.fc(x)x=x+self.dropout(fc_out)x=self.ln2(x)returnxclasstransformer_forecaster(nn.Module):def__init__(self,embed_size,num_heads,num_blocks):super(transformer_forecaster, self).__init__()num_len=len(numeric_covariates)self.embedding_cov=nn.ModuleList([nn.Embedding(n,embed_size-num_len) fornincategorical_covariates_num_embeddings])self.embedding_static=nn.ModuleList([nn.Embedding(n,embed_size-num_len) fornincategorical_static_num_embeddings])self.blocks=nn.ModuleList([transformer_block(embed_size,num_heads) forninrange(num_blocks)])self.forecast_head=nn.Sequential(nn.Linear(embed_size, embed_size*2),nn.LeakyReLU(),nn.Dropout(drop_prob),nn.Linear(embed_size*2, embed_size*4),nn.LeakyReLU(),nn.Linear(embed_size*4, forecast_length),nn.ReLU())defforward(self, x_numeric, x_category, x_static):tmp_list= []fori,embed_layerinenumerate(self.embedding_static):tmp_list.append(embed_layer(x_static[:,i]))categroical_static_embeddings=torch.stack(tmp_list).mean(dim=0).unsqueeze(1)tmp_list= []fori,embed_layerinenumerate(self.embedding_cov):tmp_list.append(embed_layer(x_category[:,:,i]))categroical_covariates_embeddings=torch.stack(tmp_list).mean(dim=0)T=categroical_covariates_embeddings.shape[1]embed_out= (categroical_covariates_embeddings+categroical_static_embeddings.repeat(1,T,1))/2x=torch.concat((x_numeric,embed_out),dim=-1)forblockinself.blocks:x=block(x)x=x.mean(dim=1)x=self.forecast_head(x)returnx

我们修改后的transformer架构如下图所示:

模型接受三个独立的输入张量:数值特征、分类特征和静态特征。对分类和静态特征嵌入进行平均,并与数字特征组合形成具有形状(batch_size, window_size, embedding_size)的张量,为Transformer块做好准备。这个复合张量还包含嵌入的时间变量,提供必要的位置信息。

Transformer块提取顺序信息,然后将结果张量沿着时间维度聚合,将其传递到MLP中以生成最终预测(batch_size, forecast_length)。这个比赛采用均方根对数误差(RMSLE)作为评价指标,公式为:

鉴于预测经过对数转换,预测低于-1的负销售额(这会导致未定义的错误)需要进行处理,所以为了避免负的销售预测和由此产生的NaN损失值,在MLP层以后增加了一层ReLU激活确保非负预测。

 class RMSLELoss(nn.Module):def __init__(self):super().__init__()self.mse = nn.MSELoss()def forward(self, pred, actual):return torch.sqrt(self.mse(torch.log(pred + 1), torch.log(actual + 1)))

训练和验证

训练模型时需要设置几个超参数:窗口大小、是否打乱时间、嵌入大小、头部数量、块数量、dropout、批大小和学习率。以下配置是有效的,但不保证是最好的:

 num_epoch = 1000min_val_loss = 999num_blocks = 1embed_size = 500num_heads = 50batch_size = 128learning_rate = 3e-4time_shuffle = Falsedrop_prob = 0.1model = transformer_forecaster(embed_size,num_heads,num_blocks).to(device)criterion = RMSLELoss()optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

这里使用adam优化器和学习率调度,以便在训练期间逐步调整学习率。

 for epoch in range(num_epoch):batch_loader = data_loader(x_numeric_train_tensor, x_category_train_tensor, x_static_train_tensor, y_train_tensor, batch_size, time_shuffle)train_loss = 0counter = 0model.train()for x_numeric, x_category, x_static, y in batch_loader:optimizer.zero_grad()preds = model(x_numeric, x_category, x_static)loss = criterion(preds, y)train_loss += loss.item()counter += 1loss.backward()optimizer.step()train_loss = train_loss/counterprint(f'Epoch {epoch} training loss: {train_loss}')model.eval()val_batches = val_loader(x_numeric_val_tensor, x_category_val_tensor, x_static_val_tensor, y_val_tensor, batch_size, num_val)val_loss = 0counter = 0for x_numeric_val, x_category_val, x_static_val, y_val in val_batches:with torch.no_grad():preds = model(x_numeric_val,x_category_val,x_static_val)loss = criterion(preds,y_val).item()val_loss += losscounter += 1val_loss = val_loss/counterprint(f'Epoch {epoch} validation loss: {val_loss}')if val_loss<min_val_loss:print('saved...')torch.save(model,data_folder+'best.model')min_val_loss = val_lossscheduler.step()

结果

训练后,表现最好的模型的训练损失为0.387,验证损失为0.457。当应用于测试集时,该模型的RMSLE为0.416,比赛排名为第89位(前10%)。

更大的嵌入和更多的注意力头似乎可以提高性能,但最好的结果是用一个单独的Transformer 实现的,这表明在有限的数据下,简单是优点。当放弃整体打乱而选择局部打乱时,效果有所改善;引入轻微的时间偏差提高了预测的准确性。

以下是引用

[1]: Alexis Cook, DanB, inversion, Ryan Holbrook. (2021). Store Sales — Time Series Forecasting. Kaggle. https://avoid.overfit.cn/post/960767b198ac4d9f988fc1795aa89e59

[2]: Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., … & Polosukhin, I. (2017). Attention is all you need. Advances in neural information processing systems, 30.

[3]: Lim, B., Arık, S. Ö., Loeff, N., & Pfister, T. (2021). Temporal fusion transformers for interpretable multi-horizon time series forecasting. International Journal of Forecasting, 37(4), 1748–1764.

作者:Kaan Aslan

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/651350.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue3的computed和watch

目录 1、computed 2、computed完整写法 3、watch 4、watch监听对象具体属性 5、watch 监听reactive数据 1、computed 基于现有的数据计算出新的数据 <script setup >import {ref,computed} from vue const numref(1) const doubleNumcomputed(()>{return num.val…

latex表格示例(背景颜色、行距、线粗细、标题、表格长度)

导入库 \usepackage{colortbl} 代码 \begin{table}[H] \begin{center}%表格居中 \tabcolsep1.5cm%表格横向长度 \renewcommand\arraystretch{1.5}%设置表格行间距 \begin{tabular}{cc} \toprule[2pt]%设置线的宽度 %\specialrule{0em}{3pt}{8pt}%添加一条线&#xff0c;第一个…

《Linux高性能服务器编程》笔记08

Linux高性能服务器编程 本文是读书笔记&#xff0c;如有侵权&#xff0c;请联系删除。 参考 Linux高性能服务器编程源码: https://github.com/raichen/LinuxServerCodes 豆瓣: Linux高性能服务器编程 文章目录 Linux高性能服务器编程第08章 高性能服务器程序框架8.1 服务器…

docker拉取镜像时指定其OS及CPU指令集类型

前言 之前在香橙派5上安装的时候碰到过一次指定镜像的OS及cpu指令集类型的问题&#xff0c;但是当时没有记录&#xff0c;现在用到 了又想不起来&#xff0c;干脆就自己记录一下。预防后面忘掉。docker报错截图 上次时在arm的cpu中运行x86镜像&#xff0c;这次时在x86中运行arm…

C语言从入门到入坟

前言 1.初识程序 有穷性 在有限的操作步骤内完成。有穷性是算法的重要特性&#xff0c;任何一个问题的解决不论其采取什么样的算法&#xff0c;其终归是要把问题解决好。如果一种算法的执行时间是无限的&#xff0c;或在期望的时间内没有完成&#xff0c;那么这种算法就是无用…

MYSQL库和表的操作(修改字符集和校验规则,备份和恢复数据库及库和表的增删改查)

文章目录 一、MSYQL库的操作1.连接MYSQL2.查看当前数据库3.创建数据库4.字符集和校验规则5.修改数据库6.删除数据库7.备份和恢复8.查看连接 二、表的操作1.创建表2.查看表结构3.修改表4.删除表 一、MSYQL库的操作 1.连接MYSQL 我们使用下面的语句来连接MSYQL&#xff1a; my…

关于session每次请求都会改变的问题

这几天在部署一个前后端分离的项目&#xff0c;使用docker进行部署&#xff0c;在本地测试没有一点问题没有&#xff0c;前脚刚把后端部署到服务器&#xff0c;后脚测试就出现了问题&#xff01;查看控制台报错提示跨域错误&#xff1f;但是对于静态资源请求&#xff0c;包括登…

【CSS】字体效果展示

测试时使用了Google浏览器。 1.Courier New 2.monospace 3.Franklin Gothic Medium 4.Arial Narrow 5.Arial 6.sans-serif 7.Gill Sans MT 8.Calibri 9.Trebuchet MS 10.Lucida Sans 11.Lucida Grande 12.Lucida Sans Unicode 13.Geneva 14.Verdana 15.Segoe UI 16.Tahoma 17.…

【2024华数杯国际数学建模竞赛】问题B 光伏发电 完整代码+结果分析+论文框架(二)

问题B&#xff08;二&#xff09; 5.2 问题二模型的建立与求解&#xff08;二&#xff09;5.1.4基于LSTM的时间序列预测模型5.1.5 LSTM的时间序列预测结果5.1.6 多元回归模型的预测结果5.1.7 LSTM时间序列模型的性能评价 5.2 问题二模型的建立与求解5.2.1基于皮尔逊系数相关性分…

【C++中STL】set/multiset容器

set/multiset容器 Set基本概念set构造和赋值set的大小和交换set的插入和删除set查找和统计 set和multiset的区别pair对组两种创建方式 set容器排序 Set基本概念 所有元素都会在插入时自动被排序。 set/multist容器属于关联式容器&#xff0c;底层结构属于二叉树。 set不允许容…

架构师的36项修炼-08系统的安全架构设计

本课时讲解系统的安全架构。 本节课主要讲 Web 的攻击与防护、信息的加解密与反垃圾。其中 Web 攻击方式包括 XSS 跨站点脚本攻击、SQL 注入攻击和 CSRF 跨站点请求伪造攻击&#xff1b;防护手段主要有消毒过滤、SQL 参数绑定、验证码和防火墙&#xff1b;加密手段&#xff0c…

java关键字概述——final及常量概述

前言&#xff1a; 打好基础&#xff0c;daydayup! final final概述 final关键字是最终的意思&#xff0c;可以修饰&#xff08;类&#xff0c;方法&#xff0c;变量&#xff09; final作用 修饰类&#xff1a;该类被称为最终类&#xff0c;特点为不能被继承 修饰方法&#xff…

智能GPT图书管理系统(SpringBoot2+Vue2)、接入GPT接口,支持AI智能图书馆

☀️技术栈介绍 ☃️前端主要技术栈 技术作用版本Vue提供前端交互2.6.14Vue-Router路由式编程导航3.5.1Element-UI模块组件库&#xff0c;绘制界面2.4.5Axios发送ajax请求给后端请求数据1.2.1core-js兼容性更强&#xff0c;浏览器适配3.8.3swiper轮播图插件&#xff08;快速实…

【笔试常见编程题01】删除公共字符串、组队竞赛、倒置字符串、排序子序列

1. 删除公共字符串 输入两个字符串&#xff0c;从第一字符串中删除第二个字符串中所有的字符。 例如&#xff0c;输入”They are students.”和”aeiou”&#xff0c;则删除之后的第一个字符串变成”Thy r stdnts.” 输入描述 每个测试输入包含2个字符串 输出描述 输出删除后的…

外包干了8个月,技术退步明显...

先说一下自己的情况&#xff0c;大专生&#xff0c;18年通过校招进入武汉某软件公司&#xff0c;干了接近4年的功能测试&#xff0c;今年年初&#xff0c;感觉自己不能够在这样下去了&#xff0c;长时间呆在一个舒适的环境会让一个人堕落! 而我已经在一个企业干了四年的功能测…

v43-47.problems

1.for循环 一般地&#xff0c;三步走&#xff1a; for&#xff08;初始化&#xff1b;表达式判断&#xff1b;递增/递减&#xff09; &#xff5b; ....... &#xff5d; 但是&#xff0c;如果说声明了全局变量&#xff0c;那么第一步初始化阶段可以省略但是要写分号‘ ; ’…

Java后端开发:学籍系统核心逻辑

✍✍计算机编程指导师 ⭐⭐个人介绍&#xff1a;自己非常喜欢研究技术问题&#xff01;专业做Java、Python、微信小程序、安卓、大数据、爬虫、Golang、大屏等实战项目。 ⛽⛽实战项目&#xff1a;有源码或者技术上的问题欢迎在评论区一起讨论交流&#xff01; ⚡⚡ Java实战 |…

uniapp组件库fullScreen 压窗屏的适用方法

目录 #平台差异说明 #基本使用 #触发压窗屏 #定义压窗屏内容 #注意事项 所谓压窗屏&#xff0c;是指遮罩能盖住原生导航栏和底部tabbar栏的弹窗&#xff0c;一般用于在APP端弹出升级应用弹框&#xff0c;或者其他需要增强型弹窗的场景。 警告 由于uni-app的Bug&#xff0…

DEM高程地形瓦片数据Cesium使用教程

一、简介 从开始写文章到现在&#xff0c;陆续发布了全球90m、30m(包括哥白尼及ALOS)、12.5m全球级瓦片数据&#xff0c;以及中国12.5、日本10m、新西兰8m、等国家级瓦片数据&#xff0c;同时也发布了台湾20m、中国34省区12.5m等地区级瓦片数据。在数据发布的文章中对数据如何…

C#,最小生成树(MST)普里姆(Prim)算法的源代码

Vojtěch Jarnk 一、Prim算法简史 Prim算法&#xff08;普里姆算法&#xff09;&#xff0c;是1930年捷克数学家算法沃伊捷赫亚尔尼克&#xff08;Vojtěch Jarnk&#xff09;最早设计&#xff1b; 1957年&#xff0c;由美国计算机科学家罗伯特普里姆独立实现&#xff1b; 19…