图神经网络实战(17)——深度图生成模型
- 0. 前言
- 1. 变分图自编码器
- 2. 自回归模型
- 3. 生成对抗网络
- 小结
- 系列链接
0. 前言
我们已经学习了经典的图生成算法,虽然它们能够完成图生成任务,但也存在一些问题,促使基于图神经网络 (Graph Neural Networks, GNN) 的图生成技术的出现。深度图生成模型基于 GNN
架构,比传统技术更具表达能力。然而,缺点在于它们往往过于复杂,无法像经典方法那样进行分析和理解。主要的深度生成模型架构包括:变分自编码器 (Variational Autoencoder, VAE)、生成对抗网络 (Generative Adversarial Network, GAN)、自回归模型 (Autoregressive Model)、归一化流模型 (Normalizing Flow Model) 或扩散模型 (Diffusion Model) 等,但相较而言,前三种模型更加成熟。在本节中,将介绍三类图生成模型:基于变分自编码器 (Variational Autoencoder
, VAE
) 的模型、基于自回归模型 (Autoregressive Model
) 和基于生成对抗网络 (Generative Adversarial Network
, GAN
) 的模型。
1. 变分图自编码器
我们已经知道变分自编码器 (Variational Autoencoder, VAE)可用于近似邻接矩阵,而变分图自编码器 (Variational Graph Autoencoder, VGAE) 模型由两个部分组成:编码器和解码器。编码器使用共享第一层的两个图卷积网络 (Graph Convolutional Network, GCN) 来学习每个潜正态分布的均值和方差。然后,解码器对学习到的分布进行采样,执行潜变量之间的内积。最后,得到了近似邻接矩阵 A ^ = σ ( Z T Z ) \hat A = σ(Z^TZ) A^=σ(ZTZ)。
在使用图神经网络预测链接一节中,使用 A ^ \hat A A^ 来预测链接。然而,这并不是它的唯一应用,它可以直接给出一个网络的邻接矩阵,模仿训练过程中所看到的图。除了预测链接之外,我们还可以使用这个输出来生成新的图。以下是由 VGAE
模型创建的邻接矩阵的示例:
import torch
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoiddevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')transform = T.Compose([T.NormalizeFeatures(),T.ToDevice(device),T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True, split_labels=True, add_negative_train_samples=False),
])dataset = Planetoid('.', name='Cora', transform=transform)train_data, val_data, test_data = dataset[0]from torch_geometric.nn import GCNConv, VGAEclass Encoder(torch.nn.Module):def __init__(self, dim_in, dim_out):super().__init__()self.conv1 = GCNConv(dim_in, 2 * dim_out)self.conv_mu = GCNConv(2 * dim_out, dim_out)self.conv_logstd = GCNConv(2 * dim_out, dim_out)def forward(self, x, edge_index):x = self.conv1(x, edge_index).relu()return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)model = VGAE(Encoder(dataset.num_features, 16)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)def train():model.train()optimizer.zero_grad()z = model.encode(train_data.x, train_data.edge_index)loss = model.recon_loss(z, train_data.pos_edge_label_index) + (1 / train_data.num_nodes) * model.kl_loss()loss.backward()optimizer.step()return float(loss)@torch.no_grad()
def test(data):model.eval()z = model.encode(data.x, data.edge_index)return model.test(z, data.pos_edge_label_index, data.neg_edge_label_index)for epoch in range(301):loss = train()val_auc, val_ap = test(val_data)if epoch % 50 == 0:print(f'Epoch: {epoch:>3} | Val AUC: {val_auc:.4f} | Val AP: {val_ap:.4f}')val_auc, val_ap = test(val_data)
print(f'\nTest AUC: {val_auc:.4f} | Test AP: {val_ap:.4f}')z = model.encode(test_data.x, test_data.edge_index)
adj = torch.where((z @ z.T) > 0.9, 1, 0)
print(adj)Epoch: 0 | Val AUC: 0.7145 | Val AP: 0.7259
Epoch: 50 | Val AUC: 0.7030 | Val AP: 0.7175
Epoch: 100 | Val AUC: 0.7359 | Val AP: 0.7561
Epoch: 150 | Val AUC: 0.8535 | Val AP: 0.8618
Epoch: 200 | Val AUC: 0.8942 | Val AP: 0.8978
Epoch: 250 | Val AUC: 0.9005 | Val AP: 0.9050
Epoch: 300 | Val AUC: 0.9101 | Val AP: 0.9138
'''
Test AUC: 0.9101 | Test AP: 0.9138
tensor([[1, 0, 1, ..., 0, 1, 1],[0, 1, 1, ..., 0, 1, 1],[1, 1, 1, ..., 0, 1, 1],...,[0, 0, 0, ..., 0, 0, 0],[1, 1, 1, ..., 0, 1, 1],[1, 1, 1, ..., 0, 1, 1]], device='cuda:0')
'''
这种技术已扩展到 VGAE
模型之外,能够输出节点和边的特征。GraphVAE
是最流行的基于 VAE
的图生成模型之一,该模型于 2018
年由 Simonovsky
和 Komodakis
提出,旨在生成逼真的分子,这需要具备区分节点(原子)和边(化学键)的能力。
GraphVAE
考虑图 G = ( A , E , F ) G= (A, E, F) G=(A,E,F) ,其中 A A A 是邻接矩阵, E E E 是边属性矩阵, F F F 是节点属性矩阵。GraphVAE
学习了具有预定节点数的图 G ~ = ( A ~ , E ~ , F ~ ) \widetilde G=(\widetilde A,\widetilde E,\widetilde F) G =(A ,E ,F ) 的概率,其中 A ~ \widetilde A A 包含节点 ( A ~ a , a ) (\widetilde A_{a, a}) (A a,a) 和边 ( A ~ a , b ) (\widetilde A_{a,b}) (A a,b) 的概率, E ~ \widetilde E E 表示边的类别概率, F ~ \widetilde F F 包含节点的类别概率。与 VGAE
相比,GraphVAE
的编码器是一个具有边条件图卷积 (conditional graph convolutions
, ECC
) 的前馈网络,其解码器是一个具有三个输出的多层感知机 (multilayer perceptron
, MLP
),整体架构如下所示:
还有许多其它基于 VAE
的图生成架构,但它们的作用并不局限于模仿图,还可以添加约束条件,引导生成的图类型:
- 添加约束的一种常用方法是在解码阶段进行检查,如约束图变分自编码器 (
Constrained Graph Variational Autoencoder
,CGVAE
)。在此架构中,编码器是一个门控图卷积网络 (Gated Graph Convolutional Network
,GGCN
),解码器是一个自回归模型。自回归解码器可以验证整个过程中每个步骤的每个约束条件 - 另一种添加约束条件的技术是使用基于
Lagrangian
的正则化器,这种正则化器计算速度更快,但生成的约束条件并不那么严格
2. 自回归模型
自回归模型 (Autoregressive Model) 也可以单独使用,自回归模型与其他模型的区别在于,模型过去的输出会作为当前输入的一部分。在此框架下,图生成成为一个连续的决策过程,既要考虑数据,又要考虑过去的决策。例如,在每一步中,自回归模型可以创建一个新节点或新链接,然后,生成的图被输入到模型中用于下一步生成,直到达到停止条件。这一过程如下图所示:
在实践中,可以使用循环神经网络 (Recurrent Neural Network, RNN) 来实现这种自回归模型。在 RNN
架构中,先前的输出被用作计算当前隐藏状态的输入。此外,RNN
还能处理任意长度的输入,这对于迭代生成图至关重要。但这种架构的计算比前馈网络慢,因为必须处理整个序列才能获得最终输出。最流行的两种 RNN
为门控递归单元 (Gated Recurrent Unit
, GRU
) 和长短期记忆 (Long Short-Term Memory
, LSTM
) 网络。
2018
年 You
等人提出了 GraphRNN
,是自回归模型在深度图生成方面的直接实现。该架构使用两个 RNN
:
- 一个图级
RNN
,用于生成节点序列(包括初始状态) - 一个边级
RNN
,用于预测每个新添加节点的连接情况
边级 RNN
将图级 RNN
的隐藏状态作为输入,然后使用自己的输出。下图展示了模型推理时的生成机制:
两个 RNN
实际上是在完成一个邻接矩阵,图级 RNN
创建的每个新节点都会增加一行和一列,而边级 RNN
会用 0
或 1
进行填充。总体而言,GraphRNN 执行以下步骤:
- 添加新节点:图级
RNN
对图进行初始化,并将其输出反馈给边级RNN
- 添加新连接:边级
RNN
会预测新节点是否与之前的每个节点相连 - 停止图生成:重复前两个步骤,直到边级
RNN
输出EOS
标记,标志着生成过程结束
GraphRNN
可以学习不同类型的图(网格、社交网络、蛋白质等),其性能完全优于传统技术。与 GraphVAE
相比,GraphRNN
是模仿给定图的首选架构。
3. 生成对抗网络
与变分自编码器 (Variational Autoencoder, VAE) 一样,生成对抗网络 (Generative Adversarial Network, GAN) 也是机器学习 (Machine Learning
, ML
) 中著名的生成模型。在 GAN
框架中,两个神经网络在零和博弈中以不同目标展开竞争。第一个神经网络是生成器 (generator
),负责创建新数据;第二个神经网络是判别器 (discriminator
),负责将每个样本分为真实样本(来自训练集)或虚假样本(由生成器创建)。
为了提升模型性能,研究人员提出了多种改进原始架构的方案。Wasserstein GAN (WGAN) 通过最小化两个概率分布之间的 Wasserstein
距离(或称推土机距离)来提高训练的稳定性。这一改进可以通过引入梯度惩罚而非原始梯度剪切进一步进行完善。
将这一框架应用于深度图生成中,与其它技术一样,GAN
可以生成图以优化某些约束条件,这在寻找具有特定性质的新化合物等应用中非常有效。由于其离散性,这一问题异常庞大且复杂。
小结
图生成是生成新图的技术,并且希望所生成的图具有真实世界中图的性质。由于传统图生成方法缺乏表达能力,因此提出了更加灵活的基于 GNN
的技术。本节中,我们介绍了三类深度图生成模型: 基于变分自编码器 (Variational Autoencoder
, VAE
) 的模型、基于自回归模型 (Autoregressive Model
) 和基于生成对抗网络 (Generative Adversarial Network
, GAN
) 的模型。
系列链接
图神经网络实战(1)——图神经网络(Graph Neural Networks, GNN)基础
图神经网络实战(2)——图论基础
图神经网络实战(3)——基于DeepWalk创建节点表示
图神经网络实战(4)——基于Node2Vec改进嵌入质量
图神经网络实战(5)——常用图数据集
图神经网络实战(6)——使用PyTorch构建图神经网络
图神经网络实战(7)——图卷积网络(Graph Convolutional Network, GCN)详解与实现
图神经网络实战(8)——图注意力网络(Graph Attention Networks, GAT)
图神经网络实战(9)——GraphSAGE详解与实现
图神经网络实战(10)——归纳学习
图神经网络实战(11)——Weisfeiler-Leman测试
图神经网络实战(12)——图同构网络(Graph Isomorphism Network, GIN)
图神经网络实战(13)——经典链接预测算法
图神经网络实战(14)——基于节点嵌入预测链接
图神经网络实战(15)——SEAL链接预测算法
图神经网络实战(16)——经典图生成算法