lag-llama源码解读(Lag-Llama: Towards Foundation Models for Time Series Forecasting)

Lag-Llama: Towards Foundation Models for Time Series Forecasting
文章内容:
时间序列预测任务,单变量预测单变量,基于Llama大模型,在zero-shot场景下模型表现优异。创新点,引入滞后特征作为协变量来进行预测。

获得不同频率的lag,来自glunoTS库里面的源码

def _make_lags(middle: int, delta: int) -> np.ndarray:"""Create a set of lags around a middle point including +/- delta."""return np.arange(middle - delta, middle + delta + 1).tolist()def get_lags_for_frequency(freq_str: str,lag_ub: int = 1200,num_lags: Optional[int] = None,num_default_lags: int = 7,
) -> List[int]:"""Generates a list of lags that that are appropriate for the given frequencystring.By default all frequencies have the following lags: [1, 2, 3, 4, 5, 6, 7].Remaining lags correspond to the same `season` (+/- `delta`) in previous`k` cycles. Here `delta` and `k` are chosen according to the existing code.Parameters----------freq_strFrequency string of the form [multiple][granularity] such as "12H","5min", "1D" etc.lag_ubThe maximum value for a lag.num_lagsMaximum number of lags; by default all generated lags are returned.num_default_lagsThe number of default lags; by default it is 7."""# Lags are target values at the same `season` (+/- delta) but in the# previous cycle.def _make_lags_for_second(multiple, num_cycles=3):# We use previous ``num_cycles`` hours to generate lagsreturn [_make_lags(k * 60 // multiple, 2) for k in range(1, num_cycles + 1)]def _make_lags_for_minute(multiple, num_cycles=3):# We use previous ``num_cycles`` hours to generate lagsreturn [_make_lags(k * 60 // multiple, 2) for k in range(1, num_cycles + 1)]def _make_lags_for_hour(multiple, num_cycles=7):# We use previous ``num_cycles`` days to generate lagsreturn [_make_lags(k * 24 // multiple, 1) for k in range(1, num_cycles + 1)]def _make_lags_for_day(multiple, num_cycles=4, days_in_week=7, days_in_month=30):# We use previous ``num_cycles`` weeks to generate lags# We use the last month (in addition to 4 weeks) to generate lag.return [_make_lags(k * days_in_week // multiple, 1)for k in range(1, num_cycles + 1)] + [_make_lags(days_in_month // multiple, 1)]def _make_lags_for_week(multiple, num_cycles=3):# We use previous ``num_cycles`` years to generate lags# Additionally, we use previous 4, 8, 12 weeksreturn [_make_lags(k * 52 // multiple, 1) for k in range(1, num_cycles + 1)] + [[4 // multiple, 8 // multiple, 12 // multiple]]def _make_lags_for_month(multiple, num_cycles=3):# We use previous ``num_cycles`` years to generate lagsreturn [_make_lags(k * 12 // multiple, 1) for k in range(1, num_cycles + 1)]# multiple, granularity = get_granularity(freq_str)offset = to_offset(freq_str)# normalize offset name, so that both `W` and `W-SUN` refer to `W`offset_name = norm_freq_str(offset.name)if offset_name == "A":lags = []elif offset_name == "Q":assert (offset.n == 1), "Only multiple 1 is supported for quarterly. Use x month instead."lags = _make_lags_for_month(offset.n * 3.0)elif offset_name == "M":lags = _make_lags_for_month(offset.n)elif offset_name == "W":lags = _make_lags_for_week(offset.n)elif offset_name == "D":lags = _make_lags_for_day(offset.n) + _make_lags_for_week(offset.n / 7.0)elif offset_name == "B":lags = _make_lags_for_day(offset.n, days_in_week=5, days_in_month=22) + _make_lags_for_week(offset.n / 5.0)elif offset_name == "H":lags = (_make_lags_for_hour(offset.n)+ _make_lags_for_day(offset.n / 24)+ _make_lags_for_week(offset.n / (24 * 7)))# minuteselif offset_name == "T":lags = (_make_lags_for_minute(offset.n)+ _make_lags_for_hour(offset.n / 60)+ _make_lags_for_day(offset.n / (60 * 24))+ _make_lags_for_week(offset.n / (60 * 24 * 7)))# secondelif offset_name == "S":lags = (_make_lags_for_second(offset.n)+ _make_lags_for_minute(offset.n / 60)+ _make_lags_for_hour(offset.n / (60 * 60)))else:raise Exception("invalid frequency")# flatten lags list and filterlags = [int(lag) for sub_list in lags for lag in sub_list if 7 < lag <= lag_ub]lags = list(range(1, num_default_lags + 1)) + sorted(list(set(lags)))return lags[:num_lags]

第一部分,生成以middle为中心,以delta为半径的区间[middle-delta,middle+delta] ,这很好理解,比如一周的周期是7天,周期大小在7天附近波动很正常。
在这里插入图片描述

第二部分,对于年月日时分秒这些不同的采样频率,采用不同的具体的函数来确定lags,其中有一个参数num_cycle,进一步利用了周期性,我们考虑间隔1、2、3、…num个周期的时间点之间的联系
在这里插入图片描述
原理类似于这张图,这种周期性的重复性体现在邻近的多个周期上

在这里插入图片描述

lag的用途

计算各类窗口大小

计算采样窗口大小

window_size = estimator.context_length + max(estimator.lags_seq) + estimator.prediction_length# Here we make a window slightly bigger so that instance sampler can sample from each window# An alternative is to have exact size and use different instance sampler (e.g. ValidationSplitSampler)
window_size = 10 * window_size
# We change ValidationSplitSampler to add min_pastestimator.validation_sampler = ValidationSplitSampler(min_past=estimator.context_length + max(estimator.lags_seq),min_future=estimator.prediction_length,)
  1. 构建静态特征
lags = lagged_sequence_values(self.lags_seq, prior_input, input, dim=-1)#构建一个包含给定序列的滞后值的数组static_feat = torch.cat((loc.abs().log1p(), scale.log()), dim=-1)
expanded_static_feat = unsqueeze_expand(static_feat, dim=-2, size=lags.shape[-2]
)return torch.cat((lags, expanded_static_feat, time_feat), dim=-1), loc, scale

数据集准备过程

对每个数据集采样,window_size=13500,也挺离谱的

 train_data, val_data = [], []for name in TRAIN_DATASET_NAMES:new_data = create_sliding_window_dataset(name, window_size)train_data.append(new_data)new_data = create_sliding_window_dataset(name, window_size, is_train=False)val_data.append(new_data)

采样的具体过程,这里有个问题,样本数量很小的数据集,实际采样窗口大小小于设定的window_size,后续会如何对齐呢?

文章设置单变量预测单变量,所以样本进行了通道分离,同一样本的不同特征被采样为不同的样本

def create_sliding_window_dataset(name, window_size, is_train=True):#划分非重叠的滑动窗口数据集,window_size是对数据集采样的数量,对每个数据集只取前windowsize个样本# Splits each time series into non-overlapping sliding windowsglobal_id = 0freq = get_dataset(name, path=dataset_path).metadata.freq#从数据集中获取时间频率data = ListDataset([], freq=freq)#创建空数据集dataset = get_dataset(name, path=dataset_path).train if is_train else get_dataset(name, path=dataset_path).test#获取原始数据集for x in dataset:windows = []#划分滑动窗口#target:滑动窗口的目标值#start:滑动窗口的起始位置#item_id,唯一标识符#feat_static_cat:静态特征数组for i in range(0, len(x['target']), window_size):windows.append({'target': x['target'][i:i+window_size],'start': x['start'] + i,'item_id': str(global_id),'feat_static_cat': np.array([0]),})global_id += 1data += ListDataset(windows, freq=freq)return data

合并数据集

# Here weights are proportional to the number of time series (=sliding windows)weights = [len(x) for x in train_data]# Here weights are proportinal to the number of individual points in all time series# weights = [sum([len(x["target"]) for x in d]) for d in train_data]train_data = CombinedDataset(train_data, weights=weights)val_data = CombinedDataset(val_data, weights=weights)
class CombinedDataset:def __init__(self, datasets, seed=None, weights=None):self._seed = seedself._datasets = datasetsself._weights = weightsn_datasets = len(datasets)if weights is None:#如果未提供权重,默认平均分配权重self._weights = [1 / n_datasets] * n_datasetsdef __iter__(self):return CombinedDatasetIterator(self._datasets, self._seed, self._weights)def __len__(self):return sum([len(ds) for ds in self._datasets])

网络结构

lagllama

class LagLlamaModel(nn.Module):def __init__(self,max_context_length: int,scaling: str,input_size: int,n_layer: int,n_embd: int,n_head: int,lags_seq: List[int],rope_scaling=None,distr_output=StudentTOutput(),num_parallel_samples: int = 100,) -> None:super().__init__()self.lags_seq = lags_seqconfig = LTSMConfig(n_layer=n_layer,n_embd=n_embd,n_head=n_head,block_size=max_context_length,feature_size=input_size * (len(self.lags_seq)) + 2 * input_size + 6,rope_scaling=rope_scaling,)self.num_parallel_samples = num_parallel_samplesif scaling == "mean":self.scaler = MeanScaler(keepdim=True, dim=1)elif scaling == "std":self.scaler = StdScaler(keepdim=True, dim=1)else:self.scaler = NOPScaler(keepdim=True, dim=1)self.distr_output = distr_outputself.param_proj = self.distr_output.get_args_proj(config.n_embd)self.transformer = nn.ModuleDict(dict(wte=nn.Linear(config.feature_size, config.n_embd),h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),ln_f=RMSNorm(config.n_embd),))

主要是transformer里面首先是一个线性层,然后加了n_layer个Block,最后是RMSNorm,接下来解析Block的代码

在这里插入图片描述

Block

class Block(nn.Module):def __init__(self, config: LTSMConfig) -> None:super().__init__()self.rms_1 = RMSNorm(config.n_embd)self.attn = CausalSelfAttention(config)self.rms_2 = RMSNorm(config.n_embd)self.mlp = MLP(config)self.y_cache = Nonedef forward(self, x: torch.Tensor, is_test: bool) -> torch.Tensor:if is_test and self.y_cache is not None:# Only use the most recent one, rest is in cachex = x[:, -1:]x = x + self.attn(self.rms_1(x), is_test)y = x + self.mlp(self.rms_2(x))if is_test:if self.y_cache is None:self.y_cache = y  # Build cacheelse:self.y_cache = torch.cat([self.y_cache, y], dim=1)[:, 1:]  # Update cachereturn y

代码看到这里不太想继续看了,太多glunoTS库里面的函数了,我完全不熟悉这个库,看起来太痛苦了,还有很多的困惑,最大的困惑就是数据是怎么对齐的,怎么输入到Llama里面的,慢慢看吧

其他

来源
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/585806.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

爬虫工作量由小到大的思维转变---<第三十五章 Scrapy 的scrapyd+Gerapy 部署爬虫项目>

前言: 项目框架没有问题大家布好了的话,接着我们就开始部署scrapy项目(没搭好架子的话,看我上文爬虫工作量由小到大的思维转变---&#xff1c;第三十四章 Scrapy 的部署scrapydGerapy&#xff1e;-CSDN博客) 正文: 1.创建主机: 首先gerapy的架子,就相当于部署服务器上的;所以…

Ubuntu 18.04搭建RISCV和QEMU环境

前言 因为公司项目代码需要在RISCV环境下测试&#xff0c;因为没有硬件实体&#xff0c;所以在Ubuntu 18.04上搭建了riscv-gnu-toolchain QEMU模拟器环境。 安装riscv-gnu-toolchain riscv-gnu-toolchain可以从GitHub上下载源码编译&#xff0c;地址为&#xff1a;https://…

大华主动注册协议介绍

一、大华主动注册协议介绍 前面写了一篇文章&#xff0c;介绍一些设备通过大华主动注册协议接入到AS-V1000的文章&#xff0c;很多问我关于大华主动注册协议的相关知识。 由于大华主动注册协议是一种私有协议&#xff0c;通常不对外公开详细的协议规范和技术细节。因此…

C++ Primer Plus----第十二章--类和动态内存分布

本章内容包括&#xff1a;对类成员使用动态内存分配&#xff1b;隐式和显式复制构造函数&#xff1b;隐式和显式重载赋值运算符&#xff1b;在构造函数中使用new所必须完成的工作&#xff1b;使用静态类成员&#xff1b;将定位new运算符用于对象&#xff1b;使用指向对象的指针…

ssm基于web的志愿者管理系统的设计与实现+vue论文

摘 要 使用旧方法对志愿者管理系统的信息进行系统化管理已经不再让人们信赖了&#xff0c;把现在的网络信息技术运用在志愿者管理系统的管理上面可以解决许多信息管理上面的难题&#xff0c;比如处理数据时间很长&#xff0c;数据存在错误不能及时纠正等问题。这次开发的志愿者…

main参数传递、反汇编、汇编混合编程

week03 一、main参数传递二、反汇编三、汇编混合编程 一、main参数传递 参考 http://www.cnblogs.com/rocedu/p/6766748.html#SECCLA 在Linux下完成“求命令行传入整数参数的和” 注意C中main: int main(int argc, char *argv[]), 字符串“12” 转为12&#xff0c;可以调用atoi…

两种汇编的实验

week04 一、汇编-1二、汇编-2 一、汇编-1 1 通过输入gcc -S -o main.s main.c -m32 将下面c程序”week0401学号.c“编译成汇编代码 int g(int x){ return x3; } int f(int x){ int i 学号后两位&#xff1b; return g(x)i; } int main(void){ return f(8)1; } 2. 删除汇编代码…

『番外篇六』SwiftUI 取得任意视图全局位置的三种方法

概览 在 SwiftUI 开发中,利用描述性代码我们可以很轻松的构建各种丰富多彩的视图。我们可以设置它们的大小、位置、颜色并应用不计其数的修改器。 但是,小伙伴们是否想过在 SwiftUI 中如何获取一个视图的全局位置坐标呢? 在本篇博文中,您将学到如下内容: 概览1. SwiftU…

守护 C 盘,Python 相关库设置

文章目录 前言Python 相关查看所有 Python 安装位置查看 Python 依赖位置查看 conda 配置查看 env 列表移除指定 env创建 env进入 env删除环境位置目录添加环境位置 (将位置置顶)查看 pip 缓存位置设置 pip 缓存位置 其他进入 Temp修改位置 Python技术资源分享1、Python所有方向…

(001)Unit 编译 UTF8JSON

文章目录 编译 Dll编译报错附录 编译 Dll 新建工程&#xff1a; 注意 UnityEngineDll 的选择&#xff01;2022 版本的太高了&#xff01;&#xff01;&#xff01; 下载包&#xff0c;导入unity : 3. 将 unf8json 的源码拷贝到新建的工程。 4. 编译发布版本&#xff1a; 编译…

竞赛保研 基于卷积神经网络的乳腺癌分类 深度学习 医学图像

文章目录 1 前言2 前言3 数据集3.1 良性样本3.2 病变样本 4 开发环境5 代码实现5.1 实现流程5.2 部分代码实现5.2.1 导入库5.2.2 图像加载5.2.3 标记5.2.4 分组5.2.5 构建模型训练 6 分析指标6.1 精度&#xff0c;召回率和F1度量6.2 混淆矩阵 7 结果和结论8 最后 1 前言 &…

mongoose中http server服务器解决“Access-Control-Allow-Origin mongoose”跨域问题

问题 使用mongoose做http服务器&#xff0c;自己构造的浏览器端jquery在访问server时&#xff0c;会遇到&#xff1a; Access to XMLHttpRequest at http://127.0.0.1:8000/ from origin null has been blocked by CORS policy: No Access-Control-Allow-Origin header is pr…

python+django大自然环境保护宣传网站62r9b

本课题使用Python语言进行开发。基于web,代码层面的操作主要在PyCharm中进行&#xff0c;将系统所使用到的表以及数据存储到MySQL数据库中 本系统由后台管理子系统&#xff0c;登录子系统&#xff0c;按登陆角色及权限划分为管理员:个人中心&#xff0c;用户管理&#xff0c;文…

遇到DDOS怎么办,盾真的可以抗攻击吗

网络在以难以想象的速度发展&#xff0c;黑客们针对网络漏洞发起的攻击也从未停止&#xff0c;但复杂的网络环境让网络安全的维护更为艰难&#xff0c;如果游戏公司没有做好防御措施&#xff0c;黑客发起攻击只是时间问题。在网络攻击愈加多元化的今天&#xff0c;游戏行业可以…

懒加载的el-tree中没有了子节点之后还是有前面icon箭头的展示,如何取消没有子节点之后的箭头显示

没有特别多的数据 <template><el-tree:props"props":load"loadNode"lazyshow-checkbox></el-tree></template><script>export default {data() {return {props: {label: name,children: zones,isLeaf:"leaf",//关…

交互式笔记Jupyter Notebook本地部署并实现公网远程访问内网服务器

最近&#xff0c;我发现了一个超级强大的人工智能学习网站。它以通俗易懂的方式呈现复杂的概念&#xff0c;而且内容风趣幽默。我觉得它对大家可能会有所帮助&#xff0c;所以我在此分享。点击这里跳转到网站。 文章目录 1.前言2.Jupyter Notebook的安装2.1 Jupyter Notebook下…

故障诊断模型 | Maltab实现PSO-BP粒子群算法优化BP神经网络的故障诊断

文章目录 效果一览文章概述模型描述源码设计参考资料效果一览 文章概述 故障诊断模型 | Maltab实现PSO-BP粒子群算法优化BP神经网络的故障诊断 模型描述 在机器学习领域,我们常常需要通过训练数据来学习一个函数模型,以便在未知的数据上进行预测或分类。传统的神经网络模型需…

Java设计模式-外观模式

目录 一、影院管理项目 二、外观模式 &#xff08;一&#xff09;基本介绍 &#xff08;二&#xff09;原理类图 &#xff08;三&#xff09;解决影院管理 &#xff08;四&#xff09;注意事项和细节 &#xff08;五&#xff09;外观模式在MyBatis框架应用的源码分析 一…

Python+OpenGL绘制3D模型(七)制作3dsmax导出插件

系列文章 一、逆向工程 Sketchup 逆向工程&#xff08;一&#xff09;破解.skp文件数据结构 Sketchup 逆向工程&#xff08;二&#xff09;分析三维模型数据结构 Sketchup 逆向工程&#xff08;三&#xff09;软件逆向工程从何处入手 Sketchup 逆向工程&#xff08;四&#xf…

Linux操作系统( YUM软件仓库技术 )

镜像文件的回环挂载&#xff08;把iso镜像文件释放成系统安装光盘&#xff09;foundation0上操作 回环挂载的用法&#xff1a; du -sh 对象名 //估算文件&#xff08;一切对象皆文件&#xff09;大小 !$ //上一条命令的最后一个参数 新创建的挂载点目录是空白目录 挂载&#xf…