1.前言
本次代码文章来自于《2024-AAAI-Spatio-Temporal Pivotal Graph Neural Networks for Traffic Flow Forecasting》,基本模型结构如下图所示:
文章讲解视频链接
代码开源链接
接下来就开始代码解读了。
2.代码解读
class nconv(nn.Module):def __init__(self):super(nconv, self).__init__()def forward(self, x, A):x = torch.einsum('ncvl,nwv->ncwl', (x, A))return x.contiguous()
让我们逐行分析:
-
def __init__(self):
这是构造函数,初始化nconv
类的实例。这里没有额外的初始化参数,因为它没有定义任何需要学习的参数。 -
super(nconv, self).__init__():
这一行调用了父类nn.Module
的构造函数,确保了所有必要的初始化步骤得以执行。 -
def forward(self, x, A):
定义了前向传播方法,这是每个nn.Module
子类必须实现的方法。这个方法接受两个输入参数:x
: 输入张量,形状为(N, C, V, L)
,其中N
是批量大小,C
是通道数,V
是顶点数,L
是序列长度。A
: 图的邻接矩阵,形状为(N, W, V)
,其中W
是边的权重数,V
是顶点数。这里的W
和V
应该对应于图中的权重和顶点。
-
x = torch.einsum('ncvl,nwv->ncwl', (x, A))
这一行是核心计算部分,使用了torch.einsum
函数来执行一个高效的多维数组乘法和求和操作。einsum
的第一个参数是一个字符串,描述了输入张量的维度标签和输出张量的维度标签。这里的标签解释如下:'ncvl'
表示输入张量x
的四个维度:N
(批量大小),C
(通道数),V
(顶点数),L
(序列长度)。'nwv'
表示输入张量A
的三个维度:N
(批量大小),W
(边的权重数),V
(顶点数)。'ncwl'
表示输出张量的四个维度:N
(批量大小),C
(通道数),W
(边的权重数),L
(序列长度)。
这个表达式实际上是在进行类似于图卷积的操作,其中输入特征
x
与图的邻接矩阵A
相乘,以传播信息通过图的边。 -
return x.contiguous()
最后返回处理后的张量。contiguous()
方法用于确保返回的张量在内存中是连续存储的,这对于后续可能的操作(如索引或视图转换)来说是必要的。
总的来说,nconv
模块接收输入特征和图的邻接矩阵,然后通过 torch.einsum
实现了一种特定的卷积操作,用于处理图结构数据。
class pconv(nn.Module):def __init__(self):super(pconv, self).__init__()def forward(self, x, A):x = torch.einsum('bcnt, bmn->bc', (x, A))return x.contiguous()
pconv
类定义了一个自定义的PyTorch模块,该模块实现了一种特定类型的卷积操作,其中输入张量与一个可学习的或预定义的邻接矩阵(A
)进行乘法运算。这种类型的卷积通常在图神经网络(Graph Neural Networks, GNNs)中使用,其中A
可以代表图的邻接矩阵,用于编码节点之间的连接性。下面是对 pconv
类的详细解释:
__init__
方法
pconv
类继承自 nn.Module
,这是所有PyTorch神经网络模块的基类。构造函数 __init__
中没有定义任何额外的参数或层,这意味着 pconv
不包含任何可学习的参数,即它不会在训练过程中更新其权重。
forward
方法
forward
方法定义了当数据通过这个模块时的操作。它接受两个参数:
x
: 输入张量,形状为(batch_size, channels, nodes, time_steps)
。其中:batch_size
表示一个批次中的样本数量。channels
表示每个节点在每个时间步上的特征数量。nodes
表示图中的节点数量。time_steps
表示时间序列的长度。
A
: 邻接矩阵,形状为(batch_size, nodes, nodes)
。A
可以是预定义的,也可以是可学习的,它编码了图中节点之间的关系。
内部操作
在 forward
方法内部,使用了 torch.einsum
函数来执行一个高效的矩阵乘法操作。einsum
是一个通用的函数,用于执行各种类型的张量运算,这里用来实现输入张量 x
与邻接矩阵 A
的乘法。
torch.einsum('bcnt, bmn->bc', (x, A))
这行代码中,字符串 'bcnt, bmn->bc'
定义了输入张量的子标模式以及期望的输出模式。具体来说:
'bcnt'
指代x
的四个维度,分别对应于 batch size (b
)、channels (c
)、nodes (n
) 和 time steps (t
)。'bmn'
指代A
的三个维度,分别对应于 batch size (b
)、源节点 (m
) 和目标节点 (n
)。'bc'
是输出张量的模式,意味着输出将是一个二维张量,其维度为 batch size 和 channels。
输出
x = torch.einsum('bcnt, bmn->bc', (x, A))
计算的结果是一个形状为 (batch_size, channels)
的张量,这表明对于每一个样本,我们得到了一个压缩后的特征表示,其中时间步和节点维度被聚合掉了。
最后,return x.contiguous()
确保返回的张量是连续存储的,这对于后续的某些操作可能很重要,例如当张量需要在GPU上进行高效计算时。这是因为非连续的内存布局可能会导致性能下降。
class linear(nn.Module):def __init__(self, c_in, c_out):super(linear, self).__init__()self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0, 0), stride=(1, 1), bias=True)def forward(self, x):return self.mlp(x)
linear
类是一个自定义的 PyTorch 模块,它实质上实现了一个线性变换。
class gcn(nn.Module):def __init__(self, c_in, c_out, dropout, support_len=3, order=2):super(gcn, self).__init__()self.nconv = nconv()c_in = (order * support_len + 1) * c_inself.mlp = linear(c_in, c_out)self.dropout = dropoutself.order = orderdef forward(self, x, support):out = [x]for a in support:x1 = self.nconv(x, a)out.append(x1)for k in range(2, self.order + 1):x2 = self.nconv(x1, a)out.append(x2)x1 = x2h = torch.cat(out, dim=1)h = self.mlp(h)return h
gcn
类定义了一个基于图卷积网络(Graph Convolutional Network, GCN)的模块,它在图结构数据上执行多阶卷积操作,以捕获不同层次的节点间关联。下面是对 gcn
类的详细解析:
初始化方法 __init__
在构造函数 __init__
中,gcn
类继承自 nn.Module
并初始化以下组件:
nconv
: 实例化nconv
类,用于执行图卷积操作。mlp
: 实例化linear
类,用于线性变换和聚合来自不同阶卷积的结果。dropout
: 设置 dropout 比率,用于正则化和防止过拟合。order
: 设置图卷积的阶数,控制卷积操作的深度,即卷积在图上扩展的层数。
c_in
的值被重新定义为 (order * support_len + 1) * c_in
,这考虑到了 support_len
个支持矩阵在 order
阶的卷积中产生的特征通道数。+1
是因为原始输入 x
也会被拼接到最终的输出中。
前向传播方法 forward
在 forward
方法中,gcn
类执行以下操作:
- 初始化一个列表
out
来保存每一阶卷积的结果,首先添加原始输入x
。 - 对于
support
中的每一个邻接矩阵a
,执行以下操作:- 使用
nconv
对输入x
和邻接矩阵a
进行一次卷积,结果存储在x1
中,并添加到out
。 - 接下来,对于
order
中的每一阶(从 2 开始),重复使用nconv
对前一阶的结果x1
和同一个邻接矩阵a
进行卷积,结果存储在x2
中,再添加到out
,并将x2
设为下一次迭代的输入x1
。
- 使用
- 将
out
中的所有结果在通道维度(dim=1)上进行拼接,形成一个包含所有阶卷积结果的张量h
。 - 将
h
传递给mlp
层,进行线性变换和通道数的调整,最终输出调整后的特征表示。
总结
gcn
类通过多次调用 nconv
模块来执行多阶图卷积,捕捉图中节点间的多层次关系。通过将不同阶的卷积结果拼接起来,它能够整合从局部到全局的节点信息。最后,mlp
层负责将这些多阶特征映射到期望的输出维度,以便进一步的处理或分类。这种设计使得 gcn
能够有效处理复杂图结构数据,并在诸如社交网络分析、分子结构预测等任务中发挥重要作用。
class pgcn(nn.Module):def __init__(self, c_in, c_out, dropout, support_len=3, order=2, temp=1):super(pgcn, self).__init__()self.nconv = nconv()self.temp = tempc_in = (order * support_len + 1) * c_inself.mlp = linear(c_in, c_out)self.dropout = dropoutself.order = orderdef forward(self, x, support):out = [x]for a in support:x1 = self.nconv(x, a)out.append(x1)for k in range(2, self.order + 1):x2 = self.nconv(x1, a)out.append(x2)x1 = x2h = torch.cat(out, dim=1)h = self.mlp(h)h = h[:,:,:,-h.size(3):-self.temp]return h
pgcn
类定义了一个个性化的图卷积网络(Personalized Graph Convolutional Network)模块,它在图卷积的基础上引入了个性化参数,允许模型在处理图数据时考虑到更加细致的节点特性或时间序列特性。下面是对 pgcn
类的详细解析:
初始化方法 __init__
pgcn
类继承自 nn.Module
并初始化以下组件:
nconv
: 实例化nconv
类,用于执行图卷积操作。temp
: 一个个性化参数,用于在输出中裁剪时间序列数据,这可能用于处理具有周期性或季节性模式的时间序列数据,通过移除某些时间点的数据来增强模型对特定时间模式的学习能力。mlp
: 实例化linear
类,用于线性变换和聚合来自不同阶卷积的结果。dropout
: 设置 dropout 比率,用于正则化和防止过拟合。order
: 设置图卷积的阶数,控制卷积操作的深度。
与 gcn
类似,c_in
的值被重新定义为 (order * support_len + 1) * c_in
,考虑到了多阶卷积产生的特征通道数。
前向传播方法 forward
forward
方法中,pgcn
类执行的操作与 gcn
类似,但在输出阶段有一个关键的区别:
- 初始化一个列表
out
来保存每一阶卷积的结果,首先添加原始输入x
。 - 对于
support
中的每一个邻接矩阵a
,执行多阶卷积操作,将结果存储在out
中。 - 将
out
中的所有结果在通道维度(dim=1)上进行拼接,形成一个包含所有阶卷积结果的张量h
。 - 将
h
传递给mlp
层,进行线性变换和通道数的调整。 - 个性化裁剪:在
h
上执行一个个性化裁剪操作,通过h = h[:,:,:,-h.size(3):-self.temp]
,这将从h
的最后一个维度(通常是时间序列的长度)开始,去除从末尾开始的self.temp
个时间点的数据。这种裁剪可以用于去除不需要的时间点,例如去除最近的短期波动,以便模型更专注于长期趋势或周期性模式。
总结
pgcn
类通过在标准图卷积网络的基础上引入个性化参数 temp
,增强了模型处理时间序列图数据的能力。通过裁剪时间序列的末端,模型可以更好地聚焦于数据中的长期模式,这对于处理具有季节性或周期性特性的数据集尤为重要。
def __init__(self, device, num_nodes, dropout=0.3, topk=35,out_dim=12, residual_channels=16, dilation_channels=16, end_channels=512,kernel_size=2, blocks=4, layers=2, days=288, dims=40, order=2, in_dim=9, normalization="batch"):super(STPGNN, self).__init__()skip_channels = 8self.alpha = nn.Parameter(torch.tensor(-5.0)) self.topk = topkself.dropout = dropoutself.blocks = blocksself.layers = layersself.filter_convs = nn.ModuleList()self.gate_convs = nn.ModuleList()self.residual_convs = nn.ModuleList()self.skip_convs = nn.ModuleList()self.normal = nn.ModuleList()self.gconv = nn.ModuleList()self.residual_convs_a = nn.ModuleList()self.skip_convs_a = nn.ModuleList()self.normal_a = nn.ModuleList()self.pgconv = nn.ModuleList()self.start_conv_a = nn.Conv2d(in_channels=in_dim,out_channels=1,kernel_size=(1, 1))self.start_conv = nn.Conv2d(in_channels=in_dim,out_channels=residual_channels,kernel_size=(1, 1))receptive_field = 1self.supports_len = 1self.nodevec_p1 = nn.Parameter(torch.randn(days, dims).to(device), requires_grad=True).to(device)self.nodevec_p2 = nn.Parameter(torch.randn(num_nodes, dims).to(device), requires_grad=True).to(device)self.nodevec_p3 = nn.Parameter(torch.randn(num_nodes, dims).to(device), requires_grad=True).to(device)self.nodevec_pk = nn.Parameter(torch.randn(dims, dims, dims).to(device), requires_grad=True).to(device)
这段代码是 STPGNN
类的初始化方法 __init__
的一部分,它主要负责构建模型的架构和初始化必要的参数。下面是详细的解析:
构建网络组件
-
Convolution Layers and Residual Connections:
self.filter_convs
,self.gate_convs
: 这两个列表存储了因果卷积(Causal Convolution)层,它们用于处理时间序列数据,通过滤波器(filter)和门控(gate)机制捕捉时间依赖性。self.residual_convs
,self.skip_convs
: 这些列表分别存储了残差卷积和跳跃连接卷积层,用于在网络中建立残差连接和跳跃连接,有助于梯度传播并避免深度网络中的梯度消失/爆炸问题。self.normal
: 这个列表包含了归一化层,如批量归一化(Batch Normalization)或层归一化(Layer Normalization),用于加速训练过程和提升模型性能。
-
Graph Convolution Layers:
self.gconv
,self.pgconv
: 这两个列表分别存储了图卷积(Graph Convolution)层和个性化图卷积(Personalized Graph Convolution)层,用于处理图结构数据,捕捉节点间的空间依赖性。self.residual_convs_a
,self.skip_convs_a
,self.normal_a
: 这些组件与前面提到的组件类似,但是专门用于辅助分支,可能是为了处理特定类型的信息或者用于构建个性化的图卷积。
-
Input Layers:
self.start_conv_a
,self.start_conv
: 这两个卷积层用于调整输入数据的维度,self.start_conv_a
可能用于特定的辅助特征提取,而self.start_conv
则是主输入层,用于调整输入特征至残差通道数。
参数初始化
-
Receptive Field:
receptive_field
是一个变量,初始化为1,它表示网络能够感知的时间序列的宽度。随着网络的深入,这个值会增加,表示网络可以捕捉到更远的历史信息。 -
Node Embeddings and Adjacency Matrix Parameters:
self.nodevec_p1
,self.nodevec_p2
,self.nodevec_p3
,self.nodevec_pk
: 这些参数是节点嵌入向量和用于构建动态邻接矩阵的参数,它们在训练过程中是可学习的。self.nodevec_p1
代表时间相关的节点嵌入,self.nodevec_p2
和self.nodevec_p3
代表空间相关的节点嵌入,而self.nodevec_pk
用于构建核心节点之间的关联,这些参数一起用于构建一个适应性更强的图结构,使得模型能够根据输入数据动态调整节点之间的关联强度。
通过上述组件和参数的初始化,STPGNN
构建了一个能够处理时空序列数据的深度学习模型,结合了时间序列分析和图结构数据处理的优势,适用于如交通流量预测、环境监测等需要同时考虑时间和空间依赖性的任务。
for b in range(blocks):additional_scope = kernel_size - 1new_dilation = 1for i in range(layers):# dilated convolutionsself.filter_convs.append(nn.Conv2d(in_channels=residual_channels,out_channels=dilation_channels,kernel_size=(1, kernel_size), dilation=new_dilation))self.gate_convs.append(nn.Conv1d(in_channels=residual_channels,out_channels=dilation_channels,kernel_size=(1, kernel_size), dilation=new_dilation))self.residual_convs.append(nn.Conv1d(in_channels=dilation_channels,out_channels=residual_channels,kernel_size=(1, 1)))self.skip_convs.append(nn.Conv1d(in_channels=dilation_channels,out_channels=skip_channels,kernel_size=(1, 1)))self.residual_convs_a.append(nn.Conv1d(in_channels=dilation_channels,out_channels=residual_channels,kernel_size=(1, 1)))self.pgconv.append(pgcn(dilation_channels, residual_channels, dropout, support_len=self.supports_len, order=order, temp=new_dilation))self.gconv.append(gcn(dilation_channels, residual_channels, dropout, support_len=self.supports_len, order=order))if normalization == "batch":self.normal.append(nn.BatchNorm2d(residual_channels))self.normal_a.append(nn.BatchNorm2d(residual_channels))elif normalization == "layer":self.normal.append(nn.LayerNorm([residual_channels, num_nodes, 13 - receptive_field - new_dilation + 1]))self.normal_a.append(nn.LayerNorm([residual_channels, num_nodes, 13 - receptive_field - new_dilation + 1]))new_dilation *= 2receptive_field += additional_scopeadditional_scope *= 2
这段代码是 STPGNN
类初始化方法的一部分,它主要负责构建多层因果卷积块,这些块是构成整个网络的基础单元。以下是详细解析:
构建因果卷积块
- Looping through blocks and layers:
- 外层循环
for b in range(blocks)
控制着构建的残差块数量,每个块由多个层组成。 - 内层循环
for i in range(layers)
控制着每个残差块内的层数量。
- 外层循环
卷积层的配置
- Dilated Convolutions:
self.filter_convs
和self.gate_convs
分别存储了滤波器和门控机制的扩张卷积层,用于捕捉时间序列数据中的长期依赖关系。扩张卷积(Dilated Convolution)通过增加卷积核之间的空洞来扩大感受野,而无需增加网络深度或输入尺寸。self.residual_convs
存储了用于残差连接的1x1卷积层,它们用于将输入与扩张卷积的输出相加,形成残差块的核心部分。self.skip_convs
存储了用于跳跃连接的1x1卷积层,它们将中间层的输出传递到网络的最后阶段,帮助网络学习长期依赖。
图卷积层的配置
- Graph Convolution Layers:
self.pgconv
和self.gconv
分别存储了个性化图卷积(Personalized Graph Convolution)和图卷积(Graph Convolution)层,用于处理图结构数据,捕捉节点间的空间依赖性。这些层在每个因果卷积层之后被调用,将时间序列特征与图结构特征相结合。
归一化层的配置
- Normalization Layers:
- 根据
normalization
参数的值,选择批量归一化(nn.BatchNorm2d
)或层归一化(nn.LayerNorm
)。归一化层有助于加速训练过程,减少内部协变量偏移,提高模型的泛化能力。
- 根据
扩张因子和感受野的更新
- Updating Dilation Factor and Receptive Field:
new_dilation *= 2
更新了扩张因子,每次内层循环都会翻倍,这样扩张卷积的感受野会随着层数的增加而指数级增长。receptive_field += additional_scope
和additional_scope *= 2
更新了网络的感受野,反映了随着扩张卷积的深入,网络能够捕捉到的时间序列的宽度也在增加。
通过这种方式,STPGNN
构建了一个能够同时处理时间序列数据和图结构数据的深度学习模型,能够捕捉到数据中的长期依赖和空间依赖,非常适合应用于如交通流量预测等需要同时考虑时间和空间因素的任务。
def dgconstruct(self, time_embedding, source_embedding, target_embedding, core_embedding):adp = torch.einsum('ai, ijk->ajk', time_embedding, core_embedding)adp = torch.einsum('bj, ajk->abk', source_embedding, adp)adp = torch.einsum('ck, abk->abc', target_embedding, adp)adp = F.softmax(F.relu(adp), dim=2)return adpdef pivotalconstruct(self, x, adj, k):x = x.squeeze(1)x = x.sum(dim=0)y = x.sum(dim=1).unsqueeze(0)adjp = torch.einsum('ij, jk->ik', x[:,:-1], x.transpose(0, 1)[1:,:]) / yadjp = adjp * adjscore = adjp.sum(dim=0) + adjp.sum(dim=1)N = x.size(0)_, topk_indices = torch.topk(score,k)mask = torch.zeros(N, dtype=torch.bool,device=x.device)mask[topk_indices] = Truemasked_matrix = adjp * mask.unsqueeze(1) * mask.unsqueeze(0)adjp = F.softmax(F.relu(masked_matrix), dim=1)return (adjp.unsqueeze(0))
这段代码定义了两个函数,dgconstruct
和pivotalconstruct
,它们分别用于构建动态图结构和识别关键节点。
dgconstruct
函数接受四个参数:time_embedding
(时间嵌入),source_embedding
(源节点嵌入),target_embedding
(目标节点嵌入),和core_embedding
(核心嵌入)。此函数的目标是通过四者间的交互作用来构建动态的邻接矩阵adp
,这个矩阵描述了在特定时间下,源节点与目标节点之间的影响强度。具体步骤如下:
- 首先,使用
torch.einsum
函数,将时间嵌入与核心嵌入相乘,生成一个中间矩阵adp
。 - 接下来,将源节点嵌入与上一步得到的
adp
相乘,进一步细化节点间的影响关系。 - 最后,目标节点嵌入与当前的
adp
相乘,完成动态邻接矩阵的构建。 - 应用ReLU激活函数和Softmax归一化函数,使矩阵元素非负且按列归一化,确保每个源节点到所有目标节点的边权总和为1。
pivotalconstruct
函数则用于识别交通网络中的关键节点。它接受三个参数:x
(输入特征矩阵),adj
(静态邻接矩阵),和k
(关键节点数量)。以下是详细步骤:
- 将输入特征矩阵
x
的维度调整,使其变为二维,然后沿列方向求和,得到节点的时间序列特征。 - 对节点的时间序列特征进行行求和,得到节点的总流量,然后将其转置并扩展维度,便于后续计算。
- 利用
torch.einsum
计算节点间的时间序列特征相互作用矩阵adjp
,并通过除以节点总流量进行标准化。 - 将
adjp
与静态邻接矩阵adj
相乘,过滤掉不存在物理连接的节点间关系。 - 计算每个节点的“重要性”分数,这是通过将
adjp
矩阵的行和列求和得到的。 - 使用
torch.topk
函数找到具有最高分数的前k
个节点,这些节点即为关键节点。 - 创建一个布尔掩码
mask
,用于标记哪些节点是关键节点。 - 应用掩码到
adjp
矩阵,仅保留关键节点间的关系。 - 最后,对关键节点的邻接矩阵应用ReLU和Softmax,确保矩阵非负且按列归一化,得到最终的关键节点邻接矩阵
adjp
,并增加一个维度以适应后续操作。
通过以上两个函数,dgconstruct
构建了基于动态特征的邻接矩阵,而pivotalconstruct
则识别出了网络中对交通流动有重要影响的关键节点及其相互关系。这两个矩阵将用于后续的图神经网络层,以捕捉交通网络中的空间和时间依赖性。
def forward(self, inputs, ind):"""input: (B, F, N, T)"""in_len = inputs.size(3)num_nodes = inputs.size(2)if in_len < self.receptive_field:xo = nn.functional.pad(inputs, (self.receptive_field - in_len, 0, 0, 0))else:xo = inputsx = self.start_conv(xo[:, [0]])x_a = self.start_conv_a(xo[:, [0]])skip = 0adj = self.dgconstruct(self.nodevec_p1[ind], self.nodevec_p2, self.nodevec_p3, self.nodevec_pk)pivweight = nn.Parameter(torch.randn(num_nodes, num_nodes).to(x.device), requires_grad=True).to(x.device)adj_p = self.pivotalconstruct(x_a, pivweight, self.topk)supports = [adj]supports_a = [adj_p]for i in range(self.blocks * self.layers):residual = xfilter = self.filter_convs[i](residual)filter = torch.tanh(filter)gate = self.gate_convs[i](residual)gate = torch.sigmoid(gate)x = filter * gatex_a = self.pgconv[i](residual, supports_a)x = self.gconv[i](x, supports)alpha_sigmoid = torch.sigmoid(self.alpha) x = alpha_sigmoid * x_a + (1 - alpha_sigmoid) * xx = x + residual[:, :, :, -x.size(3):]s = xs = self.skip_convs[i](s)if isinstance(skip, int): # B F N Tskip = s.transpose(2, 3).reshape([s.shape[0], -1, s.shape[2], 1]).contiguous()else:skip = torch.cat([s.transpose(2, 3).reshape([s.shape[0], -1, s.shape[2], 1]), skip], dim=1).contiguous()x = self.normal[i](x)x = F.relu(skip)x = F.relu(self.end_conv_1(x))x = self.end_conv_2(x)return x
这段代码实现了一个深度学习模型的前向传播过程,该模型被设计用于处理时序数据,如交通流量预测。模型的输入是一个四维张量,形状为(B, F, N, T),其中B代表批量大小,F代表特征数,N代表节点数,T代表时间步长。模型的架构包含了卷积、门控机制、残差连接、跳过连接以及图神经网络组件。
-
输入预处理:
- 首先检查输入的时间长度T是否小于模型的受感野(receptive_field),如果小于,则使用
nn.functional.pad
对输入进行填充,确保输入的时间序列长度满足要求。
- 首先检查输入的时间长度T是否小于模型的受感野(receptive_field),如果小于,则使用
-
起始卷积:
- 使用
start_conv
和start_conv_a
进行起始卷积,分别对输入的首个特征通道进行处理,得到x和x_a。
- 使用
-
动态图构建:
dgconstruct
函数用于构建动态邻接矩阵,根据节点特征构建图结构。这将用于图卷积操作。pivotalconstruct
函数用于构建关键节点图,它使用x_a和关键节点权重矩阵pivweight来构造关键节点的邻接矩阵adj_p。
-
多层残差模块:
- 模型包含多个残差块,每个残差块由多层组成。每层首先应用残差连接,之后进行卷积操作,包括滤波器和门控机制。
- 滤波器和门控卷积的结果分别经过tanh和sigmoid激活函数,之后相乘,产生门控信号控制信息流。
- 使用
pgconv
进行关键节点图上的卷积,并使用gconv
进行常规图卷积。 - 引入一个可学习的参数alpha_sigmoid,通过sigmoid函数得到一个0到1之间的值,用于加权融合关键节点图卷积和常规图卷积的结果。
- 结果再与残差项相加,之后进行跳过连接,将结果存储在skip变量中,用于后续的跳跃连接操作。
-
跳跃连接与输出:
- 跳跃连接将每一层的输出收集起来,进行整合,形成skip变量。
- 经过跳跃连接后,结果经过
end_conv_1
和end_conv_2
卷积层处理,最终得到模型的输出。
整个模型通过这种结构能够同时捕捉空间和时间依赖性,特别是在处理像交通流量预测这样的问题时,它能有效利用图结构和时序特性,从而做出更准确的预测。