图神经网络实战(15)——SEAL链接预测算法

图神经网络实战(15)——SEAL链接预测算法

    • 0. 前言
    • 1. SEAL 框架
      • 1.1 基本原理
      • 1.2 算法流程
    • 2. 实现 SEAL 框架
      • 2.1 数据预处理
      • 2.2 模型构建与训练
    • 小结
    • 系列链接

0. 前言

我们已经学习了基于节点嵌入的链接预测算法,这种方法通过学习相关的节点嵌入来计算链接可能性。接下来,我们介绍另一类方法,通过查看目标节点周围的局部邻域执行链接预测任务,这类技术称为基于子图的算法,由 SEAL 广泛使用,可以说 SEAL 表示用于链接预测的子图、嵌入和属性 (Subgraphs, Embeddings, and Attributes for Link prediction) 的缩写,但并不完全准确。在本节中,我们将介绍 SEAL 框架,并使用 PyTorch Geometric 实现该框架。

1. SEAL 框架

1.1 基本原理

SEALZhangChen2018 年提出,是一个学习图结构特征以进行链接预测的框架。它将目标节点 ( x , y ) (x,y) (x,y) 和它们的 k 跳 (k-hop) 邻居所形成的子图定义为封闭子图 (enclosing subgraph)。每个封闭子图(而非整个图)都被用作预测链接可能性的输入。从另一个角度来看,SEAL 自动学习了一种用于链接预测的局部启发式方法。

1.2 算法流程

SEAL 框架包括三个步骤:

  1. 封闭子图提取 (Enclosing subgraph extraction),包括提取一组真实链接和一组虚假链接(负抽样)来形成训练数据
  2. 节点信息矩阵构建 (Node information matrix construction),包括节点标记、节点嵌入和节点特征三个部分
  3. 图神经网络 (Graph Neural Networks, GNN) 训练 (GNN training),将节点信息矩阵作为输入,并输出链接的可能性

这些步骤可以用下图进行总结:

SEAL 框架

封闭子图提取是一个简单的过程,列出目标节点及其 k 跳邻居,以提取它们的边和特征。k 值越大,SEAL 所能学习到的启发式算法的质量就越高,但同时也会创建更大、计算开销更大的子图。
节点信息构建的第一个部分是节点标记 (node labeling)。这一过程为每个节点分配一个特定的编号,如果没有进行标记,GNN 就无法区分目标节点和上下文节点(目标节点的邻居)。它还融合了距离,用来描述节点的相对位置和结构重要性。
在实践中,目标节点 x x x y y y 必须共享一个唯一的标签,以确定它们是目标节点。对于上下文节点 i i i j j j,如果它们与目标节点的距离相同,则必须共享相同的标签—— d ( i , x ) = d ( j , x ) d(i, x) = d(j, x) d(i,x)=d(j,x) d ( i , y ) = d ( j , y ) d(i, y) = d(j, y) d(i,y)=d(j,y)。我们称这种距离为双半径 (double radius),表示为 ( d ( i , x ) , d ( i , y ) ) (d(i, x), d(i, y)) (d(i,x),d(i,y))
SEAL 中使用双半径节点标记 (Double-Radius Node Labeling, DRNL) 算法,其工作原理如下:

  1. 首先,将标签 1 分配给节点 x x x y y y
  2. 将标签 2 分配给半径为 (1,1) 的节点
  3. 将标签 3 分配给半径为 (1,2)(2,1) 的节点
  4. 将标签 4 分配给半径为 (1,3)(3,1) 的节点,以此类推

DRNL 函数的数学表达式如下:
f ( i ) = 1 + m i n ( d ( i , x ) , d ( i , y ) ) + ( d / 2 ) [ ( d / 2 ) + ( d % 2 ) − 1 ] f(i)=1+min(d(i,x),d(i,y))+(d/2)[(d/2)+(d\%2)-1] f(i)=1+min(d(i,x),d(i,y))+(d/2)[(d/2)+(d%2)1]
其中, d = d ( i , x ) + d ( i , y ) d= d(i, x) + d(i, y) d=d(i,x)+d(i,y) ( d / 2 ) (d/2) (d/2) ( d % 2 ) (d\%2) (d%2) 分别是 d d d 除以 2 2 2 的整数商和余数。最后,对这些节点标签进行独热编码 (one-hot encode)。
节点信息矩阵构建过程中的其它两个部分比较容易获得。节点嵌入是可选的,可以使用其他算法(如 Node2Vec )计算。然后,将它们与节点特征和独热编码标签连接起来,构建最终的节点信息矩阵 (node information matrix)。
最后,训练 GNN,利用封闭子图的信息和邻接矩阵来预测链接。为此,SEAL 使用了深度图卷积神经网络 (Deep Graph Convolutional Neural Network, DGCNN),该架构执行以下三个步骤:

  1. 使用数个图卷积网络 (Graph Convolutional Network, GCN) 层计算节点嵌入,然后将其串联起来,类似于图同构网络 (Graph Isomorphism Network, GIN)
  2. 使用全局排序池化层按照一致的顺序排列这些嵌入,然后再将它们深入到卷积层,而卷积层不具备置换不变性
  3. 使用传统的卷积层和全连接层应用于排序后图表示,并输出链接概率

DGCNN 模型使用二进制交叉熵损失进行训练,输出的概率介于 01 之间

2. 实现 SEAL 框架

SEAL 框架需要进行大量的预处理,以提取并标注封闭子图。接下来,我们使用 PyTorch Geometric 来实现 SEAL 框架。

2.1 数据预处理

(1) 首先,导入所有必要的库:

import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score
from scipy.sparse.csgraph import shortest_pathimport torch.nn.functional as F
from torch.nn import Conv1d, MaxPool1d, Linear, Dropout, BCEWithLogitsLossfrom torch_geometric.datasets import Planetoid
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, aggr
from torch_geometric.utils import k_hop_subgraph, to_scipy_sparse_matrix

(2) 加载 Cora 数据集,并应用链接级随机拆分:

transform = RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True, split_labels=True)
dataset = Planetoid('.', name='Cora', transform=transform)
train_data, val_data, test_data = dataset[0]

(3) 链接级随机拆分会在数据对象中创建新字段,用于存储每条正样本边(真实的边)和负样本边(虚假的边)的标签和索引:

print(train_data)# Data(x=[2708, 1433], edge_index=[2, 8976], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708], pos_edge_label=[4488], pos_edge_label_index=[2, 4488], neg_edge_label=[4488], neg_edge_label_index=[2, 4488])

(4) 创建函数 seal_processing() 处理拆分后的数据集,并获得带有独热编码节点标签和节点特征的封闭子图,使用列表 data_list 存储这些子图:

def seal_processing(dataset, edge_label_index, y):data_list = []

对于数据集中的每一对节点(源和目的节点),提取 k 跳邻居(本节中 k = 2):

    for src, dst in edge_label_index.t().tolist():sub_nodes, sub_edge_index, mapping, _ = k_hop_subgraph([src, dst], 2, dataset.edge_index, relabel_nodes=True)src, dst = mapping.tolist()

使用双半径节点标记 (Double-Radius Node Labeling, DRNL) 函数计算距离。首先,从子图中删除目标节点:

        mask1 = (sub_edge_index[0] != src) | (sub_edge_index[1] != dst)mask2 = (sub_edge_index[0] != dst) | (sub_edge_index[1] != src)sub_edge_index = sub_edge_index[:, mask1 & mask2]

根据上一个子图计算源节点和目标节点的邻接矩阵:

        src, dst = (dst, src) if src > dst else (src, dst)adj = to_scipy_sparse_matrix(sub_edge_index, num_nodes=sub_nodes.size(0)).tocsr()idx = list(range(src)) + list(range(src + 1, adj.shape[0]))adj_wo_src = adj[idx, :][:, idx]idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))adj_wo_dst = adj[idx, :][:, idx]

计算每个节点与源节点/目标节点之间的距离:

        # Calculate the distance between every node and the source target noded_src = shortest_path(adj_wo_dst, directed=False, unweighted=True, indices=src)d_src = np.insert(d_src, dst, 0, axis=0)d_src = torch.from_numpy(d_src)# Calculate the distance between every node and the destination target noded_dst = shortest_path(adj_wo_src, directed=False, unweighted=True, indices=dst-1)d_dst = np.insert(d_dst, src, 0, axis=0)d_dst = torch.from_numpy(d_dst)

计算子图中每个节点的节点标签 z z z

        dist = d_src + d_dstz = 1 + torch.min(d_src, d_dst) + dist // 2 * (dist // 2 + dist % 2 - 1)z[src], z[dst], z[torch.isnan(z)] = 1., 1., 0.z = z.to(torch.long)

在本节中,并未使用节点嵌入,但仍将特征和独热编码标签串联起来,以构建节点信息矩阵:

        node_labels = F.one_hot(z, num_classes=200).to(torch.float)node_emb = dataset.x[sub_nodes]node_x = torch.cat([node_emb, node_labels], dim=1)

创建一个 Data 对象并将其附加到列表 data_list 中,作为函数的最终输出:

        data = Data(x=node_x, z=z, edge_index=sub_edge_index, y=y)data_list.append(data)return data_list

(5) 调用 deal_processing 提取每个数据集的封闭子图。将正样本和负样本分开,以获得正确的预测标签:

train_pos_data_list = seal_processing(train_data, train_data.pos_edge_label_index, 1)
train_neg_data_list = seal_processing(train_data, train_data.neg_edge_label_index, 0)val_pos_data_list = seal_processing(val_data, val_data.pos_edge_label_index, 1)
val_neg_data_list = seal_processing(val_data, val_data.neg_edge_label_index, 0)test_pos_data_list = seal_processing(test_data, test_data.pos_edge_label_index, 1)
test_neg_data_list = seal_processing(test_data, test_data.neg_edge_label_index, 0)

(6) 合并正负数据列表,重建训练、验证和测试数据集:

train_dataset = train_pos_data_list + train_neg_data_list
val_dataset = val_pos_data_list + val_neg_data_list
test_dataset = test_pos_data_list + test_neg_data_list

(7) 创建数据加载器,使用批数据训练 GNN

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)

2.2 模型构建与训练

(1) 定义 DGCNN 类,其中参数 k 表示每个子图的节点数:

class DGCNN(torch.nn.Module):def __init__(self, dim_in, k=30):super().__init__()

创建四个 GCN 层,设定隐藏维度为 32

        self.gcn1 = GCNConv(dim_in, 32)self.gcn2 = GCNConv(32, 32)self.gcn3 = GCNConv(32, 32)self.gcn4 = GCNConv(32, 1)

实例化全局排序池化层 (深度图卷积神经网络 (Deep Graph Convolutional Neural Network, DGCNN) 架构的核心):

        self.global_pool = aggr.SortAggregation(k=k)

全局排序池化层提供的节点排序使我们能够使用传统的卷积层:

        self.conv1 = Conv1d(1, 16, 97, 97)self.conv2 = Conv1d(16, 32, 5, 1)self.maxpool = MaxPool1d(2, 2)

最后,实例化多层感知机 (Multilayer Perceptron, MLP) 用于获取预测:

        self.linear1 = Linear(352, 128)self.dropout = Dropout(0.5)self.linear2 = Linear(128, 1)

forward() 方法中,计算每个 GCN 的节点嵌入,并将结果串联起来:

    def forward(self, x, edge_index, batch):# 1. Graph Convolutional Layersh1 = self.gcn1(x, edge_index).tanh()h2 = self.gcn2(h1, edge_index).tanh()h3 = self.gcn3(h2, edge_index).tanh()h4 = self.gcn4(h3, edge_index).tanh()h = torch.cat([h1, h2, h3, h4], dim=-1)

对串联结果依次应用全局排序池化、卷积层和全连接层:

        # 2. Global sort poolingh = self.global_pool(h, batch)# 3. Traditional convolutional and dense layersh = h.view(h.size(0), 1, h.size(-1))h = self.conv1(h).relu()h = self.maxpool(h)h = self.conv2(h).relu()h = h.view(h.size(0), -1)h = self.linear1(h).relu()h = self.dropout(h)h = self.linear2(h).sigmoid()return h

(2) 将模型实例化,并使用 Adam 优化器和二进制交叉熵损失对其进行训练:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DGCNN(train_dataset[0].num_features).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
criterion = BCEWithLogitsLoss()

(3) 创建 train() 函数用于批训练:

def train():model.train()total_loss = 0for data in train_loader:data = data.to(device)optimizer.zero_grad()out = model(data.x, data.edge_index, data.batch)loss = criterion(out.view(-1), data.y.to(torch.float))loss.backward()optimizer.step()total_loss += float(loss) * data.num_graphsreturn total_loss / len(train_dataset)

(4)test() 函数中,计算 ROC AUC 分数和平均精度,以比较 SEAL 和变分图自编码器 (Variational Graph Autoencoder, VGAE) 的性能:

@torch.no_grad()
def test(loader):model.eval()y_pred, y_true = [], []for data in loader:data = data.to(device)out = model(data.x, data.edge_index, data.batch)y_pred.append(out.view(-1).cpu())y_true.append(data.y.view(-1).cpu().to(torch.float))auc = roc_auc_score(torch.cat(y_true), torch.cat(y_pred))ap = average_precision_score(torch.cat(y_true), torch.cat(y_pred))return auc, ap

(5)DGCNN 进行 31epoch 的训练:

for epoch in range(31):loss = train()val_auc, val_ap = test(val_loader)print(f'Epoch {epoch:>2} | Loss: {loss:.4f} | Val AUC: {val_auc:.4f} | Val AP: {val_ap:.4f}')

模型训练过程监测

(6) 最后,在测试数据集上对其进行测试:

test_auc, test_ap = test(test_loader)
print(f'Test AUC: {test_auc:.4f} | Test AP {test_ap:.4f}')# Test AUC: 0.7899 | Test AP 0.8174

可以看到,使用 SEAL 框架得到的结果与使用 VGAE 得到的结果(AUC0.8727AP0.8620) 相似。从理论上讲,基于子图的方法(如 SEAL )比基于节点的方法(如 VGAE )更具表达能力,基于子图的方法通过明确考虑目标节点周围的整个邻域来捕捉更多信息。通过 k 参数增加所考虑的邻域数量,可以进一步提高 SEAL 的准确性。

小结

链接预测是指利用图数据中已知的节点和边的信息,来推断图中未知的连接关系或者未来可能出现的连接关系,在机器学习和数据挖掘等领域具有广泛的应用。本节中介绍了用于链接预测的 SEAL 框架,其侧重于子图表示,每个链接周围的邻域作为预测链接概率的输入。并使用边级随机分割和负采样在 Cora 数据集上实现了 SEAL 模型执行链接预测任务。

系列链接

图神经网络实战(1)——图神经网络(Graph Neural Networks, GNN)基础
图神经网络实战(2)——图论基础
图神经网络实战(3)——基于DeepWalk创建节点表示
图神经网络实战(4)——基于Node2Vec改进嵌入质量
图神经网络实战(5)——常用图数据集
图神经网络实战(6)——使用PyTorch构建图神经网络
图神经网络实战(7)——图卷积网络(Graph Convolutional Network, GCN)详解与实现
图神经网络实战(8)——图注意力网络(Graph Attention Networks, GAT)
图神经网络实战(9)——GraphSAGE详解与实现
图神经网络实战(10)——归纳学习
图神经网络实战(11)——Weisfeiler-Leman测试
图神经网络实战(12)——图同构网络(Graph Isomorphism Network, GIN)
图神经网络实战(13)——经典链接预测算法
图神经网络实战(14)——基于节点嵌入预测链接

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/35186.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

决策树回归原理详解及Python代码示例

决策树回归原理详解 决策树回归(Decision Tree Regression)是一种非参数监督学习方法,它使用树形结构来对目标变量进行预测。与线性回归模型不同,决策树回归不需要预先假设数据的分布形式,因此能够很好地处理非线性和高…

2024年上半年软件设计师上午真题及答案解析

1.在计算机网络协议五层体系结构中,( B )工作在数据链路层。 A.路由器 B.以太网交换机 C.防火墙 D.集线器 网络层:路由器、防火墙 数据链路层:交换机、网桥 物理层:中继器、集线器 2.软件交付之后&#xff…

数据可视化期末考试(编程)

1.KNN 1.新增数据的分类 import pandas as pd # 您的原始数据字典 data { 电影名称: [电影1, 电影2, 电影3, 电影4, 电影5], 打斗镜头: [10, 5, 108, 115, 20], 接吻镜头: [110, 89, 5, 8, 200], 电影类型: [爱情片, 爱情片, 动作片, 动作片, 爱情片] } …

uni-app uni-data-picker级联选择器无法使用和清除选中的值

出现问题&#xff1a; 使用点击右边的叉号按钮无法清除已经选择的uni-data-picker值 解决办法&#xff1a; 在uni-app uni-data-picker使用中&#xff0c;要添加v-model&#xff0c;v-model在官网的示例中没有体现&#xff0c;但若不加则无法清除。 <uni-data-picker v-m…

OpenAI用GPT-4o打造癌症筛查AI助手;手机就能检测中风,准确率达 82%!中国气象局发布AI气象大模型...

AI for Science 企业动态速览—— * 皇家墨尔本大学用 AI 检测患者中风&#xff0c;准确率达 82% * OpenAI 用 GPT-4o 模型打造癌症筛查 AI 助手 * 中国气象局发布 AI 气象大模型风清、风雷、风顺 * AI 药企英矽智能&#xff1a;小分子抑制剂已完成中国 IIa 期临床试验全部患者…

GPT-5智能新纪元的曙光

在美国达特茅斯工程学院周四公布的采访中&#xff0c;OpenAI首席技术官米拉穆拉蒂被问及GPT-5是否会在明年发布&#xff0c;给出了肯定答案并表示将在一年半后发布。穆拉蒂在采访中还把GPT-4到GPT-5的飞跃描述为高中生到博士生的成长。 这一爆炸性的消息&#xff0c;震动了整体…

linux下进度条的实现

目录 一、代码一版 1.processbar.h 2.processbar.c 3.main.c 二、代码二版 1.processbar.h 2.processbar.c 3.main.c 三、改变文字颜色 一、代码一版 使用模块化编程 1.processbar.h #include<stdio.h> #define capacity 101 //常量使用宏定义 #define style…

Perl文件句柄深度解析:掌握文件操作的核心

Perl中的文件句柄是进行文件输入输出操作的关键。它们提供了一种机制&#xff0c;允许Perl脚本打开文件、读写数据、定位文件指针&#xff0c;以及关闭文件。理解文件句柄的使用对于编写高效的Perl脚本至关重要。本文将深入探讨Perl文件句柄的概念、使用方法和最佳实践。 1. 文…

【Pytorch使用教程】torch.backends.cudnn.benchmark = True的作用

在 PyTorch 中,设置 torch.backends.cudnn.benchmark = True 是一种优化深度学习应用程序性能的方法,特别是当你有固定输入大小的时候。 解释 CuDNN:CUDA Deep Neural Network library(CuDNN)是 NVIDIA 提供的一个 GPU 加速库,用于深度神经网络。PyTorch 在底层使用 Cu…

代码随想录——买股票的最佳时机Ⅱ(Leecode122)

添加链接描述 贪心 局部最优&#xff1a;手机每天的正利润 全局最优&#xff1a;求最大利润 class Solution {public int maxProfit(int[] prices) {int res 0;for(int i 1; i < prices.length; i){res Math.max(prices[i] - prices[i - 1], 0);}return res;} }

【计算机视觉】mmcv库详细介绍

文章目录 MMVC库概览特点和优势主要组件应用案例示例一:数据加载和处理示例二:模型训练和验证MMVC库概览 MMCV 是一个用于计算机视觉研究的开源库,它为各种视觉任务提供了底层的、高度优化的 API。该库涵盖了从数据加载到模型训练的各个方面,广泛应用于开源项目,如 MMDet…

webstorm无法识别tsconfig.json引用项目配置文件中的路径别名

问题 vite项目模板中&#xff0c;应用的ts配置内容写在tsconfig.app.json文件中&#xff0c;并在tsconfig.json通过项目引用的方式导入 {"files": [],"references": [{"path": "./tsconfig.app.json"},{"path": "./t…

2024-06-25 问AI: 在大语言模型中, Hugging Face 是什么?

文心一言 Hugging Face 在大语言模型领域中是一个非常重要的存在&#xff0c;它主要提供了一系列自然语言处理&#xff08;NLP&#xff09;相关的工具和资源。以下是关于 Hugging Face 的详细介绍&#xff1a; 公司背景&#xff1a;Hugging Face 是一家成立于2016年的开源模型…

2024年第十五届蓝桥杯青少组大赛8月24日开启

据蓝桥杯青少组官网显示&#xff0c;2024年第十五届蓝桥杯青少组大赛8月24日开启。 蓝桥杯青少组历届题库地址&#xff1a;http://www.6547.cn/question/cat/2 蓝桥杯青少组历届真题下载&#xff1a;http://www.6547.cn/wenku/list/10

python基础1.2----爬虫基础

python基础内容之爬虫 ## 1. 关于爬虫的特殊性 爬虫是一个很蛋疼的东西, 可能今天讲解的案例. 明天就失效了. 所以, 不要死盯着一个网站干. 要学会见招拆招(爬虫的灵魂) 爬虫程序如果编写的不够完善. 访问频率过高. 很有可能会对服务器造成毁灭性打击, 所以, 不要死盯着一个网…

MySQL用户管理和高级SQL语句

一、用户管理 1.新建用户 mysql> create user zhangsanlocalhost identified by pwd123; Query OK, 0 rows affected (0.00 sec)mysql> create user lisilocalhost identified by pwd123; Query OK, 0 rows affected (0.00 sec)mysql> create user wangwulocalhost …

统一视频接入平台LntonCVS视频共享交换平台智慧景区运用方案

随着夏季的到来&#xff0c;各地景区迎来了大量游客&#xff0c;而景区管理面临的挑战也愈加严峻&#xff0c;尤其是安全问题显得格外突出。 视频监控在预防各类安全事故方面发挥着重要作用&#xff0c;不论是自然景区还是人文景区&#xff0c;都潜藏着诸多安全隐患&#xff0…

每日一道算法题 成绩排序

题目 成绩排序_牛客题霸_牛客网 (nowcoder.com) Python nint(input()) flagint(input()) ans[] for _ in range(n):name,scoreinput().split( )ans.append([name,int(score)]) ans.sort(keylambda x:x[1],reverse not flag)for e in ans:print(e[0],e[1],sep )C #include &…

排序之插入排序----直接插入排序和希尔排序(1)

个人主页&#xff1a;C忠实粉丝 欢迎 点赞&#x1f44d; 收藏✨ 留言✉ 加关注&#x1f493;本文由 C忠实粉丝 原创 排序之插入排序----直接插入排序和希尔排序(1) 收录于专栏【数据结构初阶】 本专栏旨在分享学习数据结构学习的一点学习笔记&#xff0c;欢迎大家在评论区交流讨…

图形编辑器基于Paper.js教程04: Paper.js中的基础知识

背景 了解paper.js的基础知识&#xff0c;在往后的开发过程中会让你如履平地。 基础知识 paper.js 提供了两种编写方式&#xff0c;一种是纯粹的JavaScript编写&#xff0c;还有一种是使用官方提供的PaperScript。 区别就是在于&#xff0c;调用paper下的字对象是否需要加pa…