机器学习周记(第三十二周:文献阅读-时空双通路框架)2024.3.25~2024.3.31

目录

摘要

ABSTRACT

1 论文信息

1.1 论文标题

1.2 论文摘要 

1.3 论文模型

1.3.1 Spatial Encoder(空间编码器)

1.3.2 Temporal Encoder(时间编码器)

2 相关代码


摘要

  本周阅读了一篇运用GNN进行时间序列预测的论文。论文主要提出了一种分离空间和时间编码器的双通路框架,用于通过有效的时空表示学习准确预测水温,特别是极端的高水温。框架主要使用Transformer的自注意机制构造空间和时间编码器执行任务,同时采用了各种补丁嵌入方法和空间特征位置嵌入方法的组合。此外,本周还运用GAT模型进行了一个时间序列预测的实验。

ABSTRACT

This week, We read a paper on time series prediction using GNNs. The paper proposes a dual-pathway framework that separates spatial and temporal encoders for accurate prediction of water temperature, particularly extreme high water temperature, through effective spatiotemporal representation learning. The framework primarily utilizes Transformer's self-attention mechanism to construct spatial and temporal encoders for the task, along with various combinations of patch embedding methods and spatial feature positional embedding methods. Additionally, this week, an experiment on time series prediction was conducted using the GAT model.

1 论文信息

1.1 论文标题

Two-pathway spatiotemporal representation learning for extreme water temperature prediction

1.2 论文摘要 

  准确预测极端水温对于了解海洋环境的变化以及减少全球变暖导致的海洋灾害至关重要。在本研究中,提出了一个分离空间和时间编码器的双通路框架,用于通过有效的时空表示学习准确预测水温,特别是极端高水温。基于Transformer自注意机制构造空间和时间编码器网络执行任务,预测朝鲜半岛周围16个沿海位置未来连续七天的水温时间序列,同时采用了各种组合的补丁嵌入方法空间特征的位置嵌入。最后还进行了与传统深度卷积和递归网络的比较实验,通过比较和评估这些结果,所提出的双路径框架能够通过更好地捕获来自开放海洋和区域海域的时空相互关系和长期特征关系,改善对极端沿海水温的可预测性,并进一步确定基于自注意力的空间和时间编码器的最佳架构细节。此外,为了检查所提出的模型的可解释性及其与领域知识的一致性,进行了模型可视化并分析了空间和时间注意力图,展示了与未来预测更相关的时空输入序列的权重。

1.3 论文模型

  双通路框架主要包括空间编码器(Spatial Encoder)时间编码器(Temporal Encoder),如Fig.(a)所示。这两个编码器主要用来学习多尺度时空相互关系特征表示,将时空数据V(t,h,w)的特征表示通过SpatialEncoderTemporalEncoder分为两个空间组件,维度分别为hw,以及一个时间组件,维度为t。第一条路径中,SpatialEncoder捕获给定输入数据的一个连续片段的空间依赖关系。第二条路径中,TemporalEncoder捕获从SpatialEncoder中以时间顺序提供的连续序列空间特征向量之间的时间依赖关系,与时间特征融合。Fig.1(a)中的二维嵌入包含一个将输入数据映射到特征空间的操作,然后将其输出结果送到连续的空间编码器中。对于CNN,它通过3 \times 3卷积核卷积操作执行,而对于基于自注意力的网络,则是将每个数据分成补丁。嵌入特征被馈送到SpatialEncoder,并构建为Feature Vector以学习时空特征表示。Feature Vector的大小是固定的,并且为了比较各种实验模型组合的性能,确定了实现它所需的编码器数量。通过空间编码器压缩的Feature Vector以时间顺序被接收到TemporalEncoder中,以构建集成的时空特征向量。

Fig.1

  输入的网格化时空数据序列记作VV \in \mathbb{R}^{T \times H \times W}。其中,HWT分别代表网格化时空数据的高度、宽度和多个连续时间序列。V被映射为来自时间、高度和宽度维度的补丁嵌入的一系列标记\widetilde{Z}\widetilde{Z} \in \mathbb{R}^{n_{t} \times n_{h} \times n_{w} \times d}。如果使用位置嵌入,\widetilde{Z}将被重塑为\mathbb{R}^{N \times d}d是标记维度,N表示非重叠图像补丁。从每个数据序列中提取n_{h} \times n_{w}个非重叠图像补丁,然后将具有位置嵌入的总共n_{t} \times n_{h} \times n_{w}个标记传到Transformer的核心共同块SpatialEncoder中的多头自注意力(MSA)。从SpatialEncoder中获得的输出F是从输入的网格化时空数据序列V中获得的新表示,具有4 \times 4 \times 512的隐藏特征向量。SpatialEncoder的输出F被馈送到TemporalEncoder的输入中,以学习空间特征的时间依赖性。基于TemporalEncoderTransformer架构,通过缩放点积注意力计算输入F的单个自注意力,如Eq.(1)所述。缩放点积注意力的输入包括维度为d_{k}queries(Qkeys(K,以及维度为d_{v}的值(V)。计算queries与所有keys的点积,每个除以\sqrt{d_{k}},并应用softmax函数来获得值的权重。与具有Fkeysqueries和值的单个注意力函数相比,通过MSA,将querieskeys和值线性投影h次,分别到d_{k}d_{k}d_{v}维度,然后在这些投影版本的querieskeys和值上并行执行注意力函数,产生d_{v}维输出值,更利于结果。这些输出被串联并再次投影,得到最终值。多头注意力使模型能够同时关注不同位置的不同表示子空间的信息。使用单个注意力头,平均会抑制这一特性。通过MSA模块,表示具有时间特征的512维向量,这些向量被转换为连续时空序列的16个目标。

Attention(Q,K,V)=softmax(\frac{QK^{T}}{\sqrt{d_{k}}})V                                                                      (1)

MultiHead(Q,K,V)=Concat(head_{1},...,head_{h})W^{O}                                             (2)

其中\sqrt{d_{k}}是key向量和query向量的维度,head_{i}=Attention(QW_{i}^{Q},KW_{i}^{K},VW_{i}^{V})

  SpatialEncoder应用了ViT(Fig.2(a)和(b))ViT变体Swin Transformer(SwinT,Fig.2(c)和(d))全局上下文ViT(GCViT,Fig.2(c)和(e))多路径ViT(MPViT)(Fig.2(f)和(g)),还包括基本的2D CNN(Fig.1(b))。此外,还将LSTM(Fig.1(c))基于多头自注意力的 Transformer(MAT,Fig.1(d))作为TemporalEncoder应用,通过不同方式的SpatialEncoderTemporalEncoder的组合来检验多尺度自注意力时空特征表示的性能,如Table.1所述。同时还研究了10种单独应用SpatialEncoderTemporalEncoder的实验组合。如Fig.2所示,ViTViT 变体SwinTGCViTMPViT具有共同的基于多头自注意力的Transformer架构。然而,位置和补丁嵌入方法存在差异,如Fig2(a)(c)(f)所示。此外,可以根据投影注意力的方法以及是否将其视为分层结构进行分类,如Table2所示。

Fig.2
Table.1
Table.2

1.3.1 Spatial Encoder(空间编码器)

ViT:一个基于自注意力的空间编码器,通常由两部分组成:补丁构建和Transformer块。如Fig.2(a)所示,对于来自输入的补丁嵌入,2D网格数据的序列被划分为固定大小的补丁,然后进行线性嵌入,并添加一维(1D)位置嵌入以将它们重塑为扁平化的2D补丁序列1D位置嵌入也按照网格顺序添加到补丁嵌入中。结果的补丁嵌入向量序列被输入到SpatialEncoder中,它具有与标准MAT Transformer相同的结构(Fig.2(b))。它由LayerNorm(LN)MSAMLP块组成。在每个MLP块之前和之后分别应用了LN残差连接MLP包含用于非线性的高斯误差线性单元(GeLU)MSA是自注意力的扩展,可以并行执行i个自注意力操作并投影它们的串联输出。

SwinT:为了学习更高分辨率的空间特征,SwinT通过逐渐合并深层中相邻的补丁来构建分层特征图,从较小的补丁开始。分层表示是在一种偏移的窗口方案中计算的,它通过将自注意力计算限制在非重叠的局部窗口中来提高效率,同时允许跨窗口连接。对于补丁嵌入(Fig.2(c)),它首先将输入的网格数据划分为非重叠的补丁。然后在这些补丁标记上应用具有修改后自注意力(即SwinT块)Transformer块,配置如Fig.2(d)所示。SwinT块通过依次连接用于窗口内自注意力操作的Window Multi-head Self-Attention (W-MSA)和用于W-MSA的窗口之间补丁的自注意力操作的Shifted Window Multi-head Self-Attention (SW-MSA),然后在两个MLP层之间插入GeLU来配置。自注意力是在本地窗口内计算的,这些窗口被排列成以非重叠方式均匀划分图像。在每个MSA模块和每个MLP块之前都应用了LN层和残差连接,并在每个模块之后应用了残差连接Transformer块保持了补丁的数量(H/2 \times W/2),并且与线性嵌入一起被划分为第1个Transformer层。为了创建分层表示,随着网络的加深,补丁合并层会减少补丁(标记)的数量。Transformer层共同创建具有相同分辨率的特征图的分层表示。

1.3.2 Temporal Encoder(时间编码器)

  TemporalEncoder使用了两种模型:一种是LSTM,它是一种RNN的变种,另一种是基于自注意力的TE,用于通过SpatialEncoder(Fig.2(c)和(d))捕捉压缩的空间特征向量连续序列之间的时间依赖关系。LSTM接收压缩了空间信息的特征向量作为输入,并将其馈送到LSTM单元中以编码时间信息。TemporalEncoder接收压缩了空间信息的特征向量作为输入,并使用自注意力编码时间信息。与此同时,MLPLinear-GeLU-Linear组成。一系列由先前的SpatialEncoder层压缩的空间特征向量被顺序地馈送到TemporalEncoder。然后通过双通道方法将空间和时间特征分别融合为输出特征图,用于空间-时间特征表示学习。

2 相关代码

实验:PyG搭建图神经网络实现多变量输入多变量输出时间序列预测

参考代码:PyG搭建图神经网络实现多变量输入多变量输出时间序列预测_利用pyg库实现时间序列预测-CSDN博客

step1:首先需要构造图结构,在将多元时间序列数据转化成一个图结构数据之前,需要确定各个节点的空间关系(Node Embedding,一个变量为一个节点)。一个很自然的想法就是计算不同的变量序列间的相关系数,然后使用一个阈值进行判断,如果两个节点(变量)它们的序列间的相关系数大于这个阈值,那么两个变量节点间就存在边。

# num_nodes:节点(变量)数量; data:节点特征数据,data的维度为(x, num_nodes)
def create_graph(num_nodes, data):# 将data(feature, num_nodes)序列转换成特征矩阵features(num_nodes, feature)features = torch.transpose(torch.tensor(data), 0, 1)# 创建一个空的邻接矩阵,用于存储图中的边edge_index = [[], []]# 遍历所有节点for i in range(num_nodes):# 遍历当前节点之后的所有节点,以避免重复添加边for j in range(i + 1, num_nodes):# 从输入的data数据中获取两个节点的特征向量x, y = data[:, i], data[:, j]# 计算两个节点之间的相关性corr = calc_corr(x, y)# 如果两个节点之间的相关性大于等于0.4,则将它们之间添加一条边if corr >= 0.4:edge_index[0].append(i)edge_index[1].append(j)# 将邻接矩阵转换为PyTorch的长整型张量edge_index = torch.LongTensor(edge_index)# 创建图对象graph = Data(x=features, edge_index=edge_index)# 将有向图转换为无向图,以确保每一条有向边都有一个相对应的反向边graph.edge_index = to_undirected(graph.edge_index, num_nodes=num_nodes)return graph

本次实验的数据集采用伊斯坦布尔股票交易数据集:ISTANBUL STOCK EXCHANGE - UCI Machine Learning Repository

最终构建出的图结构为:

step2:在接下来的训练、验证以及测试过程中保持图的整体结构不变。也就是使用静态图,即图中的关系是通过训练集中的数据确定的。如果想要实现动态图,一个很自然的想法是在构造数据集时,每次都利用一个大小为(num_nodesseq_len)的矩阵计算出图中的各个参数。这样操作后每一个样本都对应一个图,图中的节点数为num_nodes,节点的初始特征都为长度为seq_len的向量,图中的边通过num_nodes个长度为seq_len的向量间的相关系数来确定。

构造数据集:

def nn_seq(num_nodes, seq_len, B, pred_step_size, data):# 将数据集划分为训练集(60%)、验证集(20%)和测试集(20%)train = data[:int(len(data) * 0.6)]# print(train)val = data[int(len(data) * 0.6):int(len(data) * 0.8)]# print(val)test = data[int(len(data) * 0.8):len(data)]# print(test)# 归一化scaler = MinMaxScaler()train_normalized = scaler.fit_transform(data[:int(len(data) * 0.8)].values)# print(train_normalized)val_normalized = scaler.transform(val.values)# print(val_normalized)test_normalized = scaler.transform(test.values)# print(test_normalized)# 创建训练集(包含测试集)图graph = create_graph(num_nodes, data[:int(len(data) * 0.8)].values)# 数据集处理(生成样本和标签)# step_size:每一步的步长;shuffle:是否打乱数据def process(dataset, batch_size, step_size, shuffle):# 将数据集由DataFrame转化为列表dataset = dataset.tolist()# print(len(dataset), len(dataset[0]))# 创建样本序列seq = []# 遍历训练数据集,直到最后一个滑动窗口和预测步长前for i in tqdm(range(0, len(dataset) - seq_len - pred_step_size, step_size)):# 创建训练序列train_seq = []# 遍历每一个滑动窗口for j in range(i, i + seq_len):# 获取一个滑动窗口的样本x = []for c in range(len(dataset[0])):x.append(dataset[j][c])train_seq.append(x)# print(x)# 获取一个滑动窗口的标签train_labels = []for j in range(len(dataset[0])):train_label = []for k in range(i + seq_len, i + seq_len + pred_step_size):train_label.append(dataset[k][j])train_labels.append(train_label)# 得到每一个滑动窗口的训练样本与对应的标签,转化为tensortrain_seq = torch.FloatTensor(train_seq)train_labels = torch.FloatTensor(train_labels)# print(train_seq.shape, train_labels.shape)seq.append((train_seq, train_labels))seq = MyDataset(seq)seq = DataLoader(dataset=seq, batch_size=batch_size, shuffle=shuffle, num_workers=0, drop_last=False)return seq# 得到每个数据集的DataLoaderDtr = process(train_normalized, B, step_size=1, shuffle=True)Val = process(val_normalized, B, step_size=1, shuffle=True)Dte = process(test_normalized, B, step_size=pred_step_size, shuffle=False)return graph, Dtr, Val, Dte, scaler

运行后得到训练集的图结构,并将其数据与邻接矩阵加入后续的模型训练:

函数最后返回训练集,验证集和测试集的DataLoader(Dtr,Val,Dte)以及得到的归一化参数(scaler),加入后续模型的训练(注意:本次实验进行的是多步预测,预测下一个月的股票交易数据):

step3:定义预测模型。这里使用GAT(图注意力网络),也可以换成GCNGraphSAGE等其他的模型。

class GAT(nn.Module):def __init__(self, in_features, h_features, out_features):super(GAT, self).__init__()self.conv1 = GATConv(in_features, h_features, heads=4, concat=False)self.conv2 = GATConv(h_features, out_features, heads=4, concat=False)def forward(self, x, edge_index):x = F.elu(self.conv1(x, edge_index))x = self.conv2(x, edge_index)return x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/791176.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【数据处理包Pandas】分组及相关操作

目录 一、初步认识分组并查看分组信息(一)通过聚合函数查看分组信息(二)转换成列表查看所有组的信息(三)通过循环查看各组的名称和组中的数据信息(四)通过get_group()方法直接获得一…

【蓝桥杯练习】tarjan算法求解LCA

还是一道比较明显的求LCA(最近公共祖先)模型的题目,我们可以使用多种方法来解决该问题,这里我们使用更好写的离线的tarjan算法来解决该问题。 除去tarjan算法必用的基础数组,我们还有一个数组d[],d[i]记录的是每个点的出度,也就是它的延迟时间…

高级IO/多路转接-select/poll(1)

概念背景 IO的本质就是输入输出 刚开始学网络的时候,我们简单的写过一些网络服务,其中用到了read,write这样的接口,当时我们用的就是基础IO,高级IO主要就是效率问题。 我们在应用层调用read&&write的时候&…

YOLOv2

YOLOv2 论文介绍论文改进1. Batch Normalization2. High Resolution Classifier3. Convolutional With Anchor Boxes4. vgg16换成darknet-195. Dimension Clusters(w h的聚类)6 预测坐标7. passthrough8. 多尺度输入训练 损失函数 论文介绍 论文名字&am…

Educational Codeforces Round 133 (Rated for Div. 2) C. Robot in a Hallway

题目 思路&#xff1a; #include <bits/stdc.h> using namespace std; #define int long long #define pb push_back #define fi first #define se second #define lson p << 1 #define rson p << 1 | 1 const int maxn 1e6 5, inf 1e18, maxm 4e4 5; c…

多模态系列-综述Video Understanding with Large Language Models: A Survey

本文是LLM系列文章,针对《Video Understanding with Large Language Models: A Survey》的翻译。 论文链接:https://arxiv.org/pdf/2312.17432v2.pdf 代码链接:https://github.com/yunlong10/Awesome-LLMs-for-Video-Understanding 大型语言模型下的视频理解研究综述 摘要…

人工智能大模型+智能算力,企商在线以新质生产力赋能数字化转型

2024 年3月28 日&#xff0c;由中国互联网协会主办、中国信通院泰尔终端实验室特别支持的 2024 高质量数字化转型创新发展大会暨铸基计划年度会议在京召开。作为新质生产力代表性企业、数算融合领导企业&#xff0c;企商在线受邀出席大会主论坛圆桌对话&#xff0c;与行业专家共…

Lora人机界面开发 3

1 显示原理 液晶的形成&#xff1a;像水一样液晶介于固态和液态之间 偏光原理&#xff1a;两块偏光的栅栏角度相互垂直时光线就完全无法通过 内部结构&#xff1a;利用电场控制液晶分支的旋转 颜色深度 TFT开关的工作原理&#xff1a; 扫描线连接同一列所有TFT栅极电极&…

通过mapreduce程序统计旅游订单(wordcount升级版)

通过mapreduce程序统计旅游订单&#xff08;wordcount升级版&#xff09; 本文将结合一个实际的MapReduce程序案例&#xff0c;探讨如何通过分析旅游产品的预订数据来揭示消费者的偏好。 程序概览 首先&#xff0c;让我们来看一下这个MapReduce程序的核心代码。这个程序的目…

创新视角:探索系统产品可用性测试的前沿分类方法与实践应用

一、可用性测试概念 1、什么是可用性&#xff1f; 任何与人可以发生交互的产品都应该是可用的&#xff0c;就一般产品而言&#xff0c;可用性被定义为目标用户可以轻松使用产品来实现特定目标。 ISO9241/11中的定义是&#xff1a; 一个产品可以被特定的用户在特定的场景中&a…

跨越时空,启迪智慧:奇趣相机重塑儿童摄影与教育体验

【科技观察】近期&#xff0c;奇趣未来公司以其创新之作——“奇趣相机”微信小程序&#xff0c;强势进军儿童AI摄影市场。这款专为亚洲儿童量身定制的应用&#xff0c;凭借精准贴合亚洲儿童面部特征的AIGC大模型&#xff0c;以及丰富的摄影模板与场景设定&#xff0c;正在重新…

Ps:匹配颜色

匹配颜色 Match Color命令可以将一个图像的颜色与另一个图像的颜色相匹配。 Ps菜单&#xff1a;图像/调整/匹配颜色 Adjustments/Match Color 匹配颜色命令可匹配多个图像之间、多个图层之间或者多个选区之间的颜色&#xff0c;还可以通过更改亮度和色彩范围以及中和色痕来调整…

Day17-【Java SE进阶】特殊文本文件、日志技术

一、特殊文本文件 为什么要用这些特殊文件&#xff1f; 存储多个用户的&#xff1a;用户名、密码 存储有关系的数据&#xff0c;做为系统的配置文件做为信息进行传输 日志技术 把程序运行的信息&#xff0c;记录到文件中&#xff0c;方便程序员定位bug、并了解程序的执行情…

Java并发编程基础面试题详细总结

1. 什么是线程和进程? 1.1 何为进程? 进程是程序的一次执行过程&#xff0c;是系统运行程序的基本单位&#xff0c;因此进程是动态的。系统运行一个程序即是一个进程从创建&#xff0c;运行到消亡的过程。 在 Java 中&#xff0c;当我们启动 main 函数时其实就是启动了一个…

Windows进程监视器Process Monitor

文章目录 Process Monitor操作逻辑 Process Monitor Process Monitor是 Windows 的高级监视工具&#xff0c;是Filemon Regmon的整合增强版本&#xff0c;实时显示文件系统&#xff0c;注册表&#xff0c;网络活动&#xff0c;进程或线程活动&#xff0c;资料收集事件&#x…

阿里云弹性计算通用算力型u1实例性能评测,性价比高

阿里云服务器u1是通用算力型云服务器&#xff0c;CPU采用2.5 GHz主频的Intel(R) Xeon(R) Platinum处理器&#xff0c;ECS通用算力型u1云服务器不适用于游戏和高频交易等需要极致性能的应用场景及对业务性能一致性有强诉求的应用场景(比如业务HA场景主备机需要性能一致)&#xf…

记录一次threejs内存泄露问题排查过程

问题描述&#xff1a; 一个有关地图编辑的使用threejs的这样的组件&#xff0c;在多次挂载销毁后&#xff0c;页面开始卡顿。 问题排查&#xff1a; 1. 首先在chrome dev tool中打开performance monitor面板&#xff0c;观察 JS head size、DOME Nodes、Js event listeners数…

【C++】C++11类的新功能

&#x1f440;樊梓慕&#xff1a;个人主页 &#x1f3a5;个人专栏&#xff1a;《C语言》《数据结构》《蓝桥杯试题》《LeetCode刷题笔记》《实训项目》《C》《Linux》《算法》 &#x1f31d;每一个不曾起舞的日子&#xff0c;都是对生命的辜负 目录 前言 默认成员函数 类成…

Java基于微信小程序高校体育场管理小程序

博主介绍&#xff1a;✌IT徐师兄、7年大厂程序员经历。全网粉丝15W、csdn博客专家、掘金/华为云//InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;&#x1f3…

跨域问题解决方案之CORS

跨域问题解决方案之CORS 文章目录 跨域问题解决方案之CORS概述浏览器的同源策略同源的判定规则目的同源策略的限制范围 浏览器的同源策略为什么会引发跨域问题&#xff1f;CORS规则CORS解决方案CORS方案将请求分为两类举例简单请求预检请求总结学以致用 概述 浏览器安全的基石…