STGCN解读(论文+代码)

一、引言

        引言部分不是论文的重点,主要讲述了交通预测的重要性以及一些传统方法的不足之处。进而推出了自己的模型——STGCN。

二、交通预测与图卷积

        

        第二部分讲述了交通预测中路图和图卷积的概念。

        首先理解道路图,交通预测被定义为典型的时间序列预测问题,即根据历史数据预测未来的交通情况。在该工作中,交通网络被建模为图结构,节点表示监控站,边表示监控站之间的连接。

        其次是图卷积,由于传统的卷积操作无法应用于图数据,作者介绍了两种扩展卷积到图数据的方法,即扩展卷积的空间定义、谱图卷积。这里需要着重注意谱图卷积,因为STGCN就是采用了傅里叶变换的方法,这里的公式后面也要提到。

三、STGCN结构与原理(重点)

        先看整体网络结构:

1. 空域卷积块(图卷积块)

1.1 论文

        这里提到了两种方法,即切比雪夫多项式和一阶近似。对应于代码中的 ChebGraphConv 和 GraphConv。

        为什么要用切比雪夫多项式计算?因为传统的拉普拉斯计算起来非常复杂,使用了切比雪夫多项式以后,通过使用多项式来表示卷积核,图卷积操作可以在局部化的节点邻域内进行,从而避免了全图的计算。

        一阶近似进一步简化了图卷积的计算,将图拉普拉斯算子(公式中的L)的高阶近似简化为一阶。这使得图卷积操作只依赖于当前节点及其直接相邻节点,计算复杂度大幅降低。

1.2 代码(GraphConvLayer)

        这里有我加了注释的代码:OracleRay/STGCN_pytorch: The PyTorch implementation of STGCN. (github.com)

        时空卷积块和输出层的定义都在 layers.py 中。

        ChebGraphConv 类的前向传播方法:

def forward(self, x):# bs, c_in, ts, n_vertex = x.shapex = torch.permute(x, (0, 2, 3, 1))  # 将时序维度和顶点维度排列到一起,方便后续的图卷积计算if self.Ks - 1 < 0:raise ValueError(f'ERROR: the graph convolution kernel size Ks has to be a positive integer, but received {self.Ks}.')elif self.Ks - 1 == 0:x_0 = xx_list = [x_0]  # 只使用第 0 阶,即不考虑邻居节点的影响,直接使用输入特征elif self.Ks - 1 == 1:x_0 = x# hi 是邻接矩阵的索引,btij 是输入张量 x 的索引,bthj为更新后的特征表示x_1 = torch.einsum('hi,btij->bthj', self.gso, x)  # 邻接矩阵gso和输入特征x进行相乘x_list = [x_0, x_1]  # 使用第 0 阶和第 1 阶的节点信息elif self.Ks - 1 >= 2:x_0 = xx_1 = torch.einsum('hi,btij->bthj', self.gso, x)x_list = [x_0, x_1]for k in range(2, self.Ks):  # 根据切比雪夫多项式的定义来计算,利用前两阶多项式来构建第 k 阶多项式x_list.append(torch.einsum('hi,btij->bthj', 2 * self.gso, x_list[k - 1]) - x_list[k - 2])x = torch.stack(x_list, dim=2)  # 将所有阶的节点特征堆叠在一起,形成一个新的张量cheb_graph_conv = torch.einsum('btkhi,kij->bthj', x, self.weight)if self.bias is not None:cheb_graph_conv = torch.add(cheb_graph_conv, self.bias)  # 添加偏置项else:cheb_graph_conv = cheb_graph_conv  # 强迫症return cheb_graph_conv

        Ks表示空间卷积核大小,gso代表交通预测图的邻接矩阵。在前向传播的过程中,判断 Ks - 1 的值的大小。等于0和等于1分别代表着切比雪夫多项式的第 0 阶和第 1 阶,当 Ks - 1 大于等于 2 时,不仅需要考虑第 0 阶和第 1 阶的邻居节点,还会通过递归关系计算更高阶的邻居节点信息。

        构建 k 阶多项式时,需要邻接矩阵 gso 和 (k - 1) 阶相乘,然后与 (k - 2) 阶相减。这样就满足了切比雪夫多项式的要求,即:

         最后把得到的值乘以权重(weight),如果有偏置项,再加偏置值(bias)即可。

        GraphConv 类与此类似。


        然后在 GraphConvLayer 中就可选择到底是使用切比雪夫多项式(ChebGraphConv)还是一阶近似(GraphConv)。

def forward(self, x):x_gc_in = self.align(x)if self.graph_conv_type == 'cheb_graph_conv':x_gc = self.cheb_graph_conv(x_gc_in)elif self.graph_conv_type == 'graph_conv':x_gc = self.graph_conv(x_gc_in)x_gc = x_gc.permute(0, 3, 1, 2)x_gc_out = torch.add(x_gc, x_gc_in)  # 残差连接return x_gc_out

2. 时域卷积块

2.1 论文

       时域卷积块最关键的两个内容就是因果卷积和门控机制GLU,不过这两个并不是这篇论文里提出来的。

        因果卷积(代码中 CausalConv 类)是时域卷积网络模型(TCN)中的重要内容。在本篇论文中,使用了1-D因果卷积来确保当前时刻只依赖于过去的输入数据,这样未来的信息在当前时刻就不会被使用。

        GLU是2016年由Yann N. Dauphin在论文《Language Modeling with Gated Convolutional Networks》中提出的。时域卷积块采用门控线性单元(GLU)作为非线性激活函数可以控制哪些输入是重要的,而不是所有信息都平等对待,这样有助于在时间序列中提取关键特征。

        残差连接、瓶颈策略、并行训练等作者只是提了一嘴,不是重点。

2.2 代码(TemporalConvLayer)

        首先看因果卷积的代码:

class CausalConv2d(nn.Conv2d):def __init__(self, in_channels, out_channels, kernel_size, stride=1, enable_padding=False, dilation=1, groups=1,bias=True):kernel_size = nn.modules.utils._pair(kernel_size)  # 卷积核大小,表示对多少个像素(或特征)进行卷积。stride = nn.modules.utils._pair(stride)  # 步长,控制卷积核滑动的步幅dilation = nn.modules.utils._pair(dilation)  # dilation:膨胀系数,控制采样间隔,用来扩大卷积核的感受野if enable_padding == True:  # 启用零填充self.__padding = [int((kernel_size[i] - 1) * dilation[i]) for i in range(len(kernel_size))]else:self.__padding = 0self.left_padding = nn.modules.utils._pair(self.__padding)super(CausalConv2d, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=0,dilation=dilation, groups=groups, bias=bias)def forward(self, input):if self.__padding != 0:# F.pad() 函数用于在高度和宽度方向上添加填充input = F.pad(input, (self.left_padding[1], 0, self.left_padding[0], 0))result = super(CausalConv2d, self).forward(input)return result

        在初始化部分,如果需要零填充,则要先计算填充量。填充量的计算公式为:(kernel_size[i] - 1) * dilation[i] 。

        其中 kernel_size[i] - 1 表示在每个维度(高度和宽度)上,卷积核“超出”当前步长的部分。例如一个 3 × 3 大小的卷积核,就会有2个位置影响当前位置(左一个右一个)。而膨胀系数 dilation[i] 则意味着卷积核中元素之间有多少间隔。最后相乘即可求出需要填充几个位置。

        最后在前向传播过程中用 F.pad() 函数就可以实现左填充。


        然后再看时域卷积块(TemporalConvLayer)的前向传播代码:

def forward(self, x):x_in = self.align(x)[:, :, self.Kt - 1:, :]  # 对其输入通道数x_causal_conv = self.causal_conv(x)  # 进行因果卷积if self.act_func == 'glu' or self.act_func == 'gtu':x_p = x_causal_conv[:, : self.c_out, :, :]  # 分割出前半部分x_q = x_causal_conv[:, -self.c_out:, :, :]  # 分割出后半部分if self.act_func == 'glu':# 通过门控机制选择性保留某些时间步的特征,这对时间序列建模非常有效x = torch.mul((x_p + x_in), torch.sigmoid(x_q))  # 对 x_p 和输入的对齐结果 x_in 进行线性加和,并与 x_q 的 sigmoid 值进行点乘else:# tanh(x_p + x_in) ⊙ sigmoid(x_q)x = torch.mul(torch.tanh(x_p + x_in), torch.sigmoid(x_q))  # 使用 tanh 代替线性加和,具有非线性变换的特性elif self.act_func == 'relu':x = self.relu(x_causal_conv + x_in)elif self.act_func == 'silu':x = self.silu(x_causal_conv + x_in)else:raise NotImplementedError(f'ERROR: The activation function {self.act_func} is not implemented.')return x

        这里的代码只干了两件事:因果卷积和选择激活函数,与论文中的时域卷积块的思想大致相同。这里着重理解这行代码:

x = torch.mul((x_p + x_in), torch.sigmoid(x_q))

        (x_p + x_in) 将前半部分 x_p 与输入 x_in 的对齐结果进行线性加和,表示对主要特征的组合。sigmoid(x_q) 将后半部分 x_q 通过 sigmoid 函数转化为 0 到 1 之间的值,作为控制门。最后用⊙符号逐元素相乘,GLU就能决定 x_p 能通过多少信息。

        默认使用GLU激活函数,其他 if-else 语句中的激活函数不使用。

3. 时空卷积块

3.1 论文

        前面知道,两个时域卷积块 + 一个空域卷积块 = 一个时空卷积块。而且是两个时域卷积块夹着一个空域卷积块的三明治结构。这种设计可以同时处理交通网络中的 时间依赖空间依赖,即模型可以同时从时序信息和图结构中提取重要特征。

        中间的图卷积层负责从图结构(如道路网络)中提取空间特征。通过使用前面提到的图卷积方法(如切比雪夫多项式近似或一阶近似),可以高效地捕捉交通站点之间的连接关系​。

        上下两个时间卷积层负责提取时间依赖特征。通过因果卷积的方式,可以确保预测时只使用当前时刻及之前的交通数据,避免未来信息泄露​。

3.2 代码(STConvBlock)
class STConvBlock(nn.Module):def __init__(self, Kt, Ks, n_vertex, last_block_channel, channels, act_func, graph_conv_type, gso, bias, droprate):super(STConvBlock, self).__init__()# “三明治”结构:两个时域卷积块,一个空域卷积块self.tmp_conv1 = TemporalConvLayer(Kt, last_block_channel, channels[0], n_vertex, act_func)self.graph_conv = GraphConvLayer(graph_conv_type, channels[0], channels[1], Ks, gso, bias)self.tmp_conv2 = TemporalConvLayer(Kt, channels[1], channels[2], n_vertex, act_func)self.tc2_ln = nn.LayerNorm([n_vertex, channels[2]], eps=1e-12)  # 归一化:缓解梯度消失或梯度爆炸问题self.relu = nn.ReLU()self.dropout = nn.Dropout(p=droprate)  # 正则化:dropout率为0.5def forward(self, x):x = self.tmp_conv1(x)x = self.graph_conv(x)x = self.relu(x)x = self.tmp_conv2(x)x = self.tc2_ln(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)x = self.dropout(x)return x

        在初始化函数中,先定义“三明治”结构,使用之前写好的时域卷积块和图卷积块。然后确定归一化方法,激活函数,正则化方法。

4. 输出层

4.1 论文

        论文中的输出层是由一个时域卷积块和一个全连接层组成的。

        当最后一个时空卷积块处理完数据之后,输出的是一个三维张量(M × n × C),其中M是时间步数(例如,过去60分钟的交通数据),n是交通网络中的节点数(即监测站或路段数),C是特征通道数。而论文使用了一个时域卷积层将这些特征进一步压缩成一个单步的时间预测输出。这意味着,时间卷积会提取多时间步数据中的关键信息,并最终输出一个代表未来某一时刻(如未来15分钟)的预测结果。

        接下来,在时间卷积之后,论文使用了一个全连接层将卷积层输出的特征张量映射到一个单一的输出值,通常是每个节点的交通状态(如车速或流量),最后生成预测结果 v。

4.2 代码(OutputBlock)
class OutputBlock(nn.Module):def __init__(self, Ko, last_block_channel, channels, end_channel, n_vertex, act_func, bias, droprate):super(OutputBlock, self).__init__()self.tmp_conv1 = TemporalConvLayer(Ko, last_block_channel, channels[0], n_vertex, act_func)self.fc1 = nn.Linear(in_features=channels[0], out_features=channels[1], bias=bias)self.fc2 = nn.Linear(in_features=channels[1], out_features=end_channel, bias=bias)self.tc1_ln = nn.LayerNorm([n_vertex, channels[0]], eps=1e-12)  # 归一化self.relu = nn.ReLU()self.dropout = nn.Dropout(p=droprate)  # 正则化def forward(self, x):x = self.tmp_conv1(x)x = self.tc1_ln(x.permute(0, 2, 3, 1))x = self.fc1(x)x = self.relu(x)x = self.dropout(x)x = self.fc2(x).permute(0, 3, 1, 2)  # 负责将时空特征映射为最终的预测值return x

         代码部分使用了一个时域卷积块和两个全连接层。这是为什么?这样的设计虽然与论文描述的输出层结构有所不同,但增加了额外的全连接层是为了增强模型的表达能力和预测精度。

        第一个全连接层用于对时域卷积输出的特征进行降维或变换。通过这个全连接层,模型可以将高维度的时空特征压缩或转化为新的特征表示,使得模型能够更好地抽象复杂的关系。第二个全连接层才用于最终的输出,即生成最后的预测结果。

5. 其他代码

5.1 models.py

        这个python文件里主要是对整个STGCN模型进行整合,一共有两个类,分别是 STGCNChebGraphConv 和 STGCNGraphConv 。这分别代表着使用切比雪夫多项式还是一阶近似。

        其中的大多数代码都是对 layers.py 中的函数方法调用,传参。有一行代码需要理解:

Ko = args.n_his - (len(blocks) - 3) * 2 * (args.Kt - 1)

        这句代码的作用是计算经过多个时空卷积块处理后,保留下来的时间维度的大小。

  • args.n_his:是输入数据的时间维度大小,通常指输入的历史时间步数。
  • len(blocks) - 3:blocks表示 STGCN 模型中不同层的配置的列表。它的长度再减3是为了去掉输出层相关的三层结构(TNFF,即两个全连接层和最后的时序处理层),仅关注时空卷积部分。
  • 2 * (args.Kt - 1):Kt 是每个时空卷积块中时间卷积核的大小。它的大小再减1是因为每次卷积操作后时间维度会 - 1(时间卷积是滑动窗口的形式)。而每个时空卷积块中有两个时间卷积层,所以总的时间维度减少量会再乘2。

        因此,对输入的时间维度 n_his,每经过一个时空卷积块,时间维度会减少 2 * (args.Kt - 1)。有 len(blocks) - 3 个这样的时空卷积块,因此总的时间维度减少的量是 (len(blocks) - 3) * 2 * (args.Kt - 1)。

5.2 main.py

        main 函数是对整个代码运行环境的配置,包括配置环境变量,设置命令行参数,数据类型转换,确定模型、优化器等等。这些代码在其他网络模型同样受用,大差不差。

四、实验

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/55466.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Axure重要元件一——动态面板

亲爱的小伙伴&#xff0c;在您浏览之前&#xff0c;烦请关注一下&#xff0c;在此深表感谢&#xff01; 本节课&#xff1a;动态面板 课程内容&#xff1a;认识动态面板、动态面板基本操作 应用场景&#xff1a;特定窗口、重要交互、长页面、容器等 一、认识动态面板 动态…

DeBiFormer:带有可变形代理双层路由注意力的视觉Transformer

https://arxiv.org/pdf/2410.08582v1 摘要 带有各种注意力模块的视觉Transformer在视觉任务上已表现出卓越的性能。虽然使用稀疏自适应注意力&#xff08;如在DAT中&#xff09;在图像分类任务中取得了显著成果&#xff0c;但在对语义分割任务进行微调时&#xff0c;由可变形…

软件测试面试题600多条及答案

这些问题都是软件测试领域常见的面试问题&#xff0c;以下是一些可能的答案&#xff1a; 什么是软件测试&#xff1f; 软件测试是一系列活动&#xff0c;旨在评估软件产品的质量和性能&#xff0c;以确保它符合规定的需求和标准。它包括执行程序或系统以验证其满足规定需求的过…

“探索Adobe Photoshop 2024:订阅方案、成本效益分析及在线替代品“

设计师们对Adobe Photoshop这款业界领先的图像编辑软件肯定不会陌生。如果你正考虑加入Photoshop的用户行列&#xff0c;可能会对其价格感到好奇。Photoshop的价值在于其强大的功能&#xff0c;而它的价格也反映了这一点。下面&#xff0c;我们就来详细了解一下Adobe Photoshop…

数据结构(8.2_1)——插入排序

插入排序 算法思想&#xff1a;每次将一个待排序的记录按其关键字大小插入到前面已排序好的子序列中&#xff0c;直到全部记录插入完成。 代码实现 #include <stdio.h>void InsertSort(int A[], int n) {int i, j.temp;for (i 1; i < n; i) {//将各元素插入已排好…

Axure重要元件二——内联框架

亲爱的小伙伴&#xff0c;在您浏览之前&#xff0c;烦请关注一下&#xff0c;在此深表感谢&#xff01; 课程主题&#xff1a;内联框架 课程内容&#xff1a;认识内联框架、基本嵌入 应用场景&#xff1a;表单、图片、文字嵌入式场景、交互应用 一、认识内联框架 内联框架的…

如何安全擦除 iPhone 上的所有数据,避免隐私泄露?

在当今的数字时代&#xff0c;隐私安全尤为重要。特别是在转让或出售 iPhone 之前&#xff0c;擦除设备上的所有内容是每位用户都应注意的关键步骤。尽管苹果自带了删除数据的功能&#xff0c;但有时这并不足以保证数据完全无法恢复。本文将结合 iPhone 自带的"抹掉所有内…

软考(中级-软件设计师)计算机系统篇(1018)

十、存储系统 10.1 层次结构主存–辅存&#xff1a;实现虚拟存储系统&#xff0c;解决了主存容量不够的问题。 Cache–主存&#xff1a;解决了主存与CPU速度不匹配的问题。 10.2 分类 1、按位置分类&#xff1a;可分为内存和外存。 内存&#xff08;主存&#xff09;&#…

【从零开发Mybatis】引入XNode和XPathParser

引言 在上文&#xff0c;我们发现直接使用 DOM库去解析XML 配置文件&#xff0c;非常复杂&#xff0c;也很不方便&#xff0c;需要编写大量的重复代码来处理 XML 文件的读取和解析&#xff0c;代码可读性以及可维护性相当差&#xff0c;使用起来非常不灵活。 因此&#xff0c…

o1快慢思考的风又吹到了Agent!

智能体&#xff08;Agent&#xff09;通过自然对话与用户互动有两个任务&#xff1a;交谈和规划/推理。对话回应必须基于所有可用信息&#xff0c;行动必须有助于实现目标。与用户交谈和进行多步推理和规划之间的二分法&#xff0c;类似卡尼曼引入的人类快速思考和慢速思考系统…

库卡ForceTorqueControl(二)

1. 基准坐标系RCS 基准坐标系 RCS 是力 / 力矩控制的参考系。基准坐标系的原点始终是当前的TCP。 1.1 BASE 的 RCS 姿态 基准坐标系的姿态与当前基础坐标系&#xff08;基座坐标系&#xff09;的姿态一致。它不取决于刀具的姿态。基准坐标系的原点是当前的 TCP。 示例&#xff…

【数据库设计】概念结构设计

引入——整体解释 上次我们讲完了关系模型&#xff0c;这次我们来讲新的章节&#xff1a;数据库设计 该怎样有效地管理和存储现实中的数据&#xff1f;答案是设计一个优秀的数据库。现实中的数据转化成关系表中的数据需要经过四个主要的设计步骤。 现实世界需求分析——>…

java常用工具包

Java标准库&#xff08;Java Standard Library&#xff09; 比喻&#xff1a;就像你厨房里的基础调料&#xff0c;没有它们&#xff0c;你很难做出美味的菜肴。Java标准库包含了进行基本编程所需的所有核心类和方法&#xff0c;如字符串处理、集合框架、输入输出操作等。 关键…

C++ 内存布局 - Part6: 虚继承

1. 关于虚继承 虚继承可以在菱形继承体系中&#xff0c;防止派生类中有多份重复祖基类内容。如下图所示&#xff0c;如果是常规继承&#xff0c;Class Final中会有两份Class Base的内容。通过虚继承&#xff0c;即Derived1 虚继承自Base, Derived2 也虚继承自Base, 那么Final中…

003_ipc概述及信号

【背景】 程序运行起来后&#xff0c;每个模块都有自己的进程&#xff0c;那么不同的模块如何进行通讯或者数据交换呢&#xff1f; 上面这张图说明了linux的ipc是继承最初的Unix 的IPC逻辑的&#xff0c;那么具体关系和概述讲解&#xff0c;请参考此链接的原文&#xff1a;htt…

mac 桌面版docker no space left on device

报错信息 docker pull镜像时报&#xff1a; failed to register layer: Error processing tar file(exit status 1): write /home/admin/oceanbase_bak/bin/observer: no space left on device 解决 增加 docker 虚拟磁盘大小。 调整完点击重启即可。

助力语音技术发展,景联文科技提供语音数据采集服务

语音数据采集是语音识别技术、语音合成技术以及其他语音相关应用的重要基础。采集高质量的语音数据有助于提高语音识别的准确性&#xff0c;同时也能够促进语音技术的发展。 景联文科技作为专业的数据采集标注公司&#xff0c;支持语音数据采集。可通过手机、专业麦克风阵列、专…

两个案例全面阐述全链路测试怎么做

首先我们先针对全链路功能测试部分进行一下讲解。去年的时候&#xff0c;有一家电商公司可能知道我一直在帮银行做相关的测试&#xff0c;就请我帮他们去做一些规划。这个平台有虚拟订单&#xff0c;也有实体订单&#xff0c;方式不太一样。 还涉及到分账分佣以及跟银行的对接…

大数据-174 Elasticsearch Query DSL - 全文检索 full-text query 匹配、短语、多字段 详细操作

点一下关注吧&#xff01;&#xff01;&#xff01;非常感谢&#xff01;&#xff01;持续更新&#xff01;&#xff01;&#xff01; 目前已经更新到了&#xff1a; Hadoop&#xff08;已更完&#xff09;HDFS&#xff08;已更完&#xff09;MapReduce&#xff08;已更完&am…

计算机网络基础(1)

个人主页&#xff1a;C忠实粉丝 欢迎 点赞&#x1f44d; 收藏✨ 留言✉ 加关注&#x1f493;本文由 C忠实粉丝 原创 计算机网络基础 收录于专栏【计算机网络】 本专栏旨在分享学习计算机网络的一点学习笔记&#xff0c;欢迎大家在评论区交流讨论&#x1f48c; 目录 1. 计算机网…