图神经网络(GNN)基本概念与核心原理
图神经网络(GNN)是一类专门处理图结构数据的神经网络模型 (GTAT: empowering graph neural networks with cross attention | Scientific Reports)。图结构数据由节点(表示实体)和边(表示实体间关系)构成,每个节点和边都可以带有特征信息。GNN的核心思想是通过多轮**消息传递(message passing)**来迭代更新节点的表示:每层GNN会让每个节点收集并聚合其邻居节点的特征,然后通过一个神经网络变换这些聚合信息,更新自身的表示 (Graph neural network - Wikipedia) (GTAT: empowering graph neural networks with cross attention | Scientific Reports)。这样,多层堆叠的GNN可以让信息在图中从一个节点传递到远处的节点,从而学习到图的全局结构特征。
- 图结构和特征:图由节点和边组成,节点可对应机器、任务、地理位置等实体,节点特征描述实体属性(如机器人状态、任务需求等),边可表示实体间的联系或拓扑结构。
- 消息传递与聚合:在每一层GNN中,每个节点会收集所有邻居节点的特征(如将邻居特征求和或求平均),并结合自身特征输入一个神经网络进行变换。这样,节点能“看到”局部邻域的信息,形成新的表示。
- 迭代更新与表达:通过多层GNN的迭代,每个节点的信息融合来自更远节点的影响,最终输出的节点表示(或全图表示)可用于后续任务,如节点分类、图分类或回归等。经过训练后的GNN能够自动提取图结构中的有效信息,无需手工设计特征。
经过若干层GNN后,我们可以得到每个节点或整个图的高维嵌入(embedding),并据此完成分类、回归等任务。这种基于图结构的神经网络具有很强的表达能力,能够捕捉节点间的复杂关系 (GTAT: empowering graph neural networks with cross attention | Scientific Reports) (Graph neural network - Wikipedia)。
GNN示例:基于GCN的简单实现
下面以PyTorch Geometric为例,演示一个简单的两层图卷积网络(GCN)实现,用于对图中节点进行分类。代码中对每行添加了中文注释说明。
import torch
import torch.nn.functional as F