图神经网络实战(14)——基于节点嵌入预测链接

图神经网络实战(14)——基于节点嵌入预测链接

    • 0. 前言
    • 1. 图自编码器
    • 2. 变分图自编码器
    • 3. 实现变分图自编码器
    • 小结
    • 系列链接

0. 前言

我们已经了解了如何使用图神经网络 (Graph Neural Networks, GNN) 生成节点嵌入,我们可以使用这些嵌入执行矩阵分解 (matrix factorization) 完成链接预测任务。本节将介绍两种用于链接预测的 GNN 架构——图自编码器 (Graph Autoencoder, GAE) 和变分图自编码器 (Variational Graph Autoencoder, VGAE)。

1. 图自编码器

图自编码器 (Graph Autoencoder, GAE) 和变分图自编码器 (Variational Graph Autoencoder, VGAE) 架构都是 KipfWelling2016 年所提出的。它们分别对应于两种流行的神经网络架构——自编码器 (Autoencoder) 和变分自编码器 (Variational Autoencoder, VAE)。为了便于理解,我们将首先介绍 GAEGAE 由两个模块组成:

  • 编码器 (encoder):一个经典的双层图卷积网络 (Graph Convolutional Network, GCN),使用以下方式计算节点嵌入:
    Z = G C N ( X , A ) Z=GCN(X,A) Z=GCN(X,A)
  • 解码器 (decoder):使用矩阵分解 (matrix factorization) 和 sigmoid 函数 σ σ σ 来近似邻接矩阵 A ^ \hat A A^,从而输出概率:
    A ^ = σ ( Z T Z ) \hat A=\sigma(Z^TZ) A^=σ(ZTZ)

需要注意的是,我们并不是要对节点或图进行分类,而是预测邻接矩阵 A ^ \hat A A^ 中每个元素的概率(介于 01 之间),因此使用两个邻接矩阵元素之间的二进制交叉熵损失(负对数似然)来训练 GAE
L B C E = ∑ i ∈ V , j ∈ V − A i j l o g ( A ^ i j ) − ( 1 − A i j ) l o g ( 1 − A ^ i j ) \mathcal L_{BCE}=\sum_{i\in V,j\in V}-A_{ij}log(\hat A_{ij})-(1-A_{ij})log(1-\hat A_{ij}) LBCE=iV,jVAijlog(A^ij)(1Aij)log(1A^ij)

然而,邻接矩阵通常非常稀疏,这会使 GAE 偏向于预测零值。有两种简单的方法可以修正这一偏差。首先,可以在上述损失函数中增加一个权重,使偏向于 A i i = 1 A_{ii}=1 Aii=1。其次,可以在训练过程中采样较少的零值,使标签更加均衡。
这种架构非常灵活,编码器可以换成其它类型的图神经网络 (Graph Neural Networks, GNN) (如 GraphSAGE、图同构网络 (Graph Isomorphism Network, GIN) 等),多层感知机 (Multilayer Perceptron, MLP) 也可以作为解码器,另一种改进方法是将 GAE 转换为变分图自编码器。

2. 变分图自编码器

图自编码器 (Graph Autoencoder, GAE) 和变分图自编码器 (Variational Graph Autoencoder, VGAE) 之间的区别与自编码器 (Autoencoder) 和变分自编码器 (Variational Autoencoder, VAE) 之间的区别相同。VGAE 不直接学习节点嵌入,而是学习正态分布,然后通过采样生成嵌入。VGAE 也由两个模块组成:

  • 编码器 (encoder):由共享第一层的两个图卷积网络 (Graph Convolutional Network, GCN) 组成。其目标是学习每个潜正态分布的参数,均值 μ μ μ (由 G C N μ GCN_μ GCNμ 学习)和方差 σ 2 σ^2 σ2 (在实践中通过 G C N σ GCN_σ GCNσ 学习其对数形式)
  • 解码器 (decoder):使用重参数化技巧 (reparametrization trick),从学习到的分布 ( μ , σ 2 ) (μ, σ^2) (μ,σ2) 中采样嵌入值 z i z_i zi, 。然后,它使用潜变量之间的内积来近似邻接矩阵 A ^ = σ ( Z T Z ) \hat A= σ(Z^TZ) A^=σ(ZTZ)

对于 VGAE,确保编码器的输出服从正态分布非常重要,因此需要在损失函数中添加一个新项,Kullback-Leibler 散度 (KL 散度),它用于测量两个分布之间的差异。VGAE 的总体损失如下,也称为证据下界 (evidence lower bound, ELBO):
L E L B O = L B C E − K L [ q ( Z ∣ X , A ) ∣ ∣ p ( Z ) ] \mathcal L_{ELBO}=\mathcal L_{BCE}-KL[q(Z|X,A)||p(Z)] LELBO=LBCEKL[q(ZX,A)∣∣p(Z)]
其中, q ( Z ∣ X , A ) q(Z|X,A) q(ZX,A) 表示编码器, p ( Z ) p(Z) p(Z) Z Z Z 的先验分布。通常可以使用ROC 曲线下面积 (area under the ROC, AUROC) 和平均精度 (average precision, AP) 这两个指标来评估模型的性能。
接下来,我们使用 PyTorch Geometric 实现 VGAE

3. 实现变分图自编码器

变分图自编码器 (Variational Graph Autoencoder, VGAE) 与其它类型的图神经网络 (Graph Neural Networks, GNN) (如 GraphSAGE、图同构网络 (Graph Isomorphism Network, GIN) 等)实现有两个主要区别:

  • 对数据集进行预处理,随机删除一些链接以进行预测
  • 创建一个编码器模型,并将其添加到 VGAE 类中,而不是直接从头开始实现 VGAE

接下来,使用 PyTorch Geometric (PyG) 构建 VGAE 模型。

(1) 首先,导入所需的库,并定义设备:

import numpy as np
import torch
import matplotlib.pyplot as plt
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoiddevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

(2) 创建一个 transform 对象,对输入特征进行归一化处理,将张量转移到预定义的设备中,并随机分割链接(在本节中,我们按照 85: 5:10 的比例进行拆分),将 add_negative_train_samples 参数设置为 False,因为模型已经执行了负采样,所以数据集中不需要负采样:

transform = T.Compose([T.NormalizeFeatures(),T.ToDevice(device),T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True, split_labels=True, add_negative_train_samples=False),
])

(3) 使用定义的 transform 对象加载 Cora 数据集:

dataset = Planetoid('.', name='Cora', transform=transform)

(4) RandomLinkSplit 方法会按预定比例拆分生成训练/验证/测试集,并存储这些数据集:

train_data, val_data, test_data = dataset[0]

(5) 接下来,实现编码器。首先,需要导入 GCNConv 和 VGAE

from torch_geometric.nn import GCNConv, VGAE

声明一个新类,在这个类中,需要三个图卷积网络 (Graph Convolutional Network, GCN) 层,一个作为共享层、一个用于近似均值 μ μ μ,第三个用于近似方差值(实践中使用对数标准差, log ⁡ σ \log\sigma logσ):

class Encoder(torch.nn.Module):def __init__(self, dim_in, dim_out):super().__init__()self.conv1 = GCNConv(dim_in, 2 * dim_out)self.conv_mu = GCNConv(2 * dim_out, dim_out)self.conv_logstd = GCNConv(2 * dim_out, dim_out)def forward(self, x, edge_index):x = self.conv1(x, edge_index).relu()return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)

(6) 初始化 VGAE 并将编码器作为输入,默认情况下,VGAE 使用内积作为解码器:

model = VGAE(Encoder(dataset.num_features, 16)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

(7)train() 方法中,首先使用 model.encode() 计算嵌入矩阵 Z Z Z,此函数从学习到的分布中对样本嵌入进行采样。然后,使用 model.recon_loss() (二进制交叉熵损失)和 model.kl_loss() (KL 散度) 计算 ELBO 损失。解码器会被隐式调用来计算交叉熵损失:

def train():model.train()optimizer.zero_grad()z = model.encode(train_data.x, train_data.edge_index)loss = model.recon_loss(z, train_data.pos_edge_label_index) + (1 / train_data.num_nodes) * model.kl_loss()loss.backward()optimizer.step()return float(loss)

(8) test() 函数只需调用 VGAE 的专用方法:

@torch.no_grad()
def test(data):model.eval()z = model.encode(data.x, data.edge_index)return model.test(z, data.pos_edge_label_index, data.neg_edge_label_index)

(9) 对模型进行 301epoch 的训练,并打印 AUCAP 指标:

for epoch in range(301):loss = train()val_auc, val_ap = test(test_data)if epoch % 50 == 0:print(f'Epoch {epoch:>2} | Loss: {loss:.4f} | Val AUC: {val_auc:.4f} | Val AP: {val_ap:.4f}') 

输出结果如下所示:

Epoch  0 | Loss: 3.4412 | Val AUC: 0.6842 | Val AP: 0.7043
Epoch 50 | Loss: 1.3321 | Val AUC: 0.6628 | Val AP: 0.6881
Epoch 100 | Loss: 1.1690 | Val AUC: 0.7512 | Val AP: 0.7526
Epoch 150 | Loss: 1.0348 | Val AUC: 0.8173 | Val AP: 0.8128
Epoch 200 | Loss: 0.9980 | Val AUC: 0.8415 | Val AP: 0.8364
Epoch 250 | Loss: 0.9698 | Val AUC: 0.8576 | Val AP: 0.8457
Epoch 300 | Loss: 0.9339 | Val AUC: 0.8727 | Val AP: 0.8620

(10) 在测试集上对模型进行评估:

test_auc, test_ap = test(test_data) 
print(f'Test AUC: {test_auc:.4f} | Test AP {test_ap:.4f}')# Test AUC: 0.8727 | Test AP 0.8620

(11) 手动计算近似邻接矩阵 A ^ \hat A A^

z = model.encode(test_data.x, test_data.edge_index) 
Ahat = torch.sigmoid(z @ z.T)
print(Ahat)
'''
tensor([[0.8468, 0.5072, 0.7254,  ..., 0.7016, 0.8674, 0.8545],[0.5072, 0.8120, 0.7991,  ..., 0.4572, 0.6988, 0.6898],[0.7254, 0.7991, 0.8623,  ..., 0.5731, 0.8622, 0.8496],...,[0.7016, 0.4572, 0.5731,  ..., 0.6582, 0.6973, 0.6925],[0.8674, 0.6988, 0.8622,  ..., 0.6973, 0.9259, 0.9155],[0.8545, 0.6898, 0.8496,  ..., 0.6925, 0.9155, 0.9051]],device='cuda:0', grad_fn=<SigmoidBackward0>)
'''

VGAE 的训练速度很快,输出结果也很容易理解,但我们已经知道 GCN 并不是最具表达能力的运算符。为了提高模型的表达能力,我们需要采用更好的技术。

小结

链接预测可以帮助我们发现隐藏的关联规律,从而为网络分析、推荐系统等问题提供有效的解决方案。在本节中,介绍了如何使用图神经网络 (Graph Neural Networks, GNN) 实现链接预测,学习了基于节点嵌入的链接预测技术,包括图自编码器 (Graph Autoencoder, GAE) 和变分图自编码器 (Variational Graph Autoencoder, VGAE),并使用边级随机分割和负采样在 Cora 数据集上实现了 VGAE 模型。

系列链接

图神经网络实战(1)——图神经网络(Graph Neural Networks, GNN)基础
图神经网络实战(2)——图论基础
图神经网络实战(3)——基于DeepWalk创建节点表示
图神经网络实战(4)——基于Node2Vec改进嵌入质量
图神经网络实战(5)——常用图数据集
图神经网络实战(6)——使用PyTorch构建图神经网络
图神经网络实战(7)——图卷积网络(Graph Convolutional Network, GCN)详解与实现
图神经网络实战(8)——图注意力网络(Graph Attention Networks, GAT)
图神经网络实战(9)——GraphSAGE详解与实现
图神经网络实战(10)——归纳学习
图神经网络实战(11)——Weisfeiler-Leman测试
图神经网络实战(12)——图同构网络(Graph Isomorphism Network, GIN)
图神经网络实战(13)——经典链接预测算法

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/28518.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

中华老字号李良济,展现百年匠心之魅力,释放千年中医药文化自信

6月14-16日&#xff0c;“潮品老字号 国货LU锋芒”江苏老字号博览会在南京隆重启幕&#xff0c;中华老字号李良济凭借过硬的品牌实力和优质的口碑再次受邀参加&#xff0c;并在展会上绽放百年匠心魅力&#xff0c;彰显千年中医药文化自信&#xff01; 百年匠心 以实力铸就荣耀…

计算机组成原理之定点乘法运算

文章目录 原码并行乘法与补码并行乘法原码算法运算规则存在的问题带符号的阵列乘法器习题原码阵列乘法器间接补码阵列乘法器直接补码阵列乘法器 补码与真值的转换 原码并行乘法与补码并行乘法 原码算法运算规则 存在的问题 理解流水式阵列乘法器&#xff08;并行乘法器&#x…

Java环境安装

下载JDK https://www.oracle.com/cn/java/technologies/downloads/#jdk22-windows 点开那个下载都可以但是要记住下载的路径因为下一步要添加环境变量 选择编辑系统环境变量 点击环境变量 点击新建 新建环境变量JAVA_HOME 并输入JDK在计算机保存的路径 打开cmd 输入java -…

GStreamer——教程——基础教程4:Time management

基础教程4&#xff1a;Time management&#xff08;时间管理&#xff09; 目标 本教程展示了如何使用GStreamer时间相关工具。特别是&#xff1a; 如何查询管道以获取流位置或持续时间等信息。如何寻找&#xff08;跳转&#xff09;到流内的不同位置&#xff08;时间&#x…

JVM调优-推荐启动参数

JVM&#xff08;Java Virtual Machine&#xff09;调优是为了提高Java应用程序的性能和稳定性。以下是一些常用的JVM启动参数及其作用&#xff0c;这些参数可以帮助优化JVM性能&#xff1a; 1. 堆内存设置&#xff1a; - -Xms<size>: 设置初始堆大小。例如&#xff0…

python模块之codecs

python 模块codecs python对多国语言的处理是支持的很好的&#xff0c;它可以处理现在任意编码的字符&#xff0c;这里深入的研究一下python对多种不同语言的处理。 有一点需要清楚的是&#xff0c;当python要做编码转换的时候&#xff0c;会借助于内部的编码&#xff0c;转换…

数据结构与算法笔记:基础篇 -递归树:如何借助树来求解递归算法的时间复杂度?

概述 我们都知道&#xff0c;递归代码的时间复杂度分析起来很麻烦。在《排序(下)》哪里讲过&#xff0c;如何用递推公式&#xff0c;求解归并排序、快速排序的时间复杂度&#xff0c;但是有些情况&#xff0c;比如快排的平均时间复杂度的分析&#xff0c;用递推公式的话&#…

《天软股票特色因子定期报告》

最新《天软股票特色因子定期报告》&#xff08;2024-06&#xff09;&#xff0c;抢先发布 内容概要如下&#xff1a; 天软特色因子A08006&#xff08;近一月日度买卖压力2&#xff09;从行业角度分析&#xff0c;在电子设备、石油石化行业表现稳定&#xff0c;无论在有效性、区…

【名词解释】Unity中的3D物理系统:触发器

在Unity的3D物理系统中&#xff0c;触发器&#xff08;Trigger&#xff09;是一种特殊的碰撞体&#xff0c;用于检测物体进入或离开一个特定区域的事件&#xff0c;但它不会像普通碰撞体那样产生物理碰撞反应。触发器通常用于实现非物理交互&#xff0c;如检测玩家进入特定区域…

复星杏脉算法面经2024年5月16日面试

复星杏脉算法面经2024年5月 面试记录&#xff1a;3个部分1. 自己介绍 2. 问八股 3.代码题先自我介绍20分钟问问题1. 梯度爆炸怎么解决&#xff0c;三个解决方案&#xff1a;梯度裁剪&#xff08;Gradient Clipping&#xff09;正则化&#xff08;Regularization&#xff09;调整…

C11与C++11关于Atomic原子类型的异同

"The C11 atomics were almost copynpasted from C11. All the work was done for C, and C (sensibly) incorporated it wholesale." 上面这句话源自&#xff1a;C11 atomic variables and the kernel [LWN.net] 翻译过来就是&#xff1a; "C11 中的原子操作…

HTML 颜色名

HTML 颜色名 HTML 颜色名是一组预定义的颜色&#xff0c;可以在 HTML 和 CSS 中使用。这些颜色名易于记忆&#xff0c;方便开发者快速选择和使用。本文将详细介绍 HTML 颜色名&#xff0c;包括它们的用途、优点以及如何在网页设计中使用它们。 HTML 颜色名的用途 HTML 颜色名…

熱門開源項目推薦

熱門開源項目推薦&#xff1a;探索未來的技術前沿 開源軟件的興起為科技領域帶來了革命性的變化&#xff0c;不僅促進了技術的發展&#xff0c;還創造了一個開放和協作的環境&#xff0c;讓全球的開發者可以共同參與、創新和改進。近年來&#xff0c;開源大模型成為了技術社區…

时政|连续高温

危害 会对人的健康乃至生命安全产生严重影响&#xff0c;近年来&#xff0c;几乎每年都有因热致死的病例面对高温天气&#xff0c;不能仅仅止于调侃“天热”&#xff0c;止于变着花样表达自己的感受&#xff0c;还是要提高警惕&#xff0c;重视并防范高温导致的中暑、热痉挛、…

nginx+tomcat+nfs →web集群部署

nginxtomcatnfs →web集群部署 一.安装前介绍 NGINX是一个高性能的Web服务器和反向代理服务器。它能够处理静态内容&#xff0c;缓存请求结果&#xff0c;以及将请求转发给后端服务器。通过反向代理&#xff0c;NGINX能够实现请求的负载均衡、安全性增强、SSL加密等功能。此外…

Linux中文件查找相关命令比较

Linux中与文件定位的命令有find、locate、whereis、which&#xff0c;type。 一、find find命令最强&#xff0c;能搜索各种场景下的文件&#xff0c;需要配合相关参数&#xff0c;搜索速度慢。在文件系统中递归查找文件。 find /path/to/search -name "filename"…

第67集《摄大乘论》

《摄大乘论》&#xff0c;和尚尼慈悲、诸位法师、诸位居士&#xff0c;阿弥陀佛&#xff01;(阿弥陀佛&#xff01;)请大家打开《讲义》第二二六页&#xff0c;庚十、业。 这一大科是讲到法身的功德。我们从前面的学习&#xff0c;可以把法身的功德分两部分来作个总结&#xf…

位运算算法:编程世界中的魔法符号

✨✨✨学习的道路很枯燥&#xff0c;希望我们能并肩走下来! 文章目录 目录 文章目录 前言 一. 常见位运算总结 二、常见位运算题目 2.1 位1的个数 2.2 比特数记位&#xff08;典型dp&#xff09; 2.3 汉明距离 2.4 只出现一次的数字&#xff08;1&#xff09; 2.5 只出…

【JVM】CMS 收集器的垃圾收集过程

CMS&#xff08;Concurrent Mark-Sweep&#xff09;收集器是Java虚拟机&#xff08;JVM&#xff09;中的一种垃圾收集器&#xff0c;它主要面向老年代&#xff08;Old Generation&#xff09;的垃圾回收。CMS收集器的目标是最小化垃圾收集的停顿时间&#xff0c;从而提高应用程…

OpenGL系列(六)变换

在三角形和纹理贴图示例中&#xff0c;顶点使用的是归一化设备坐标&#xff0c;在该坐标系下&#xff0c;顶点的每个轴的取值为-1到1&#xff0c;超出范围的顶点不可见。 基于归一化设备坐标的物体的形状随着设备的大小变换而变化&#xff0c;这里产生的第一个问题是&#xff0…