论文代码解读STPGNN

1.前言

本次代码文章来自于《2024-AAAI-Spatio-Temporal Pivotal Graph Neural Networks for Traffic Flow Forecasting》,基本模型结构如下图所示:

文章讲解视频链接

代码开源链接

接下来就开始代码解读了。

 

2.代码解读 

class nconv(nn.Module):def __init__(self):super(nconv, self).__init__()def forward(self, x, A):x = torch.einsum('ncvl,nwv->ncwl', (x, A))return x.contiguous()

让我们逐行分析:

  1. def __init__(self): 这是构造函数,初始化 nconv 类的实例。这里没有额外的初始化参数,因为它没有定义任何需要学习的参数。

  2. super(nconv, self).__init__(): 这一行调用了父类 nn.Module 的构造函数,确保了所有必要的初始化步骤得以执行。

  3. def forward(self, x, A): 定义了前向传播方法,这是每个 nn.Module 子类必须实现的方法。这个方法接受两个输入参数:

    • x: 输入张量,形状为 (N, C, V, L),其中 N 是批量大小,C 是通道数,V 是顶点数,L 是序列长度。
    • A: 图的邻接矩阵,形状为 (N, W, V),其中 W 是边的权重数,V 是顶点数。这里的 W 和 V 应该对应于图中的权重和顶点。
  4. x = torch.einsum('ncvl,nwv->ncwl', (x, A)) 这一行是核心计算部分,使用了 torch.einsum 函数来执行一个高效的多维数组乘法和求和操作。einsum 的第一个参数是一个字符串,描述了输入张量的维度标签和输出张量的维度标签。这里的标签解释如下:

    • 'ncvl' 表示输入张量 x 的四个维度:N(批量大小),C(通道数),V(顶点数),L(序列长度)。
    • 'nwv' 表示输入张量 A 的三个维度:N(批量大小),W(边的权重数),V(顶点数)。
    • 'ncwl' 表示输出张量的四个维度:N(批量大小),C(通道数),W(边的权重数),L(序列长度)。

    这个表达式实际上是在进行类似于图卷积的操作,其中输入特征 x 与图的邻接矩阵 A 相乘,以传播信息通过图的边。

  5. return x.contiguous() 最后返回处理后的张量。contiguous() 方法用于确保返回的张量在内存中是连续存储的,这对于后续可能的操作(如索引或视图转换)来说是必要的。

总的来说,nconv 模块接收输入特征和图的邻接矩阵,然后通过 torch.einsum 实现了一种特定的卷积操作,用于处理图结构数据。

class pconv(nn.Module):def __init__(self):super(pconv, self).__init__()def forward(self, x, A):x = torch.einsum('bcnt, bmn->bc', (x, A))return x.contiguous()

pconv 类定义了一个自定义的PyTorch模块,该模块实现了一种特定类型的卷积操作,其中输入张量与一个可学习的或预定义的邻接矩阵(A)进行乘法运算。这种类型的卷积通常在图神经网络(Graph Neural Networks, GNNs)中使用,其中A可以代表图的邻接矩阵,用于编码节点之间的连接性。下面是对 pconv 类的详细解释:

__init__ 方法

pconv 类继承自 nn.Module,这是所有PyTorch神经网络模块的基类。构造函数 __init__ 中没有定义任何额外的参数或层,这意味着 pconv 不包含任何可学习的参数,即它不会在训练过程中更新其权重。

forward 方法

forward 方法定义了当数据通过这个模块时的操作。它接受两个参数:

  • x: 输入张量,形状为 (batch_size, channels, nodes, time_steps)。其中:
    • batch_size 表示一个批次中的样本数量。
    • channels 表示每个节点在每个时间步上的特征数量。
    • nodes 表示图中的节点数量。
    • time_steps 表示时间序列的长度。
  • A: 邻接矩阵,形状为 (batch_size, nodes, nodes)A 可以是预定义的,也可以是可学习的,它编码了图中节点之间的关系。

内部操作

forward 方法内部,使用了 torch.einsum 函数来执行一个高效的矩阵乘法操作。einsum 是一个通用的函数,用于执行各种类型的张量运算,这里用来实现输入张量 x 与邻接矩阵 A 的乘法。

torch.einsum('bcnt, bmn->bc', (x, A)) 这行代码中,字符串 'bcnt, bmn->bc' 定义了输入张量的子标模式以及期望的输出模式。具体来说:

  • 'bcnt' 指代 x 的四个维度,分别对应于 batch size (b)、channels (c)、nodes (n) 和 time steps (t)。
  • 'bmn' 指代 A 的三个维度,分别对应于 batch size (b)、源节点 (m) 和目标节点 (n)。
  • 'bc' 是输出张量的模式,意味着输出将是一个二维张量,其维度为 batch size 和 channels。

输出

x = torch.einsum('bcnt, bmn->bc', (x, A)) 计算的结果是一个形状为 (batch_size, channels) 的张量,这表明对于每一个样本,我们得到了一个压缩后的特征表示,其中时间步和节点维度被聚合掉了。

最后,return x.contiguous() 确保返回的张量是连续存储的,这对于后续的某些操作可能很重要,例如当张量需要在GPU上进行高效计算时。这是因为非连续的内存布局可能会导致性能下降。

 

class linear(nn.Module):def __init__(self, c_in, c_out):super(linear, self).__init__()self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0, 0), stride=(1, 1), bias=True)def forward(self, x):return self.mlp(x)

linear 类是一个自定义的 PyTorch 模块,它实质上实现了一个线性变换。

 

class gcn(nn.Module):def __init__(self, c_in, c_out, dropout, support_len=3, order=2):super(gcn, self).__init__()self.nconv = nconv()c_in = (order * support_len + 1) * c_inself.mlp = linear(c_in, c_out)self.dropout = dropoutself.order = orderdef forward(self, x, support):out = [x]for a in support:x1 = self.nconv(x, a)out.append(x1)for k in range(2, self.order + 1):x2 = self.nconv(x1, a)out.append(x2)x1 = x2h = torch.cat(out, dim=1)h = self.mlp(h)return h

gcn 类定义了一个基于图卷积网络(Graph Convolutional Network, GCN)的模块,它在图结构数据上执行多阶卷积操作,以捕获不同层次的节点间关联。下面是对 gcn 类的详细解析:

初始化方法 __init__

在构造函数 __init__ 中,gcn 类继承自 nn.Module 并初始化以下组件:

  • nconv: 实例化 nconv 类,用于执行图卷积操作。
  • mlp: 实例化 linear 类,用于线性变换和聚合来自不同阶卷积的结果。
  • dropout: 设置 dropout 比率,用于正则化和防止过拟合。
  • order: 设置图卷积的阶数,控制卷积操作的深度,即卷积在图上扩展的层数。

c_in 的值被重新定义为 (order * support_len + 1) * c_in,这考虑到了 support_len 个支持矩阵在 order 阶的卷积中产生的特征通道数。+1 是因为原始输入 x 也会被拼接到最终的输出中。

前向传播方法 forward

forward 方法中,gcn 类执行以下操作:

  1. 初始化一个列表 out 来保存每一阶卷积的结果,首先添加原始输入 x
  2. 对于 support 中的每一个邻接矩阵 a,执行以下操作:
    • 使用 nconv 对输入 x 和邻接矩阵 a 进行一次卷积,结果存储在 x1 中,并添加到 out
    • 接下来,对于 order 中的每一阶(从 2 开始),重复使用 nconv 对前一阶的结果 x1 和同一个邻接矩阵 a 进行卷积,结果存储在 x2 中,再添加到 out,并将 x2 设为下一次迭代的输入 x1
  3. 将 out 中的所有结果在通道维度(dim=1)上进行拼接,形成一个包含所有阶卷积结果的张量 h
  4. 将 h 传递给 mlp 层,进行线性变换和通道数的调整,最终输出调整后的特征表示。

总结

gcn 类通过多次调用 nconv 模块来执行多阶图卷积,捕捉图中节点间的多层次关系。通过将不同阶的卷积结果拼接起来,它能够整合从局部到全局的节点信息。最后,mlp 层负责将这些多阶特征映射到期望的输出维度,以便进一步的处理或分类。这种设计使得 gcn 能够有效处理复杂图结构数据,并在诸如社交网络分析、分子结构预测等任务中发挥重要作用。

 

class pgcn(nn.Module):def __init__(self, c_in, c_out, dropout, support_len=3, order=2, temp=1):super(pgcn, self).__init__()self.nconv = nconv()self.temp = tempc_in = (order * support_len + 1) * c_inself.mlp = linear(c_in, c_out)self.dropout = dropoutself.order = orderdef forward(self, x, support):out = [x]for a in support:x1 = self.nconv(x, a)out.append(x1)for k in range(2, self.order + 1):x2 = self.nconv(x1, a)out.append(x2)x1 = x2h = torch.cat(out, dim=1)h = self.mlp(h)h = h[:,:,:,-h.size(3):-self.temp]return h

pgcn 类定义了一个个性化的图卷积网络(Personalized Graph Convolutional Network)模块,它在图卷积的基础上引入了个性化参数,允许模型在处理图数据时考虑到更加细致的节点特性或时间序列特性。下面是对 pgcn 类的详细解析:

初始化方法 __init__

pgcn 类继承自 nn.Module 并初始化以下组件:

  • nconv: 实例化 nconv 类,用于执行图卷积操作。
  • temp: 一个个性化参数,用于在输出中裁剪时间序列数据,这可能用于处理具有周期性或季节性模式的时间序列数据,通过移除某些时间点的数据来增强模型对特定时间模式的学习能力。
  • mlp: 实例化 linear 类,用于线性变换和聚合来自不同阶卷积的结果。
  • dropout: 设置 dropout 比率,用于正则化和防止过拟合。
  • order: 设置图卷积的阶数,控制卷积操作的深度。

gcn 类似,c_in 的值被重新定义为 (order * support_len + 1) * c_in,考虑到了多阶卷积产生的特征通道数。

前向传播方法 forward

forward 方法中,pgcn 类执行的操作与 gcn 类似,但在输出阶段有一个关键的区别:

  1. 初始化一个列表 out 来保存每一阶卷积的结果,首先添加原始输入 x
  2. 对于 support 中的每一个邻接矩阵 a,执行多阶卷积操作,将结果存储在 out 中。
  3. 将 out 中的所有结果在通道维度(dim=1)上进行拼接,形成一个包含所有阶卷积结果的张量 h
  4. 将 h 传递给 mlp 层,进行线性变换和通道数的调整。
  5. 个性化裁剪:在 h 上执行一个个性化裁剪操作,通过 h = h[:,:,:,-h.size(3):-self.temp],这将从 h 的最后一个维度(通常是时间序列的长度)开始,去除从末尾开始的 self.temp 个时间点的数据。这种裁剪可以用于去除不需要的时间点,例如去除最近的短期波动,以便模型更专注于长期趋势或周期性模式。

总结

pgcn 类通过在标准图卷积网络的基础上引入个性化参数 temp,增强了模型处理时间序列图数据的能力。通过裁剪时间序列的末端,模型可以更好地聚焦于数据中的长期模式,这对于处理具有季节性或周期性特性的数据集尤为重要。

 

    def __init__(self, device, num_nodes, dropout=0.3, topk=35,out_dim=12, residual_channels=16, dilation_channels=16, end_channels=512,kernel_size=2, blocks=4, layers=2, days=288, dims=40, order=2, in_dim=9, normalization="batch"):super(STPGNN, self).__init__()skip_channels = 8self.alpha = nn.Parameter(torch.tensor(-5.0))  self.topk = topkself.dropout = dropoutself.blocks = blocksself.layers = layersself.filter_convs = nn.ModuleList()self.gate_convs = nn.ModuleList()self.residual_convs = nn.ModuleList()self.skip_convs = nn.ModuleList()self.normal = nn.ModuleList()self.gconv = nn.ModuleList()self.residual_convs_a = nn.ModuleList()self.skip_convs_a = nn.ModuleList()self.normal_a = nn.ModuleList()self.pgconv = nn.ModuleList()self.start_conv_a = nn.Conv2d(in_channels=in_dim,out_channels=1,kernel_size=(1, 1))self.start_conv = nn.Conv2d(in_channels=in_dim,out_channels=residual_channels,kernel_size=(1, 1))receptive_field = 1self.supports_len = 1self.nodevec_p1 = nn.Parameter(torch.randn(days, dims).to(device), requires_grad=True).to(device)self.nodevec_p2 = nn.Parameter(torch.randn(num_nodes, dims).to(device), requires_grad=True).to(device)self.nodevec_p3 = nn.Parameter(torch.randn(num_nodes, dims).to(device), requires_grad=True).to(device)self.nodevec_pk = nn.Parameter(torch.randn(dims, dims, dims).to(device), requires_grad=True).to(device)

这段代码是 STPGNN 类的初始化方法 __init__ 的一部分,它主要负责构建模型的架构和初始化必要的参数。下面是详细的解析:

构建网络组件

  • Convolution Layers and Residual Connections:

    • self.filter_convsself.gate_convs: 这两个列表存储了因果卷积(Causal Convolution)层,它们用于处理时间序列数据,通过滤波器(filter)和门控(gate)机制捕捉时间依赖性。
    • self.residual_convsself.skip_convs: 这些列表分别存储了残差卷积和跳跃连接卷积层,用于在网络中建立残差连接和跳跃连接,有助于梯度传播并避免深度网络中的梯度消失/爆炸问题。
    • self.normal: 这个列表包含了归一化层,如批量归一化(Batch Normalization)或层归一化(Layer Normalization),用于加速训练过程和提升模型性能。
  • Graph Convolution Layers:

    • self.gconvself.pgconv: 这两个列表分别存储了图卷积(Graph Convolution)层和个性化图卷积(Personalized Graph Convolution)层,用于处理图结构数据,捕捉节点间的空间依赖性。
    • self.residual_convs_aself.skip_convs_aself.normal_a: 这些组件与前面提到的组件类似,但是专门用于辅助分支,可能是为了处理特定类型的信息或者用于构建个性化的图卷积。
  • Input Layers:

    • self.start_conv_aself.start_conv: 这两个卷积层用于调整输入数据的维度,self.start_conv_a 可能用于特定的辅助特征提取,而 self.start_conv 则是主输入层,用于调整输入特征至残差通道数。

参数初始化

  • Receptive Field: receptive_field 是一个变量,初始化为1,它表示网络能够感知的时间序列的宽度。随着网络的深入,这个值会增加,表示网络可以捕捉到更远的历史信息。

  • Node Embeddings and Adjacency Matrix Parameters:

    • self.nodevec_p1self.nodevec_p2self.nodevec_p3self.nodevec_pk: 这些参数是节点嵌入向量和用于构建动态邻接矩阵的参数,它们在训练过程中是可学习的。self.nodevec_p1 代表时间相关的节点嵌入,self.nodevec_p2 和 self.nodevec_p3 代表空间相关的节点嵌入,而 self.nodevec_pk 用于构建核心节点之间的关联,这些参数一起用于构建一个适应性更强的图结构,使得模型能够根据输入数据动态调整节点之间的关联强度。

通过上述组件和参数的初始化,STPGNN 构建了一个能够处理时空序列数据的深度学习模型,结合了时间序列分析和图结构数据处理的优势,适用于如交通流量预测、环境监测等需要同时考虑时间和空间依赖性的任务。

 
        for b in range(blocks):additional_scope = kernel_size - 1new_dilation = 1for i in range(layers):# dilated convolutionsself.filter_convs.append(nn.Conv2d(in_channels=residual_channels,out_channels=dilation_channels,kernel_size=(1, kernel_size), dilation=new_dilation))self.gate_convs.append(nn.Conv1d(in_channels=residual_channels,out_channels=dilation_channels,kernel_size=(1, kernel_size), dilation=new_dilation))self.residual_convs.append(nn.Conv1d(in_channels=dilation_channels,out_channels=residual_channels,kernel_size=(1, 1)))self.skip_convs.append(nn.Conv1d(in_channels=dilation_channels,out_channels=skip_channels,kernel_size=(1, 1)))self.residual_convs_a.append(nn.Conv1d(in_channels=dilation_channels,out_channels=residual_channels,kernel_size=(1, 1)))self.pgconv.append(pgcn(dilation_channels, residual_channels, dropout, support_len=self.supports_len, order=order, temp=new_dilation))self.gconv.append(gcn(dilation_channels, residual_channels, dropout, support_len=self.supports_len, order=order))if normalization == "batch":self.normal.append(nn.BatchNorm2d(residual_channels))self.normal_a.append(nn.BatchNorm2d(residual_channels))elif normalization == "layer":self.normal.append(nn.LayerNorm([residual_channels, num_nodes, 13 - receptive_field - new_dilation + 1]))self.normal_a.append(nn.LayerNorm([residual_channels, num_nodes, 13 - receptive_field - new_dilation + 1]))new_dilation *= 2receptive_field += additional_scopeadditional_scope *= 2

这段代码是 STPGNN 类初始化方法的一部分,它主要负责构建多层因果卷积块,这些块是构成整个网络的基础单元。以下是详细解析:

构建因果卷积块

  • Looping through blocks and layers:
    • 外层循环 for b in range(blocks) 控制着构建的残差块数量,每个块由多个层组成。
    • 内层循环 for i in range(layers) 控制着每个残差块内的层数量。

卷积层的配置

  • Dilated Convolutions:
    • self.filter_convs 和 self.gate_convs 分别存储了滤波器和门控机制的扩张卷积层,用于捕捉时间序列数据中的长期依赖关系。扩张卷积(Dilated Convolution)通过增加卷积核之间的空洞来扩大感受野,而无需增加网络深度或输入尺寸。
    • self.residual_convs 存储了用于残差连接的1x1卷积层,它们用于将输入与扩张卷积的输出相加,形成残差块的核心部分。
    • self.skip_convs 存储了用于跳跃连接的1x1卷积层,它们将中间层的输出传递到网络的最后阶段,帮助网络学习长期依赖。

图卷积层的配置

  • Graph Convolution Layers:
    • self.pgconv 和 self.gconv 分别存储了个性化图卷积(Personalized Graph Convolution)和图卷积(Graph Convolution)层,用于处理图结构数据,捕捉节点间的空间依赖性。这些层在每个因果卷积层之后被调用,将时间序列特征与图结构特征相结合。

归一化层的配置

  • Normalization Layers:
    • 根据 normalization 参数的值,选择批量归一化(nn.BatchNorm2d)或层归一化(nn.LayerNorm)。归一化层有助于加速训练过程,减少内部协变量偏移,提高模型的泛化能力。

扩张因子和感受野的更新

  • Updating Dilation Factor and Receptive Field:
    • new_dilation *= 2 更新了扩张因子,每次内层循环都会翻倍,这样扩张卷积的感受野会随着层数的增加而指数级增长。
    • receptive_field += additional_scope 和 additional_scope *= 2 更新了网络的感受野,反映了随着扩张卷积的深入,网络能够捕捉到的时间序列的宽度也在增加。

通过这种方式,STPGNN 构建了一个能够同时处理时间序列数据和图结构数据的深度学习模型,能够捕捉到数据中的长期依赖和空间依赖,非常适合应用于如交通流量预测等需要同时考虑时间和空间因素的任务。

 

    def dgconstruct(self, time_embedding, source_embedding, target_embedding, core_embedding):adp = torch.einsum('ai, ijk->ajk', time_embedding, core_embedding)adp = torch.einsum('bj, ajk->abk', source_embedding, adp)adp = torch.einsum('ck, abk->abc', target_embedding, adp)adp = F.softmax(F.relu(adp), dim=2)return adpdef pivotalconstruct(self, x, adj, k):x = x.squeeze(1)x = x.sum(dim=0)y = x.sum(dim=1).unsqueeze(0)adjp = torch.einsum('ij, jk->ik', x[:,:-1], x.transpose(0, 1)[1:,:]) / yadjp = adjp * adjscore = adjp.sum(dim=0) + adjp.sum(dim=1)N = x.size(0)_, topk_indices = torch.topk(score,k)mask = torch.zeros(N, dtype=torch.bool,device=x.device)mask[topk_indices] = Truemasked_matrix = adjp * mask.unsqueeze(1) * mask.unsqueeze(0)adjp = F.softmax(F.relu(masked_matrix), dim=1)return (adjp.unsqueeze(0))

这段代码定义了两个函数,dgconstructpivotalconstruct,它们分别用于构建动态图结构和识别关键节点。

dgconstruct函数接受四个参数:time_embedding(时间嵌入),source_embedding(源节点嵌入),target_embedding(目标节点嵌入),和core_embedding(核心嵌入)。此函数的目标是通过四者间的交互作用来构建动态的邻接矩阵adp,这个矩阵描述了在特定时间下,源节点与目标节点之间的影响强度。具体步骤如下:

  1. 首先,使用torch.einsum函数,将时间嵌入与核心嵌入相乘,生成一个中间矩阵adp
  2. 接下来,将源节点嵌入与上一步得到的adp相乘,进一步细化节点间的影响关系。
  3. 最后,目标节点嵌入与当前的adp相乘,完成动态邻接矩阵的构建。
  4. 应用ReLU激活函数和Softmax归一化函数,使矩阵元素非负且按列归一化,确保每个源节点到所有目标节点的边权总和为1。

pivotalconstruct函数则用于识别交通网络中的关键节点。它接受三个参数:x(输入特征矩阵),adj(静态邻接矩阵),和k(关键节点数量)。以下是详细步骤:

  1. 将输入特征矩阵x的维度调整,使其变为二维,然后沿列方向求和,得到节点的时间序列特征。
  2. 对节点的时间序列特征进行行求和,得到节点的总流量,然后将其转置并扩展维度,便于后续计算。
  3. 利用torch.einsum计算节点间的时间序列特征相互作用矩阵adjp,并通过除以节点总流量进行标准化。
  4. adjp与静态邻接矩阵adj相乘,过滤掉不存在物理连接的节点间关系。
  5. 计算每个节点的“重要性”分数,这是通过将adjp矩阵的行和列求和得到的。
  6. 使用torch.topk函数找到具有最高分数的前k个节点,这些节点即为关键节点。
  7. 创建一个布尔掩码mask,用于标记哪些节点是关键节点。
  8. 应用掩码到adjp矩阵,仅保留关键节点间的关系。
  9. 最后,对关键节点的邻接矩阵应用ReLU和Softmax,确保矩阵非负且按列归一化,得到最终的关键节点邻接矩阵adjp,并增加一个维度以适应后续操作。

通过以上两个函数,dgconstruct构建了基于动态特征的邻接矩阵,而pivotalconstruct则识别出了网络中对交通流动有重要影响的关键节点及其相互关系。这两个矩阵将用于后续的图神经网络层,以捕捉交通网络中的空间和时间依赖性。

    def forward(self, inputs, ind):"""input: (B, F, N, T)"""in_len = inputs.size(3)num_nodes = inputs.size(2)if in_len < self.receptive_field:xo = nn.functional.pad(inputs, (self.receptive_field - in_len, 0, 0, 0))else:xo = inputsx = self.start_conv(xo[:, [0]])x_a = self.start_conv_a(xo[:, [0]])skip = 0adj = self.dgconstruct(self.nodevec_p1[ind], self.nodevec_p2, self.nodevec_p3, self.nodevec_pk)pivweight = nn.Parameter(torch.randn(num_nodes, num_nodes).to(x.device), requires_grad=True).to(x.device)adj_p = self.pivotalconstruct(x_a, pivweight, self.topk)supports = [adj]supports_a = [adj_p]for i in range(self.blocks * self.layers):residual = xfilter = self.filter_convs[i](residual)filter = torch.tanh(filter)gate = self.gate_convs[i](residual)gate = torch.sigmoid(gate)x = filter * gatex_a = self.pgconv[i](residual, supports_a)x = self.gconv[i](x, supports)alpha_sigmoid = torch.sigmoid(self.alpha)  x = alpha_sigmoid * x_a +  (1 - alpha_sigmoid) * xx = x + residual[:, :, :, -x.size(3):]s = xs = self.skip_convs[i](s)if isinstance(skip, int):  # B F N Tskip = s.transpose(2, 3).reshape([s.shape[0], -1, s.shape[2], 1]).contiguous()else:skip = torch.cat([s.transpose(2, 3).reshape([s.shape[0], -1, s.shape[2], 1]), skip], dim=1).contiguous()x = self.normal[i](x)x = F.relu(skip)x = F.relu(self.end_conv_1(x))x = self.end_conv_2(x)return x

这段代码实现了一个深度学习模型的前向传播过程,该模型被设计用于处理时序数据,如交通流量预测。模型的输入是一个四维张量,形状为(B, F, N, T),其中B代表批量大小,F代表特征数,N代表节点数,T代表时间步长。模型的架构包含了卷积、门控机制、残差连接、跳过连接以及图神经网络组件。

  1. 输入预处理

    • 首先检查输入的时间长度T是否小于模型的受感野(receptive_field),如果小于,则使用nn.functional.pad对输入进行填充,确保输入的时间序列长度满足要求。
  2. 起始卷积

    • 使用start_convstart_conv_a进行起始卷积,分别对输入的首个特征通道进行处理,得到x和x_a。
  3. 动态图构建

    • dgconstruct函数用于构建动态邻接矩阵,根据节点特征构建图结构。这将用于图卷积操作。
    • pivotalconstruct函数用于构建关键节点图,它使用x_a和关键节点权重矩阵pivweight来构造关键节点的邻接矩阵adj_p。
  4. 多层残差模块

    • 模型包含多个残差块,每个残差块由多层组成。每层首先应用残差连接,之后进行卷积操作,包括滤波器和门控机制。
    • 滤波器和门控卷积的结果分别经过tanh和sigmoid激活函数,之后相乘,产生门控信号控制信息流。
    • 使用pgconv进行关键节点图上的卷积,并使用gconv进行常规图卷积。
    • 引入一个可学习的参数alpha_sigmoid,通过sigmoid函数得到一个0到1之间的值,用于加权融合关键节点图卷积和常规图卷积的结果。
    • 结果再与残差项相加,之后进行跳过连接,将结果存储在skip变量中,用于后续的跳跃连接操作。
  5. 跳跃连接与输出

    • 跳跃连接将每一层的输出收集起来,进行整合,形成skip变量。
    • 经过跳跃连接后,结果经过end_conv_1end_conv_2卷积层处理,最终得到模型的输出。

整个模型通过这种结构能够同时捕捉空间和时间依赖性,特别是在处理像交通流量预测这样的问题时,它能有效利用图结构和时序特性,从而做出更准确的预测。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/24011.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

NDIS Filter开发-网络数据的传输

和NIC小端口驱动不同的是&#xff0c;无需考虑网络数据具体是如何传输的&#xff0c;只需要针对NBL进行处理即可。Filter驱动程序可以启动发送请求和接收指示&#xff0c;或“过滤”其他驱动程序的请求和指示。Filter模块堆叠在微型端口适配器上。 驱动程序堆栈中的Filter模块…

谷粒商城实战(033 业务-秒杀功能4-高并发问题解决方案sentinel 1)

Java项目《谷粒商城》架构师级Java项目实战&#xff0c;对标阿里P6-P7&#xff0c;全网最强 总时长 104:45:00 共408P 此文章包含第326p-第p331的内容 关注的问题 sentinel&#xff08;哨兵&#xff09; sentinel来实现熔断、降级、限流等操作 腾讯开源的tendis&#xff0c…

ctfshow web

【nl】难了 <?php show_source(__FILE__); error_reporting(0); if(strlen($_GET[1])<4){echo shell_exec($_GET[1]); } else{echo "hack!!!"; } ?> //by Firebasky //by Firebasky ?1>nl //先写个文件 ?1*>b //这样子会把所有文件名写在b里…

JSON 无法序列化

JSON 无法序列化通常出现在尝试将某些类型的数据转换为 JSON 字符串时&#xff0c;这些数据类型可能包含不可序列化的内容。 JSON 序列化器通常无法处理特定类型的数据&#xff0c;例如日期时间对象、自定义类实例等。在将数据转换为 JSON 字符串之前&#xff0c;确保所有数据都…

「动态规划」如何求地下城游戏中,最低初始健康点数是多少?

174. 地下城游戏https://leetcode.cn/problems/dungeon-game/description/ 恶魔们抓住了公主并将她关在了地下城dungeon的右下角。地下城是由m x n个房间组成的二维网格。我们英勇的骑士最初被安置在左上角的房间里&#xff0c;他必须穿过地下城并通过对抗恶魔来拯救公主。骑士…

【Text2SQL 论文】C3:使用 ChatGPT 实现 zero-shot Text2SQL

论文&#xff1a;C3: Zero-shot Text-to-SQL with ChatGPT ⭐⭐⭐⭐ arXiv:2307.07306&#xff0c;浙大 Code&#xff1a;C3SQL | GitHub 一、论文速读 使用 ChatGPT 来解决 Text2SQL 任务时&#xff0c;few-shots ICL 的 setting 需要输入大量的 tokens&#xff0c;这有点昂贵…

MacOS M系列芯片一键配置多个不同版本的JDK

第一步&#xff1a;下载JDK。 官网下载地址&#xff1a;Java Archive | Oracle 选择自己想要下载的版本&#xff0c;一般来说下载一个jdk8和一个jdk11就够用了。 M系列芯片选择这两个&#xff0c;第一个是压缩包&#xff0c;第二个是dmg可以安装的。 第二步&#xff1a;编辑…

eclipse插件开发(二)RCP第三方库的引入方式

RCP第三方库的引入 最近在RCP开发过程中遇到JSON串与对象互转的问题&#xff0c;如何像spring开发模式一样引入第三方库呢&#xff1f;eclipse插件开发中用到p2库&#xff0c;但也支持maven库的引入。关键在于.target这个关键文件。 .target 文件用于定义一个目标平台&#x…

民主测评要做些什么?

民主测评&#xff0c;作为一种重要的民主管理工具&#xff0c;旨在通过广泛征求群众意见&#xff0c;对特定对象或事项进行客观、公正的评价。它不仅是推动民主参与、民主监督的重要手段&#xff0c;也是提升治理效能、促进社会和谐的有效途径。以下将详细介绍民主测评的主要过…

如何以非交互方式将参数传递给交互式脚本

文章目录 问题回答1. 使用 Here Document2. 使用 echo 管道传递3. 使用文件描述符4. 使用 expect 工具 参考 问题 我有一个 Bash 脚本&#xff0c;它使用 read 命令以交互方式读取命令参数&#xff0c;例如 yes/no 选项。是否有一种方法可以在非交互式脚本中调用这个脚本&…

Chrome 源码阅读:跟踪一个鼠标事件的流程

我们通过在关键节点打断点的方式&#xff0c;去分析一个鼠标事件的流程。 我们知道chromium是多进程模型&#xff0c;那么&#xff0c;我们可以推测&#xff1a;一个鼠标消息先从主进程产生&#xff0c;再通过跨进程通信发送给渲染进程&#xff0c;渲染进程再发送给WebFrame&a…

【FAS】《CN103106397B》

原文 CN103106397B-基于亮瞳效应的人脸活体检测方法-授权-2013.01.19 华南理工大学 方法 / 点评 核心方法用的是传统的形态学和模板匹配&#xff0c;亮点是双红外发射器做差分 差分&#xff1a;所述FPGA芯片控制两组红外光源&#xff08;一近一远&#xff09;交替亮灭&…

RDMA (2)

iWARP(RDMA)怎么工作的 招式1:bypass内核 非iWARP时,当应用向网络适配器发出读或者写命令时,命令穿过用户空间以及内核空间,因此需要在用户空间和内核空间间进行切换。 iWARP使用RDMA,让应用直接将命令送达到网络适配器。这规避了对内核的调用,减少了开销和延迟。 招式2…

【Kubernetes】三证集齐 Kubernetes实现资源超卖(附镜像包)

目录 插叙前言一、思考和原理二、实现步骤0. 资料包1. TLS证书签发2. 使用 certmanager 生成签发证书3. 获取secret的内容 并替换CA_BUNDLE4.部署svc deploy 三、测试验证1. 观察pod情况2. 给node 打上不需要超售的标签【可以让master节点资源不超卖】3. 资源实现超卖4. 删除还…

IP域名关系的研究与系统设计(学习某知名测绘系统)

IP域名关系库管理包括域名库检索和whois库检索&#xff0c;详情如下。 域名库检索支持以下5项功能&#xff1a; 1.通过过滤器检索 筛选条件包含IP地址、口令、工具名称、可利用的漏洞编号、创建时间&#xff1b; 2.通过关键字检索 在查询框中输入域名库名称的部分关键词&a…

计算机组成结构—IO系统概述

目录 一、I/O 系统的发展 1. 早期阶段 2. 接口模块和 DMA 阶段 3. 通道结构阶段 4. 处理机阶段 二、I/O 系统的组成 1. I/O 软件 2. I/O 硬件 三、I/O 设备 1. I/O 设备分类 2. I/O 设备的组成 在计算机中&#xff0c;除 CPU 和主存两大模块之外&#xff0c;第三个重…

Apple开发者应用商店(AppStore)描述文件及ADHOC描述文件生成

创建AD HOC描述文件 1.选中Profiles,然后点击加号创建 2.创建已注册设备可安装描述文件 3.选择要注册的id 4.选择证书 5.选择设备 6.输入文件名,点击生成 7.生成成功,点击下载

Java使用OpenCV计算两张图片相似度

业务&#xff1a;找出两个表的重复的图片。 图片在表里存的是二进制值&#xff0c;存在大量由于一些特殊情况例如扫描有差异&#xff0c;导致图片存的二进制值不同&#xff0c;但图片其实是一样来的。 所以找出两个表重复相同的图片&#xff0c;不可能只是单纯的比较二进制值…

flask招聘数据分析及展示平台-计算机毕业设计源码39292

目 录 摘要 1 绪论 1.1研究意义 1.2国内外研究进展 1.3flask框架介绍 2 1.4论文结构与章节安排 3 2 招聘数据分析及展示平台分析 4 2.1 可行性分析 4 2.2 系统流程分析 4 2.2.1数据增加流程 5 2.3.2数据修改流程 5 2.3.3数据删除流程 5 2.3 系统功能分析 5 2.3.1 功能性分…

亚马逊新品如何快速吸引流量?自养号测评助卖家一臂之力

在亚马逊平台上每天都会有大量的新品推出&#xff0c;而这些新品中有部分可能并没有什么流量和订单&#xff0c;有些可能上架后立马就能获得流量了&#xff0c;那么亚马逊上新品一般几天出单&#xff1f; 一、亚马逊上新品一般几天出单&#xff1f; 亚马逊上新品出单的时间因…