一、引言
引言部分不是论文的重点,主要讲述了交通预测的重要性以及一些传统方法的不足之处。进而推出了自己的模型——STGCN。
二、交通预测与图卷积
第二部分讲述了交通预测中路图和图卷积的概念。
首先理解道路图,交通预测被定义为典型的时间序列预测问题,即根据历史数据预测未来的交通情况。在该工作中,交通网络被建模为图结构,节点表示监控站,边表示监控站之间的连接。
其次是图卷积,由于传统的卷积操作无法应用于图数据,作者介绍了两种扩展卷积到图数据的方法,即扩展卷积的空间定义、谱图卷积。这里需要着重注意谱图卷积,因为STGCN就是采用了傅里叶变换的方法,这里的公式后面也要提到。
三、STGCN结构与原理(重点)
先看整体网络结构:
1. 空域卷积块(图卷积块)
1.1 论文
这里提到了两种方法,即切比雪夫多项式和一阶近似。对应于代码中的 ChebGraphConv 和 GraphConv。
为什么要用切比雪夫多项式计算?因为传统的拉普拉斯计算起来非常复杂,使用了切比雪夫多项式以后,通过使用多项式来表示卷积核,图卷积操作可以在局部化的节点邻域内进行,从而避免了全图的计算。
一阶近似进一步简化了图卷积的计算,将图拉普拉斯算子(公式中的L)的高阶近似简化为一阶。这使得图卷积操作只依赖于当前节点及其直接相邻节点,计算复杂度大幅降低。
1.2 代码(GraphConvLayer)
这里有我加了注释的代码:OracleRay/STGCN_pytorch: The PyTorch implementation of STGCN. (github.com)
时空卷积块和输出层的定义都在 layers.py 中。
ChebGraphConv 类的前向传播方法:
def forward(self, x):# bs, c_in, ts, n_vertex = x.shapex = torch.permute(x, (0, 2, 3, 1)) # 将时序维度和顶点维度排列到一起,方便后续的图卷积计算if self.Ks - 1 < 0:raise ValueError(f'ERROR: the graph convolution kernel size Ks has to be a positive integer, but received {self.Ks}.')elif self.Ks - 1 == 0:x_0 = xx_list = [x_0] # 只使用第 0 阶,即不考虑邻居节点的影响,直接使用输入特征elif self.Ks - 1 == 1:x_0 = x# hi 是邻接矩阵的索引,btij 是输入张量 x 的索引,bthj为更新后的特征表示x_1 = torch.einsum('hi,btij->bthj', self.gso, x) # 邻接矩阵gso和输入特征x进行相乘x_list = [x_0, x_1] # 使用第 0 阶和第 1 阶的节点信息elif self.Ks - 1 >= 2:x_0 = xx_1 = torch.einsum('hi,btij->bthj', self.gso, x)x_list = [x_0, x_1]for k in range(2, self.Ks): # 根据切比雪夫多项式的定义来计算,利用前两阶多项式来构建第 k 阶多项式x_list.append(torch.einsum('hi,btij->bthj', 2 * self.gso, x_list[k - 1]) - x_list[k - 2])x = torch.stack(x_list, dim=2) # 将所有阶的节点特征堆叠在一起,形成一个新的张量cheb_graph_conv = torch.einsum('btkhi,kij->bthj', x, self.weight)if self.bias is not None:cheb_graph_conv = torch.add(cheb_graph_conv, self.bias) # 添加偏置项else:cheb_graph_conv = cheb_graph_conv # 强迫症return cheb_graph_conv
Ks表示空间卷积核大小,gso代表交通预测图的邻接矩阵。在前向传播的过程中,判断 Ks - 1 的值的大小。等于0和等于1分别代表着切比雪夫多项式的第 0 阶和第 1 阶,当 Ks - 1 大于等于 2 时,不仅需要考虑第 0 阶和第 1 阶的邻居节点,还会通过递归关系计算更高阶的邻居节点信息。
构建 k 阶多项式时,需要邻接矩阵 gso 和 (k - 1) 阶相乘,然后与 (k - 2) 阶相减。这样就满足了切比雪夫多项式的要求,即:
最后把得到的值乘以权重(weight),如果有偏置项,再加偏置值(bias)即可。
GraphConv 类与此类似。
然后在 GraphConvLayer 中就可选择到底是使用切比雪夫多项式(ChebGraphConv)还是一阶近似(GraphConv)。
def forward(self, x):x_gc_in = self.align(x)if self.graph_conv_type == 'cheb_graph_conv':x_gc = self.cheb_graph_conv(x_gc_in)elif self.graph_conv_type == 'graph_conv':x_gc = self.graph_conv(x_gc_in)x_gc = x_gc.permute(0, 3, 1, 2)x_gc_out = torch.add(x_gc, x_gc_in) # 残差连接return x_gc_out
2. 时域卷积块
2.1 论文
时域卷积块最关键的两个内容就是因果卷积和门控机制GLU,不过这两个并不是这篇论文里提出来的。
因果卷积(代码中 CausalConv 类)是时域卷积网络模型(TCN)中的重要内容。在本篇论文中,使用了1-D因果卷积来确保当前时刻只依赖于过去的输入数据,这样未来的信息在当前时刻就不会被使用。
GLU是2016年由Yann N. Dauphin在论文《Language Modeling with Gated Convolutional Networks》中提出的。时域卷积块采用门控线性单元(GLU)作为非线性激活函数可以控制哪些输入是重要的,而不是所有信息都平等对待,这样有助于在时间序列中提取关键特征。
残差连接、瓶颈策略、并行训练等作者只是提了一嘴,不是重点。
2.2 代码(TemporalConvLayer)
首先看因果卷积的代码:
class CausalConv2d(nn.Conv2d):def __init__(self, in_channels, out_channels, kernel_size, stride=1, enable_padding=False, dilation=1, groups=1,bias=True):kernel_size = nn.modules.utils._pair(kernel_size) # 卷积核大小,表示对多少个像素(或特征)进行卷积。stride = nn.modules.utils._pair(stride) # 步长,控制卷积核滑动的步幅dilation = nn.modules.utils._pair(dilation) # dilation:膨胀系数,控制采样间隔,用来扩大卷积核的感受野if enable_padding == True: # 启用零填充self.__padding = [int((kernel_size[i] - 1) * dilation[i]) for i in range(len(kernel_size))]else:self.__padding = 0self.left_padding = nn.modules.utils._pair(self.__padding)super(CausalConv2d, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=0,dilation=dilation, groups=groups, bias=bias)def forward(self, input):if self.__padding != 0:# F.pad() 函数用于在高度和宽度方向上添加填充input = F.pad(input, (self.left_padding[1], 0, self.left_padding[0], 0))result = super(CausalConv2d, self).forward(input)return result
在初始化部分,如果需要零填充,则要先计算填充量。填充量的计算公式为:(kernel_size[i] - 1) * dilation[i] 。
其中 kernel_size[i] - 1 表示在每个维度(高度和宽度)上,卷积核“超出”当前步长的部分。例如一个 3 × 3 大小的卷积核,就会有2个位置影响当前位置(左一个右一个)。而膨胀系数 dilation[i] 则意味着卷积核中元素之间有多少间隔。最后相乘即可求出需要填充几个位置。
最后在前向传播过程中用 F.pad() 函数就可以实现左填充。
然后再看时域卷积块(TemporalConvLayer)的前向传播代码:
def forward(self, x):x_in = self.align(x)[:, :, self.Kt - 1:, :] # 对其输入通道数x_causal_conv = self.causal_conv(x) # 进行因果卷积if self.act_func == 'glu' or self.act_func == 'gtu':x_p = x_causal_conv[:, : self.c_out, :, :] # 分割出前半部分x_q = x_causal_conv[:, -self.c_out:, :, :] # 分割出后半部分if self.act_func == 'glu':# 通过门控机制选择性保留某些时间步的特征,这对时间序列建模非常有效x = torch.mul((x_p + x_in), torch.sigmoid(x_q)) # 对 x_p 和输入的对齐结果 x_in 进行线性加和,并与 x_q 的 sigmoid 值进行点乘else:# tanh(x_p + x_in) ⊙ sigmoid(x_q)x = torch.mul(torch.tanh(x_p + x_in), torch.sigmoid(x_q)) # 使用 tanh 代替线性加和,具有非线性变换的特性elif self.act_func == 'relu':x = self.relu(x_causal_conv + x_in)elif self.act_func == 'silu':x = self.silu(x_causal_conv + x_in)else:raise NotImplementedError(f'ERROR: The activation function {self.act_func} is not implemented.')return x
这里的代码只干了两件事:因果卷积和选择激活函数,与论文中的时域卷积块的思想大致相同。这里着重理解这行代码:
x = torch.mul((x_p + x_in), torch.sigmoid(x_q))
(x_p + x_in) 将前半部分 x_p 与输入 x_in 的对齐结果进行线性加和,表示对主要特征的组合。sigmoid(x_q) 将后半部分 x_q 通过 sigmoid 函数转化为 0 到 1 之间的值,作为控制门。最后用⊙符号逐元素相乘,GLU就能决定 x_p 能通过多少信息。
默认使用GLU激活函数,其他 if-else 语句中的激活函数不使用。
3. 时空卷积块
3.1 论文
前面知道,两个时域卷积块 + 一个空域卷积块 = 一个时空卷积块。而且是两个时域卷积块夹着一个空域卷积块的三明治结构。这种设计可以同时处理交通网络中的 时间依赖 和 空间依赖,即模型可以同时从时序信息和图结构中提取重要特征。
中间的图卷积层负责从图结构(如道路网络)中提取空间特征。通过使用前面提到的图卷积方法(如切比雪夫多项式近似或一阶近似),可以高效地捕捉交通站点之间的连接关系。
上下两个时间卷积层负责提取时间依赖特征。通过因果卷积的方式,可以确保预测时只使用当前时刻及之前的交通数据,避免未来信息泄露。
3.2 代码(STConvBlock)
class STConvBlock(nn.Module):def __init__(self, Kt, Ks, n_vertex, last_block_channel, channels, act_func, graph_conv_type, gso, bias, droprate):super(STConvBlock, self).__init__()# “三明治”结构:两个时域卷积块,一个空域卷积块self.tmp_conv1 = TemporalConvLayer(Kt, last_block_channel, channels[0], n_vertex, act_func)self.graph_conv = GraphConvLayer(graph_conv_type, channels[0], channels[1], Ks, gso, bias)self.tmp_conv2 = TemporalConvLayer(Kt, channels[1], channels[2], n_vertex, act_func)self.tc2_ln = nn.LayerNorm([n_vertex, channels[2]], eps=1e-12) # 归一化:缓解梯度消失或梯度爆炸问题self.relu = nn.ReLU()self.dropout = nn.Dropout(p=droprate) # 正则化:dropout率为0.5def forward(self, x):x = self.tmp_conv1(x)x = self.graph_conv(x)x = self.relu(x)x = self.tmp_conv2(x)x = self.tc2_ln(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)x = self.dropout(x)return x
在初始化函数中,先定义“三明治”结构,使用之前写好的时域卷积块和图卷积块。然后确定归一化方法,激活函数,正则化方法。
4. 输出层
4.1 论文
论文中的输出层是由一个时域卷积块和一个全连接层组成的。
当最后一个时空卷积块处理完数据之后,输出的是一个三维张量(M × n × C),其中M是时间步数(例如,过去60分钟的交通数据),n是交通网络中的节点数(即监测站或路段数),C是特征通道数。而论文使用了一个时域卷积层将这些特征进一步压缩成一个单步的时间预测输出。这意味着,时间卷积会提取多时间步数据中的关键信息,并最终输出一个代表未来某一时刻(如未来15分钟)的预测结果。
接下来,在时间卷积之后,论文使用了一个全连接层将卷积层输出的特征张量映射到一个单一的输出值,通常是每个节点的交通状态(如车速或流量),最后生成预测结果 v。
4.2 代码(OutputBlock)
class OutputBlock(nn.Module):def __init__(self, Ko, last_block_channel, channels, end_channel, n_vertex, act_func, bias, droprate):super(OutputBlock, self).__init__()self.tmp_conv1 = TemporalConvLayer(Ko, last_block_channel, channels[0], n_vertex, act_func)self.fc1 = nn.Linear(in_features=channels[0], out_features=channels[1], bias=bias)self.fc2 = nn.Linear(in_features=channels[1], out_features=end_channel, bias=bias)self.tc1_ln = nn.LayerNorm([n_vertex, channels[0]], eps=1e-12) # 归一化self.relu = nn.ReLU()self.dropout = nn.Dropout(p=droprate) # 正则化def forward(self, x):x = self.tmp_conv1(x)x = self.tc1_ln(x.permute(0, 2, 3, 1))x = self.fc1(x)x = self.relu(x)x = self.dropout(x)x = self.fc2(x).permute(0, 3, 1, 2) # 负责将时空特征映射为最终的预测值return x
代码部分使用了一个时域卷积块和两个全连接层。这是为什么?这样的设计虽然与论文描述的输出层结构有所不同,但增加了额外的全连接层是为了增强模型的表达能力和预测精度。
第一个全连接层用于对时域卷积输出的特征进行降维或变换。通过这个全连接层,模型可以将高维度的时空特征压缩或转化为新的特征表示,使得模型能够更好地抽象复杂的关系。第二个全连接层才用于最终的输出,即生成最后的预测结果。
5. 其他代码
5.1 models.py
这个python文件里主要是对整个STGCN模型进行整合,一共有两个类,分别是 STGCNChebGraphConv 和 STGCNGraphConv 。这分别代表着使用切比雪夫多项式还是一阶近似。
其中的大多数代码都是对 layers.py 中的函数方法调用,传参。有一行代码需要理解:
Ko = args.n_his - (len(blocks) - 3) * 2 * (args.Kt - 1)
这句代码的作用是计算经过多个时空卷积块处理后,保留下来的时间维度的大小。
- args.n_his:是输入数据的时间维度大小,通常指输入的历史时间步数。
- len(blocks) - 3:blocks表示 STGCN 模型中不同层的配置的列表。它的长度再减3是为了去掉输出层相关的三层结构(
TNFF
,即两个全连接层和最后的时序处理层),仅关注时空卷积部分。- 2 * (args.Kt - 1):Kt 是每个时空卷积块中时间卷积核的大小。它的大小再减1是因为每次卷积操作后时间维度会 - 1(时间卷积是滑动窗口的形式)。而每个时空卷积块中有两个时间卷积层,所以总的时间维度减少量会再乘2。
因此,对输入的时间维度 n_his,每经过一个时空卷积块,时间维度会减少 2 * (args.Kt - 1)。有 len(blocks) - 3 个这样的时空卷积块,因此总的时间维度减少的量是 (len(blocks) - 3) * 2 * (args.Kt - 1)。
5.2 main.py
main 函数是对整个代码运行环境的配置,包括配置环境变量,设置命令行参数,数据类型转换,确定模型、优化器等等。这些代码在其他网络模型同样受用,大差不差。