Lag-Llama:第一个时间序列预测的开源基础模型介绍和性能测试

2023年10月,我们发表了一篇关于TimeGPT的文章,TimeGPT是时间序列预测的第一个基础模型之一,具有零样本推理、异常检测和共形预测能力。

虽然TimeGPT是一个专有模型,只能通过API访问。但是它还是引发了对时间序列基础模型的更多研究。到了2024年2月,已经有了一个用于时间序列预测的开源基础模型:laglllama。

在原论文《Lag-Llama: Towards Foundation Models for Probabilistic Time Series Forecasting》中,模型作为单变量概率预测的通用基础模型提出。它是由来自不同机构的大型团队开发的,这些机构包括Morgan Stanley, ServiceNow, Université de Montréal, Mila-Quebec, 和McGill University.

在本文中,我们将探讨Lag-Llama的架构、功能以及训练方式。还会将lagllama应用于一个预测项目中,并将其与其他深度学习方法Temporal Fusion Transformer (TFT) 和DeepAR进行性能比较。

Lag-Llama

lagllama是为单变量概率预测而构建的。它使用不依赖于频率的通用方法来标记时间序列数据。这样模型可以很好地推广到不可见的频率。

它利用Transformer体系结构和分布头来解析输入令牌,并将它们映射到具有置信区间的未来预测。

1、具有滞后特征的标记

laglllama的标记策略是使用一组指定的滞后来构造序列的滞后特征。

它将从这个列表中为给定的数据集选择所有合适的频率:

季度、月、周、天、小时、秒

也就是说,如果以每日频率提供数据集,lag - llama将尝试使用每日滞后(t-1),每周滞后(t-7),每月滞后(t-30)等构建特征。

策略如下图所示。

从上图中,我们还可以看到模型构建了其他静态协变量,例如秒/分、小时/天等等,直到季度/年。虽然这可以很好地推广到所有类型的时间序列,但它有一个致命的缺点:由于固定的滞后指数列表,输入令牌可能会变得非常大。

例如,查看每小时数据的每月频率需要730个时间步。这意味着除了所有静态协变量之外,输入令牌的长度至少为730。

2、Lag-Llama架构

Lag-Llama是一个基于transformer的纯解码器模型,其灵感来自大型语言模型LLaMA的体系结构。

从图中可以看到输入标记是滞后时间步长和静态协变量的拼接。输入序列通过线性投影层将特征映射到解码器内部注意力模块的隐藏维度。另外就是在最后的输出,序列被发送到一个分布头负责输出一个概率分布。

在推理过程中,输入序列生成下一个时间点的分布。然后通过自回归,模型逐个生成剩余的预测序列,直到达到设置的长度。

生成预测的自回归过程有效地允许模型为其预测生成不确定性区间。但是这里的问题就是如果序列很长,自回归的方式会将错误扩大。

3、Lag-Llama分布头

Lag-Llama的分布头负责输出概率分布。这样模型就能够生成预测区间。

在模型的迭代中,最后一层使用Student 's t分布来构造不确定性区间。从理论上讲不同的分布头可以组合在一起,但是论文并没有做这样的实验,可能是想在以后在做吧。

4、Lag-Llama的训练

作为一个基础模型,Lag-Llama显然是在大量的时间序列数据语料库上训练的,因此该模型可以很好地泛化未见过的时间序列并进行零样本预测。

论文中说:Lag-Llama在来自不同领域的27个时间序列数据集上进行了训练,如能源、交通、经济等。

数据包含7965个单变量时间序列,总计约3.52亿个令牌。

所有数据集都是开源的,包括ethth, Exchange和Weather等。

Lag-Llama测试

因为代码已经开源,所以我们可以直接测试,我们首先使用Lag-Llama的零样本预测能力,并将其性能与特定数据模型(如TFT和DeepAR)进行比较。

Lag-Llama的实现是建立在GluonTS之上的,所以我们还需要安装这个库。实验使用了澳大利亚电力需求数据集,该数据集包含五个单变量时间序列,以半小时的频率跟踪能源需求。

这里有个说明:Lag-Llama目前的实现是初期阶段。并且存还在积极开发中,后面可能还会有很大的调整,因为目前还没加入微调的功能。

1、环境设置

 !git clone https://github.com/time-series-foundation-models/lag-llama/ cd lag-llama pip install -r requirements.txt --quiet

然后需要我们从HuggingFace下载模型的权重。

 !huggingface-cli download time-series-foundation-models/Lag-Llama lag-llama.ckpt --local-dir /content/lag-llama

2、加载数据集

 import pandas as pd import matplotlib.pyplot as plt import matplotlib.dates as mdates import torchfrom itertools import islicefrom gluonts.evaluation import make_evaluation_predictions, Evaluator from gluonts.dataset.repository.datasets import get_dataset from lag_llama.gluon.estimator import LagLlamaEstimator

可以直接从GluonTS加载数据集。

 dataset = get_dataset("australian_electricity_demand") backtest_dataset = dataset.test prediction_length = dataset.metadata.prediction_length context_length = 3 * prediction_length

3、使用Lag-Llama预测

简单地初始化模型并使用LagLlamaEstimator对象。

 ckpt = torch.load("lag-llama.ckpt", map_location=torch.device('cuda:0')) estimator_args = ckpt["hyper_parameters"]["model_kwargs"] estimator = LagLlamaEstimator( ckpt_path="lag-llama.ckpt", prediction_length=prediction_length, context_length=context_length, input_size=estimator_args["input_size"], n_layer=estimator_args["n_layer"], n_embd_per_head=estimator_args["n_embd_per_head"], n_head=estimator_args["n_head"], scaling=estimator_args["scaling"], time_feat=estimator_args["time_feat"]) lightning_module = estimator.create_lightning_module() transformation = estimator.create_transformation() predictor = estimator.create_predictor(transformation, lightning_module)

使用make_evaluation_predictions函数生成零样本的预测。

 forecast_it, ts_it = make_evaluation_predictions(dataset=backtest_dataset, predictor=predictor)

这个函数返回生成器。我们需要把它们转换成列表。

 forecasts = list(forecast_it) tss = list(ts_it)

4、评估

GluonTS可以使用Evaluator对象方便地计算不同的性能指标。

 evaluator = Evaluator() agg_metrics, ts_metrics = evaluator(iter(tss), iter(forecasts))

RMSE为481.57。

我们还可以随意地将预测可视化。

 plt.figure(figsize=(20, 15)) date_formater = mdates.DateFormatter('%b, %d') plt.rcParams.update({'font.size': 15}) for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 4): ax = plt.subplot(2, 2, idx+1) plt.plot(ts[-4 * dataset.metadata.prediction_length:].to_timestamp(), label="target") forecast.plot( color='g')plt.xticks(rotation=60) ax.xaxis.set_major_formatter(date_formater) ax.set_title(forecast.item_id) plt.gcf().tight_layout() plt.legend() plt.show()

上图可以看到模型对数据做出了合理的预测,尽管它在第四个序列(图的右下角)上确实存在问题。

另外由于 Lag-Llama实现了概率预测,可以得到预测的不确定性区间。

5、与TFT和DeepAR相比

我们在数据集上训练TFT和DeepAR模型,看看它们是否能表现得更好。

为了节省时间,我们将训练设置为5个epoch。

 from gluonts.torch import TemporalFusionTransformerEstimator, DeepAREstimator tft_estimator = TemporalFusionTransformerEstimator(prediction_length=prediction_length, context_length=context_length, freq="30min", trainer_kwargs={"max_epochs": 5}) deepar_estimator = DeepAREstimator(prediction_length=prediction_length, context_length=context_length, freq="30min", trainer_kwargs={"max_epochs": 5})

训练过程。

 tft_predictor = tft_estimator.train(dataset.train) deepar_predictor = deepar_estimator.train(dataset.train)

训练完成后,生成预测并计算RMSE。

 tft_forecast_it, tft_ts_it = make_evaluation_predictions(dataset=backtest_dataset, predictor=tft_predictor) deepar_forecast_it, deepar_ts_it = make_evaluation_predictions(dataset=backtest_dataset, predictor=deepar_predictor) tft_forecasts = list(tft_forecast_it) tft_tss = list(tft_ts_it) deepar_forecasts = list(deepar_forecast_it) deepar_tss = list(deepar_ts_it) # Get evaluation metricstft_agg_metrics, tft_ts_metrics = evaluator(iter(tft_tss), iter(tft_forecasts)) deepar_agg_metrics, deepar_ts_metrics = evaluator(iter(deepar_tss), iter(deepar_forecasts))

下表突出显示了性能最好的模型。

可以看到TFT是目前表现最好的模型,DeepAR的表现也优于laglama。

虽然laglllama的表现似乎不尽如人意,但该模型没有经过微调,而且零样本测本身就比较困难。

有趣的是,只训练了5个epoch这两个模型都取得了比Lag-Llama更好的结果。虽然样本预测可以节省时间,但训练五个epoch在时间和计算能力方面的要求应该不是很苛刻。所以目前可能零样本学习方面还需要很大的提升。

总结

在尝试了TimeGPT和Lag-Llama之后,Lag-Llama算是构建开源预测模型的第一步,但与TimeGPT相比,它在功能方面存在不足。

TimeGPT可以处理多变量时间序列、不规则时间戳,并实现共形预测,与使用laglama等固定分布相比,这是一种更稳健的量化不确定性的方式。

laglllama是一个开源的基础模型,只用于单变量概率预测,并且我觉得它训练的数据有点少了。我相信在不久的将来会看到更多的开源预测模型出现。他们的表现可能会得到改善,这代表了该领域的一个重大转变。

最后论文地址:

Lag-Llama: Towards Foundation Models for Probabilistic Time Series Forecasting by K. Rasul, A. Ashok, A. Williams, H. Ghonia, R. Bhagwatkar, A. Khorasani, M. Bayazi, G. Adamopoulos, R. Riachi, N. Hassen, M. Bilos, S. Garg, A. Schneider, N. Chapados, A. Drouin, V. Zantedeschi, Y. Nevmyvaka, I. Rish

https://avoid.overfit.cn/post/8a9120d3cf074c1ba0de0a7a247993c9

作者:Marco Peixeiro

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/683861.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Unity如何修改预制体(预制件)?

文章目录 19 复制复制复制,预制体与变体 19 复制复制复制,预制体与变体 【预制件】 预制件作用:方便复用 【预制件】的制作 直接拖拽,从层级面板 -> 项目面板。层级面板中当前图标会变蓝,子物体名字变蓝色。预制件…

[经验] 做完腺样体手术打呼噜很严重怎么办 #媒体#笔记#经验分享

做完腺样体手术打呼噜很严重怎么办 1、打呼噜很严重怎么办 打呼噜是一种常见的睡眠障碍,不仅让睡眠质量变得很糟糕,也会影响室友或家人的睡眠质量。幸运的是,有许多方法可以减少打呼噜的发生率,从而让睡眠变得更好。 保持良好的…

具有集中目录服务器的 P2P 工作方式

P2P 工作方式概述 在 P2P 工作方式下,所有的音频/视频文件都是在普通的互联网用户之间传输。 具有集中目录服务器的 P2P 工作方式 Napster 最早使用 P2P 技术,提供免费下载 MP3 音乐。 Napster 将所有音乐文件的索引信息都集中存放在 Napster 目录服务…

ng : 无法加载文件 C:\Program Files\nodejs\node_global\ng.ps1, 因为在此系统上禁止运行脚本

ng : 无法加载文件 C:\Program Files\nodejs\node_global\ng.ps1,因为在此系统上禁止运行脚本 今天在VSCode中运行ng serve --port 8081运行基于Angular的项目时,报错了,错误如下图所示: 解决方法: 按照下图的5步即…

算法沉淀——哈希算法(leetcode真题剖析)

算法沉淀——哈希算法 01.两数之和02.判定是否互为字符重排03.存在重复元素04.存在重复元素 II05.字母异位词分组 哈希算法(Hash Algorithm)是一种将任意长度的输入(也称为消息)映射为固定长度的输出的算法。这个输出通常称为哈希…

七、Mybatis缓存

缓存就是内存中的数据,常常来自对数据库查询结果的保存,使用缓存、可以避免频繁的与数据库进行交互,进而提高响应速度一级缓存是sqlSession级别的缓存,在操作数据库时需要构造sqlsession对象,在对象中有一个数据结构&a…

【智能家居入门3】(MQTT服务器、MQTT协议、微信小程序、STM32)

前面已经写了三篇博客关于智能家居的,服务器全都是使用ONENET中国移动,他最大的优点就是作为数据收发的中转站是免费的。本篇使用专门适配MQTT协议的MQTT服务器,有公用的,也可以自己搭建(应该要钱)&#xf…

【Java程序员面试专栏 分布式中间件】ElasticSearch 核心面试指引

关于ElasticSearch 部分的核心知识进行一网打尽,包括ElasticSearch 的基本概念,基本架构,工作流程,存储机制等,通过一篇文章串联面试重点,并且帮助加强日常基础知识的理解,全局思维导图如下所示 基础概念 从数据分类入手,考察全文索引的基本概念 现实世界中数据有哪…

量子算法入门——2.线性代数与复数

参考资料: 【【零基础入门量子计算-第03讲】线性代数初步与复数】 来自b站up:溴锑锑跃迁 建议关注他的更多高质量文章:CSDN:【溴锑锑跃迁】 0. 前言 强烈建议搭配b站原视频进行观看,这只是我当时看的笔记&#xff0c…

【机器学习笔记】4 朴素贝叶斯

贝叶斯方法 贝叶斯分类 贝叶斯分类是一类分类算法的总称,这类算法均以贝叶斯定理为基础,故统称为贝叶斯分类。 朴素贝叶斯分类是这一类算法中最简单的较为常见的算法。 先验概率 根据以往经验和分析得到的概率。我们用𝑃(𝑌)来代…

FL Studio 21.2.3.4004 All Plugins Edition Win/Mac音乐软件

FL Studio 21.2.3.4004 All Plugins Edition 是一款功能强大的音乐制作软件,提供了丰富的音频处理工具和插件,适用于专业音乐制作人和爱好者。该软件具有直观的用户界面,支持多轨道录音、混音和编辑,以及各种音频效果和虚拟乐器。…

华清远见嵌入式学习——春节作业——2.15日

作业要求&#xff1a; 编写led驱动&#xff0c;通过应用程序控制三盏灯亮灭 作业答案&#xff1a; 作业效果&#xff1a; mychrdev.c #include <linux/init.h> #include <linux/module.h> #include <linux/fs.h> #include <linux/uaccess.h> #incl…

基于GPT-4一键完成数据分析全流程的AI Agent: Streamline Analyst

大型语言模型&#xff08;LLM&#xff09;的兴起不仅为获取知识和解决问题开辟了新的可能性&#xff0c;而且催生了一些新型智能系统&#xff0c;例如旨在辅助用户完成特定任务的AI Copilot以及旨在自动化和自主执行复杂任务的AI Agent&#xff0c;使得编程、创作等任务变得高效…

医卫答案在哪搜?九个公众号和软件推荐清单! #笔记#笔记#微信

在这个信息爆炸的时代&#xff0c;合理利用学习工具可以帮助我们过滤和获取有用的知识。 1.粉鹿搜题 这是一个公众号 题库包括四六级答案、各学校往期课后答案、期末考试题等&#xff0c;使用比较简单。 下方附上一些测试的试题及答案 1、最有可能担任债券发行受托人的个人…

装饰工程|装饰工程管理系统-项目立项子系统的设计与实现|基于Springboot的装饰工程管理系统设计与实现(源码+数据库+文档)

装饰工程管理系统-项目立项子系统目录 目录 基于Springboot的装饰工程管理系统设计与实现 一、前言 二、系统功能设计 三、系统实现 1、管理员功能实现 &#xff08;2&#xff09;合同报价管理 &#xff08;3&#xff09;装饰材料总计划管理 &#xff08;4&#xff0…

Java与JavaScript的区别与联系

Java是目前编程领域使用非常广泛的编程语言&#xff0c;相较于JavaScript&#xff0c;Java更被人们熟知。很多Java程序员想学门脚本语言&#xff0c;一看JavaScript和Java这么像&#xff0c;很有亲切感&#xff0c;那干脆就学它了&#xff0c;这也间接的帮助了JavaScript的发展…

OLED显示红外遥控键码

基本原理 本遥控器的编码是NEC编码&#xff0c;为PWM&#xff08;脉冲宽度调制&#xff09;。 发射红外载波的时间固定&#xff0c;通过改变不发射载波的时间来改变占空比。 逻辑“0”是由0.56ms的38KHZ载波和0.560ms的无载波间隔组成&#xff1b;逻辑“1”是由0.56ms的38KHZ…

LabVIEW高效电磁阀性能测试

LabVIEW高效电磁阀性能测试 在核电站的安全运营中&#xff0c;电磁阀作为关键组件&#xff0c;其性能的可靠性至关重要。设计一套基于LabVIEW的电磁阀测试平台&#xff0c;既能精准测试电磁阀的多项性能指标&#xff0c;又能提高检修效率与准确性&#xff0c;进而保障核电站的…

接口测试全流程扫盲

扫盲内容&#xff1a; 1.什么是接口&#xff1f; 2.接口都有哪些类型&#xff1f; 3.接口的本质是什么&#xff1f; 4.什么是接口测试&#xff1f; 5.问什么要做接口测试&#xff1f; 6.怎样做接口测试&#xff1f; 7.接口测测试点是什么&#xff1f; 8.接口测试都要掌…

​StableSwarmUI#超越文本的prompt

今天看到一个新的webui方案&#xff0c;是Stability-AI开源的&#xff1a; StableSwarmUI 是一个模块化的稳定扩散web用户界面&#xff0c;着重于使强大的工具易于访问、高性能和可扩展性。 由于项目还在开发中&#xff0c;我们可以先了解下&#xff0c;翻看了它的特点&#xf…