机器学习第四十九周周报 GT

文章目录

  • week49 GY
  • 摘要
  • Abstract
    • 1. 题目
    • 2. Abstract
    • 3. 网络结构
      • 3.1 graphon
      • 3.2 框架概览
    • 4. 文献解读
      • 4.1 Introduction
      • 4.2 创新点
      • 4.3 实验过程
        • 4.3.1 有效性
        • 4.3.2 可转移性
        • 4.3.3 消融研究
        • 4.3.4 运行时间
    • 5. 结论
    • 6.代码复现
      • 小结
      • 参考文献

week49 GY

摘要

本周阅读了题为Fine-tuning Graph Neural Networks by Preserving Graph Generative Patterns的论文。该文将图的预训练性能不理想归因于预训练与下游数据集之间的结构分歧。此外,研究将这种差异的原因确定为预训练和下游图之间生成模式的差异。基于理论分析,该文提出了一种基于graphon的GNN微调策略G-TUNING,以使预训练的模型适应下游数据集。最后,实证证明了G-TUNING的有效性。

Abstract

This week’s weekly newspaper decodes the paper entitled Fine-tuning Graph Neural Networks by Preserving Graph Generative Patterns. This paper attributes the unsatisfactory pre-training performance of graphs to the structural discrepancy between pre-training and downstream datasets. Furthermore, the study identifies the cause of this disparity as the difference in generative patterns between pre-trained and downstream graphs. Based on theoretical analysis, this paper proposes a graphon-based GNN fine-tuning strategy named G-TUNING to adapt the pre-trained model to downstream datasets. Finally, empirical evidence demonstrates the effectiveness of G-TUNING.

1. 题目

标题:Fine-tuning Graph Neural Networks by Preserving Graph Generative Patterns

作者:Yifei Sun, Qi Zhu, Yang Yang, Chunping Wang, Tianyu Fan, Jiajun Zhu, Lei Chen

发布:AAAI2024

链接:https://export.arxiv.org/abs/2312.13583

代码链接:https://github.com/zjunet/G-Tuning

2. Abstract

该文认为结构差异的根本原因是预训练图下游图之间生成模式的差异。此外,提出了G-TUNING来保留下游图的生成模式。给定一个下游图G,核心思想是调整预训练的GNN,以便它可以重建G的生成模式,即图元w。然而,已知图元的精确重建在计算上是昂贵的。为了克服这一挑战,提供了一个理论分析,建立了一组称为graphon bases的替代graphon的存在。通过利用这些graphon bases的线性组合,可以有效地近似w。这一理论发现构成了提出模型的基础,因为它可以有效地学习graphon bases及其相关系数。与现有算法相比,G-TUNING在域内和域外迁移学习实验上分别平均提高了0.5%和2.6%。

3. 网络结构

image-20240727101959921

3.1 graphon

graphon是“图函数”的缩写,可以解释为具有不可数节点数的图或图生成模型的概括。

表示图生成模式 P ( G ; Θ ) P(G;Θ) P(G;Θ)。形式上,图元是一个连续对称函数 W : [ 0 , 1 ] 2 → [ 0 , 1 ] W:[0,1]2→[0,1] W:[0,1]2[0,1]。给定两个点 u i , u j ∈ [ 0 , 1 ] u_i, u_j∈[0,1] ui,uj[0,1]作为“节点”, W ( i , j ) ∈ [ 0 , 1 ] W(i, j)∈[0,1] W(i,j)[0,1]​表示它们形成一条边的概率。

graphon的主要思想是,当从观察到的图中提取子图时,随着子图大小的增加,这些子图的结构与观察到的图的结构越来越相似。然后结构在某种意义上收敛到一个极限对象,graphon。收敛性通过同态密度的收敛性来定义。同态密度 t ( F , G ) t(F, G) t(F,G)用来度量图F在图G中同态出现的相对频率: t ( F , G ) = ∣ h m ( F , G ) ∣ ∣ V G ∣ ∣ V H ∣ t(F, G) = |h_m (F,G)| |VG||VH| t(F,G)=hm(F,G)∣∣VG∣∣VH,可以看作是顶点从F到G的随机映射是同态的概率。因此,收敛性可以形式化为 l i m n → ∞ t ( F , G n ) = t ( F , W ) lim_{n→∞}t(F, G_n) = t(F,W) limnt(F,Gn)=t(F,W)。当作为图的生成模式时,从 P ( G ; W ) P(G;W) P(G;W)​中抽取N个节点的图G的邻接矩阵A如下:
v ∼ U ( 0 , 1 ) , v ∈ V ; A i j ∼ Ber ( W ( v i , v j ) ) , ∀ i , j ∈ [ N ] v\sim \mathbb U(0,1), v\in V;A_{ij}\sim \text{Ber}(W(v_i,v_j)),\forall i,j\in [N] vU(0,1),vV;AijBer(W(vi,vj)),i,j[N]
现有的研究主要采用二维阶跃函数来表示graphon,该阶跃函数可以看作矩阵

根据上述工作,采用阶跃函数 W ∈ [ 0 , 1 ] D × D W∈[0,1]^{D×D} W[0,1]D×D来表示一个graphon,其中D是一个超参数。

3.2 框架概览

G-TUNING旨在通过保留生成模式来使预训练的GNN适应微调图。在微调过程中,预训练的GNN Φ获得下游图 G t = G 1 , … , G n G_t = {G_1,…, G_n} Gt=G1Gn,并将它们馈送到任务特定层 f φ f_φ fφ中,用微调标签y进行训练。对于特定图 G i ( a , X ) G_i(a,X) Gi(a,X)​,通过预训练模型Φ获得预训练节点嵌入H:
L t a s k = L C E ( f ϕ ( H ) , Y ) H = Φ ( A , X ) \mathcal L_{task}=\mathcal L_{CE}(f_{\phi}(H),Y)\quad H=\Phi(A,X) Ltask=LCE(fϕ(H),Y)H=Φ(A,X)
原有策略可能无法提高微调性能,因为预训练和微调图之间存在较大差异,即负迁移。为了缓解这一问题,建议通过重建下游图W(图2中的整体工作流程),使预训练的GNNΦ能够保留下游图Gt的生成模式。在微调开始时,嵌入H也包含来自预训练数据的偏差。故需要下游图的H和图结构A来重建图。

具体来说,设计了一个graphon重建模块Ω来重建。因此,graphon重构模块Ω通过 L a u x \mathcal L_{aux} Laux对每个下游图逼近一个估计的oracle graphon(即 W ∈ [ 0 , 1 ] D × D W∈[0,1]D×D W[0,1]D×D), D为oracle graphon的大小。最后,在G-TUNING的框架下(图2),利用下游任务损失和重构损失来优化预训练的GNN编码器Φ、fϕ层和graphon重构模块Ω的参数,如下所示:
L = L t a s k + λ L G-TUNING ( W , W ^ ) \mathcal L=\mathcal L_{task}+\lambda \mathcal L_{\text{G-TUNING}}(W,\hat W) L=Ltask+λLG-TUNING(W,W^)
近似graphon的一种直接的方法是学习一个映射函数图结构A和节点嵌入H到目标w。

首先建立了一个graphon分解定理,并利用它进行高效的graphon近似。具体来说,提出任何graphon都可以通过graphon base B k ∈ B Bk∈B BkB​的线性组合来重构。

综上所述,设计了图形重构模块Ω作为另一个GNN,将编码节点表示H和图形结构A转换为系数 α = { α 1 , … α C } α = \{α1,…α_C\} α={α1αC}:
α = Ψ ( A , H ) \alpha =\Psi (A,H) α=Ψ(A,H)
复杂性分析。现在分析G-TUNING除了普通调优之外的额外时间复杂度。设|V|和|E|为节点和边的平均数目,d为隐藏维数,C为graphon基数。G-TUNING的总时间复杂度包括两个部分:(i)石墨解码器耗时 O ( C M 2 ∣ V ∣ d ) O(CM^2 |V|d) O(CM2Vd);(ii) oracle graphon估计成本 O ( ∣ E ∣ D ∣ V ∣ D 2 ) O(|E|D |V|D^2) O(EDVD2) ,其中D为oracle graphon的大小。因此,总体的额外时间复杂度为 O ( ∣ E ∣ D ∣ V ∣ D 2 C M 2 ∣ V ∣ D ) O(|E|D |V|D^2 CM^2 |V|D) O(EDVD2CM2VD),假设M,D≪|V|,这与普通调谐过程的 O ( ∣ E ∣ D ∣ V ∣ D ) O(|E|D |V|D) O(EDVD)是相同的数量级。

4. 文献解读

4.1 Introduction

在本文中,目标是通过提出一种微调策略G-TUNING来解决这些挑战,该策略与预训练数据和算法无关。具体来说,它在调优期间执行下游图的图重建。为了实现有效的重建,提供了一个理论结果(定理1),即给定一个图元W,有可能找到一组其他的图元,称为图元基,其线性组合可以接近W。然后,开发了一个图元解码器,将嵌入从预训练模型转换为一组系数。这些系数与结构感知的可学习基相结合,形成重构图。为了确保重建图形的保真度,引入了基于GW差异的损耗,从而最小化了近似图形与oracle graphon之间的距离(Xu et al. 2021)。此外,通过优化提出的G-TUNING,我们获得了与任务相关的判别子图的可证明结果(定理2)。

4.2 创新点

该文的主要贡献有四个方面:

  1. 确定下游图的生成模式是弥合预训练和微调之间差距的关键步骤。
  2. 基于理论结果,设计了模型架构G-TUNING,以有效地将graphon重构为具有严格泛化结果的生成模式。
  3. 从经验上看,该方法在8个域内和7个域外迁移学习数据集上比最佳基线平均提高了0.47%和2.62%。

4.3 实验过程

目标是在两种设置下实际评估G-TUNING在15个数据集上的性能。

具体来说,回答以下问题:

  • (有效性)G-TUNING是否提高了微调的性能
  • (可转移性)G-TUNING能比基线更好地实现可转移性吗
  • (完整性)G-TUNING的每个组成部分对性能的贡献是什么
  • (效率)G-TUNING能否在可接受的时间消耗下提高微调的性能

基线。有大量的GNN预训练方法,但只有少数的微调策略可用。

  • 几个最初为cnn设计的代表性基线,包括StochNorm, DELTA和固定注意力系数的版本(Feature-Map), L2_SP和BSS。
  • 一个基线致力于改善gnn的微调,这与预训练策略无关,即gtt -tuning
  • 为了验证重建图的有效性,引入了VGAE- tuning进行比较,该方法使用VGAE 作为辅助损失来重建下游图的邻接矩阵。

根据作者发布的代码重现基线,并根据他们发布的代码和他们论文中描述的设置设置超参数。

4.3.1 有效性

评估了G-TUNING在分子性质预测任务上的有效性。使用无监督上下文预测任务预训练的模型作为骨干模型。具体来说,在带有200万个未标记分子的ZINC15数据集(Sterling and Irwin 2015)上通过自监督上下文预测任务预训练GIN (Xu et al. 2019)。接下来,对从MoleculeNet (Wu et al. 2018)获得的8个二元分类数据集进行骨架模型的微调。用的是8:1:1比例的支架。由于框架对骨干gnn是不可知的,专注于评估我们的模型是否达到更好的微调结果。

对于每个数据集,运行5次,并报告具有相应标准差的平均ROCAUC。

image-20240727110155957

表1显示,G-TUNING在8个基线数据集中获得了6个最佳性能,平均排名最高。注意到,从预训练模型(如Feature-Map或DELTA)中约束嵌入有时会带来比普通调优更差的性能。从监督学习和G-TUNING w/o Pre之间的比较来看,尽管在将G-TUNING损失应用于监督学习时可能偶尔会出现性能轻微下降的情况,但大多数监督训练经验都受益于G-TUNING损失。从原型调优和表的最后两行比较可以看出,未经预训练的G-TUNING的性能低于经过预训练的G-TUNING,但有时优于原型调优。结果一般证明,在数据集来自相同域的情况下,G-TUNING可以通过保留生成模式来补偿结构差异,从而获得更好的性能。

4.3.2 可转移性

在跨域设置中评估G-TUNING,其中预训练数据集和下游数据集不是来自同一域。较大的结构差异反过来会降低性能。故采用GCC (Qiu et al. 2020)作为主干模型,其子图判别作为预训练任务。根据GCC的设置,从7个不同数据集上进行预训练,并在7个下游图分类基准上评估方法:IMDB-M, IMDB-B, MUTAG, PROTEINS, ENZYMES, MSRC_21和来自tudatasset的RDT-M12K (Morris et al. 2020)。这些数据集涵盖了广泛的领域。报告了10倍交叉验证的结果。

image-20240727110451229

从表2中,发现模型在7个数据集中的6个上优于所有基线,并且在MUTAG上呈现出具有竞争力的结果(比最好的低1.92%)。与原型调优和第二好的基线相比,G-TUNING分别提高了7.63%和4.71%的蛋白质性能。与之前的实验相比,可以观察到G-TUNING有了更实质性的改进(表1),因为我们明确地保留了生成模式。尽管GTOT也包含结构信息,但它有时甚至比普通调优(即蛋白质和酶)的性能更差。一般来说,当预训练和微调图显示出较大的结构差异时,G-TUNING清楚地表明了它的有效性。

4.3.3 消融研究

image-20240727111050629

  • 首先,通过与直接重建图元(Direct-Rec)进行比较,来检验所提出的图元分解方法的有效性。结果表明,原型调优之外的改进是有限的,在某些情况下甚至存在负迁移。原因可能是直接重建复杂的语义信息和从A和h中捕获graphon属性的困难。在四个数据集中,观察到G-TUNING总是优于“Direct-Rec”。
  • 接下来,比较了不同的GNN架构(两层MLP, GCN(Welling and Kipf 2016), GraphSAGE(Hamilton, Ying, and Leskovec 2017)和GAT (Veli ckovi等人2018))与默认主干(即GIN (Xu等人2019))。在图3中,观察到MLPencoder表现最差,这证明了结合结构信息重构graphon的有效性。
  • 最后,用KL散度、Wasserstein距离和余弦相似度代替损失。可以观察到,GW差异损失显著优于其他。认为余弦相似度可能对指边的概率的绝对值不敏感。由于KL散度不满足交换律,重构图时难以收敛。虽然Wasserstein距离也是基于最优输运,但它无法捕捉到两个graphon之间的几何形状。

image-20240727111120077

超参研究:还研究了2 ~ 512的不同graphon bases的影响。更多的基数可以表示更多的信息,并且可以更好地近似oracle graphon。图4显示,当基数从2个增加到32个时,性能有所提高。然而,当数量继续增加时,改善变得越来越小。将这种现象归因于参数数量增加带来的优化难度。此外,随着碱基数量的增加,G-TUNING的运行时间呈指数增长(绿色曲线)。因此,G-TUNING只需要少量的基数就可以提高微调性能。

4.3.4 运行时间
  1. 进行了运行时间对比(表3)。表3中给出了该文方法与基线的运行时间对比。时间复杂度主要由两部分组成:(i)预训练模型的graphon逼近和(ii) oracle graphon估计。
  2. 现在报告了在域内设置下数据集上我们的方法和基线的计算效率,而不会失去一般性(每个训练epoch的秒数)。

image-20240727111435580

从表3中,可以看到G-TUNING在大多数情况下并不是最慢的调优方法。可以观察到,该方法的时间接近于vgae-tuning。正如第一行图的数量所示,很明显,随着图数量的增加,该方法的时间消耗变得与其他基线更具可比性。这意味着该方法具有出色的可扩展性。因此,该方法的时间消耗保持在一个可接受的范围内。

5. 结论

在该文中,将图的预训练性能不理想归因于预训练与下游数据集之间的结构分歧。此外,将这种差异的原因确定为预训练和下游图之间生成模式的差异。基于理论分析,提出了一种基于graphon的GNN微调策略G-TUNING,以使预训练的模型适应下游数据集。最后,实证证明了G-TUNING的有效性。

6.代码复现

https://github.com/zjunet/G-Tuning

Mole-based model

import randomfrom ot_distance import sliced_fgw_distance, fgw_distance
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree, softmax, subgraph
import torch_geometric.utils as PyG_utils
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
import torch.nn.functional as F
import torch.nn as nn
# from torch_geometric.nn.conv import GATConv
from torch_scatter import scatter_add
from torch_geometric.nn.inits import glorot, zerosfrom torch_geometric.datasets import TUDataset
from abc import ABCnum_atom_type = 120  # including the extra mask tokens
num_chirality_tag = 3num_bond_type = 6  # including aromatic and self-loop edge, and extra masked tokens
num_bond_direction = 3class GINConv(MessagePassing):"""Extension of GIN aggregation to incorporate edge information by concatenation.Args:emb_dim (int): dimensionality of embeddings for nodes and edges.embed_input (bool): whether to embed input or not.See https://arxiv.org/abs/1810.00826"""def __init__(self, emb_dim, aggr="add"):super(GINConv, self).__init__()# multi-layer perceptronself.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.ReLU(),torch.nn.Linear(2 * emb_dim, emb_dim))self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)self.aggr = aggrdef forward(self, x, edge_index, edge_attr):# add self loops in the edge space# edge_index = add_self_loops(edge_index, num_nodes=x.size(0))# add features corresponding to self-loop edges.# self_loop_attr = torch.zeros(edge_index[0].size(0), 2)# self_loop_attr[:, 0] = 4  # bond type for self-loop edge# self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)# edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0)edge_embeddings = self.edge_embedding1(edge_attr[:, 0]) + self.edge_embedding2(edge_attr[:, 1])try:  # PyG 1.6.return self.propagate(edge_index, x=x, edge_attr=edge_embeddings)except:  # PyG 1.0.3return self.propagate(self.aggr, edge_index, x=x, edge_attr=edge_embeddings)def message(self, x_j, edge_attr):return x_j + edge_attrdef update(self, aggr_out):return self.mlp(aggr_out)class GCNConv(MessagePassing):def __init__(self, emb_dim, aggr="add"):super(GCNConv, self).__init__()self.emb_dim = emb_dimself.linear = torch.nn.Linear(emb_dim, emb_dim)self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)self.aggr = aggrdef norm(self, edge_index, num_nodes, dtype):### assuming that self-loops have been already added in edge_indexedge_weight = torch.ones((edge_index.size(1),), dtype=dtype,device=edge_index.device)row, col = edge_indexdeg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)deg_inv_sqrt = deg.pow(-0.5)deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0return deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]def forward(self, x, edge_index, edge_attr):# add self loops in the edge space# edge_index, edge_weight = add_self_loops(edge_index, num_nodes=x.size(0)) pyg 1.6edge_index = add_self_loops(edge_index, num_nodes=x.size(0))# add features corresponding to self-loop edges.self_loop_attr = torch.zeros(x.size(0), 2)self_loop_attr[:, 0] = 4  # bond type for self-loop edgeself_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0)edge_embeddings = self.edge_embedding1(edge_attr[:, 0]) + self.edge_embedding2(edge_attr[:, 1])try:norm = self.norm(edge_index, x.size(0), x.dtype)except:norm = self.norm(edge_index[0], x.size(0), x.dtype)x = self.linear(x)try:  # PyG 1.6.return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings, norm=norm)except:  # PyG 1.0.3return self.propagate(self.aggr, edge_index, x=x, edge_attr=edge_embeddings, norm=norm)# return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings, norm=norm)# return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings, norm = norm)def message(self, x_j, edge_attr, norm):return norm.view(-1, 1) * (x_j + edge_attr)class GATConv(MessagePassing):def __init__(self, emb_dim, heads=2, negative_slope=0.2, aggr="add"):super(GATConv, self).__init__()self.aggr = aggrself.emb_dim = emb_dimself.heads = headsself.negative_slope = negative_slopeself.weight_linear = torch.nn.Linear(emb_dim, heads * emb_dim)self.att = torch.nn.Parameter(torch.Tensor(1, heads, 2 * emb_dim))self.bias = torch.nn.Parameter(torch.Tensor(emb_dim))self.edge_embedding1 = torch.nn.Embedding(num_bond_type, heads * emb_dim)self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, heads * emb_dim)torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)self.reset_parameters()def reset_parameters(self):glorot(self.att)zeros(self.bias)def forward(self, x, edge_index, edge_attr):# add self loops in the edge spaceedge_ind = add_self_loops(edge_index, num_nodes=x.size(0))# add features corresponding to self-loop edges.self_loop_attr = torch.zeros(x.size(0), 2)self_loop_attr[:, 0] = 4  # bond type for self-loop edgeself_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0)edge_embeddings = self.edge_embedding1(edge_attr[:, 0]) + self.edge_embedding2(edge_attr[:, 1])x = self.weight_linear(x).view(-1, self.heads, self.emb_dim)# edge_ind = edge_ind[0]print("edge_index", edge_ind)return self.propagate(self.aggr, edge_index=edge_ind, x=x, edge_attr=edge_embeddings)def message(self, edge_index, x_i, x_j, edge_attr):edge_attr = edge_attr.view(-1, self.heads, self.emb_dim)x_j += edge_attralpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)alpha = F.leaky_relu(alpha, self.negative_slope)alpha = softmax(alpha, edge_index[0])return x_j * alpha.view(-1, self.heads, 1)def update(self, aggr_out):aggr_out = aggr_out.mean(dim=1)aggr_out = aggr_out + self.biasreturn aggr_outclass GraphSAGEConv(MessagePassing):def __init__(self, emb_dim, aggr="mean"):super(GraphSAGEConv, self).__init__()self.emb_dim = emb_dimself.linear = torch.nn.Linear(emb_dim, emb_dim)self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)self.aggr = aggrdef forward(self, x, edge_index, edge_attr):# add self loops in the edge spaceedge_index = add_self_loops(edge_index, num_nodes=x.size(0))# add features corresponding to self-loop edges.self_loop_attr = torch.zeros(x.size(0), 2)self_loop_attr[:, 0] = 4  # bond type for self-loop edgeself_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0)edge_embeddings = self.edge_embedding1(edge_attr[:, 0]) + self.edge_embedding2(edge_attr[:, 1])x = self.linear(x)try:  # PyG 1.6.return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings)except:  # PyG 1.0.3return self.propagate(self.aggr, edge_index, x=x, edge_attr=edge_embeddings)def message(self, x_j, edge_attr):return x_j + edge_attrdef update(self, aggr_out):return F.normalize(aggr_out, p=2, dim=-1)class GNN(torch.nn.Module):"""Args:num_layer (int): the number of GNN layersemb_dim (int): dimensionality of embeddingsJK (str): last, concat, max or sum.max_pool_layer (int): the layer from which we use max pool rather than add pool for neighbor aggregationdrop_ratio (float): dropout rategnn_type: gin, gcn, graphsage, gatOutput:node representations"""def __init__(self, num_layer, emb_dim, JK="last", drop_ratio=0, gnn_type="gin"):super(GNN, self).__init__()self.num_layer = num_layerself.drop_ratio = drop_ratioself.gnn_type = gnn_typeself.JK = JKif self.num_layer < 2:raise ValueError("Number of GNN layers must be greater than 1.")self.x_embedding1 = torch.nn.Embedding(num_atom_type, emb_dim)self.x_embedding2 = torch.nn.Embedding(num_chirality_tag, emb_dim)torch.nn.init.xavier_uniform_(self.x_embedding1.weight.data)torch.nn.init.xavier_uniform_(self.x_embedding2.weight.data)###List of MLPsself.gnns = torch.nn.ModuleList()for layer in range(num_layer):if gnn_type == "gin":self.gnns.append(GINConv(emb_dim, aggr="add"))elif gnn_type == "gcn":self.gnns.append(GCNConv(emb_dim))elif gnn_type == "gat":self.gnns.append(GATConv(emb_dim))elif gnn_type == "graphsage":self.gnns.append(GraphSAGEConv(emb_dim))###List of batchnormsself.batch_norms = torch.nn.ModuleList()for layer in range(num_layer):self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))# def forward(self, x, edge_index, edge_attr):def forward(self, *argv):batch = Noneif len(argv) == 3:x, edge_index, edge_attr = argv[0], argv[1], argv[2]elif len(argv) == 1:data = argv[0]x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attrelif len(argv) == 4:x, edge_index, edge_attr, batch = argv[0], argv[1], argv[2], argv[3]else:raise ValueError("unmatched number of arguments.")x = self.x_embedding1(x[:, 0]) + self.x_embedding2(x[:, 1])h_list = [x]for layer in range(self.num_layer):h = self.gnns[layer](h_list[layer], edge_index, edge_attr)h = self.batch_norms[layer](h)# h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)if layer == self.num_layer - 1:# remove relu for the last layerh = F.dropout(h, self.drop_ratio, training=self.training)else:h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)h_list.append(h)### Different implementations of Jk-concatif self.JK == "concat":node_representation = torch.cat(h_list, dim=1)elif self.JK == "last":node_representation = h_list[-1]elif self.JK == "max":h_list = [h.unsqueeze_(0) for h in h_list]node_representation = torch.max(torch.cat(h_list, dim=0), dim=0)[0]elif self.JK == "sum":h_list = [h.unsqueeze_(0) for h in h_list]node_representation = torch.sum(torch.cat(h_list, dim=0), dim=0)[0]return node_representationclass GNN_graphpred(torch.nn.Module):"""Extension of GIN to incorporate edge information by concatenation.Args:num_layer (int): the number of GNN layersemb_dim (int): dimensionality of embeddingsnum_tasks (int): number of tasks in multi-task learning scenariodrop_ratio (float): dropout rateJK (str): last, concat, max or sum.graph_pooling (str): sum, mean, max, attention, set2setgnn_type: gin, gcn, graphsage, gatSee https://arxiv.org/abs/1810.00826JK-net: https://arxiv.org/abs/1806.03536"""def __init__(self, num_layer, emb_dim, num_tasks, JK="last", drop_ratio=0, graph_pooling="mean", gnn_type="gin",backbone=None, args=None):'''backbone is gnn default'''super(GNN_graphpred, self).__init__()self.num_layer = num_layerself.drop_ratio = drop_ratioself.JK = JKself.emb_dim = emb_dimself.num_tasks = num_tasksself.emb_f = Noneself.gnn_type = gnn_typeself.param_args = argsif self.num_layer < 2:raise ValueError("Number of GNN layers must be greater than 1.")if backbone is None:self.gnn = GNN(num_layer, emb_dim, JK, drop_ratio, gnn_type=gnn_type)else:self.gnn = backbone# self.backbone = self.gnn# Different kind of graph poolingif graph_pooling == "sum":self.pool = global_add_poolelif graph_pooling == "mean":self.pool = global_mean_poolelif graph_pooling == "max":self.pool = global_max_poolelif graph_pooling == "attention":if self.JK == "concat":self.pool = GlobalAttention(gate_nn=torch.nn.Linear((self.num_layer + 1) * emb_dim, 1))else:self.pool = GlobalAttention(gate_nn=torch.nn.Linear(emb_dim, 1))elif graph_pooling[:-1] == "set2set":set2set_iter = int(graph_pooling[-1])if self.JK == "concat":self.pool = Set2Set((self.num_layer + 1) * emb_dim, set2set_iter)else:self.pool = Set2Set(emb_dim, set2set_iter)else:raise ValueError("Invalid graph pooling type.")# For graph-level binary classificationif graph_pooling[:-1] == "set2set":self.mult = 2else:self.mult = 1if self.JK == "concat":self.graph_pred_linear = torch.nn.Linear(self.mult * (self.num_layer + 1) * self.emb_dim, self.num_tasks)else:self.graph_pred_linear = torch.nn.Linear(self.mult * self.emb_dim, self.num_tasks)def from_pretrained(self, model_file):# self.gnn = GNN(self.num_layer, self.emb_dim, JK = self.JK, drop_ratio = self.drop_ratio)model = torch.load(model_file, map_location='cpu')self.gnn.load_state_dict(model)  # self.args.device))def forward(self, *argv):if len(argv) == 4:x, edge_index, edge_attr, batch = argv[0], argv[1], argv[2], argv[3]elif len(argv) == 1:data = argv[0]x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batchelse:raise ValueError("unmatched number of arguments.")node_representation = self.gnn(x, edge_index, edge_attr)self.emb_f = self.pool(node_representation, batch)return self.graph_pred_linear(self.emb_f)def get_bottleneck(self):return self.emb_fclass GraphonEncoder(torch.nn.Module):def __init__(self, feature_length, hidden_size, out_size):super(GraphonEncoder, self).__init__()self.feature_length, self.hidden_size, self.out_size = feature_length, hidden_size, out_sizeself.fc1 = torch.nn.Linear(feature_length, hidden_size)self.fc2 = torch.nn.Linear(hidden_size, out_size)def forward(self, x):x = x.view(-1, self.feature_length)# print(x, x.shape)# print(edge_index, edge_index.shape)x = F.dropout(x, p=0.9, training=self.training)x = F.relu(self.fc1(x))x = F.dropout(x, p=0.9, training=self.training)x = self.fc2(x)return xdef sampling_gaussian(mu, logvar, num_sample):std = torch.exp(0.5 * logvar)samples = Nonefor i in range(num_sample):eps = torch.randn_like(std)if i == 0:samples = mu + eps * stdelse:samples = torch.cat((samples, mu + eps * std), dim=0)return samplesdef sampling_gmm(mu, logvar, num_sample):std = torch.exp(0.5 * logvar)n = int(num_sample / mu.size(0)) + 1samples = Nonefor i in range(n):eps = torch.randn_like(std)if i == 0:samples = mu + eps * stdelse:samples = torch.cat((samples, mu + eps * std), dim=0)return samples[:num_sample, :]class Prior(nn.Module, ABC):def __init__(self, data_size: list, prior_type: str = 'gmm'):super(Prior, self).__init__()# data_size = [num_component, z_dim]self.data_size = data_sizeself.number_components = data_size[0]self.output_size = data_size[1]self.prior_type = prior_typeif self.prior_type == 'gmm':self.mu = nn.Parameter(torch.randn(data_size), requires_grad=True)self.logvar = nn.Parameter(torch.randn(data_size), requires_grad=True)else:self.mu = nn.Parameter(torch.zeros(1, self.output_size), requires_grad=False)self.logvar = nn.Parameter(torch.ones(1, self.output_size), requires_grad=False)def forward(self):return self.mu, self.logvardef sampling(self, num_sample):if self.prior_type == 'gmm':return sampling_gmm(self.mu, self.logvar, num_sample)else:return sampling_gaussian(self.mu, self.logvar, num_sample)class GraphonNewEncoder(torch.nn.Module):def __init__(self, feature_length, hidden_size, out_size, encoder_type):super(GraphonNewEncoder, self).__init__()self.feature_length, self.hidden_size, self.out_size = feature_length, hidden_size, out_sizeself.encoder_type = encoder_typeself.fc1 = torch.nn.Linear(feature_length, hidden_size)self.fc2 = torch.nn.Linear(hidden_size, out_size)# self.gnn = GNN(2, feature_length, "last", 0.2, gnn_type='gin')self.gnns_en = torch.nn.ModuleList()for layer in range(2):if encoder_type == "gin":self.gnns_en.append(GINConv(feature_length, aggr="add"))elif encoder_type == "gcn":self.gnns_en.append(GCNConv(feature_length))elif encoder_type == "gat":self.gnns_en.append(GATConv(feature_length))elif encoder_type == "graphsage":self.gnns_en.append(GraphSAGEConv(feature_length))self.fc3 = torch.nn.Linear(feature_length, out_size)self.batch_norms = torch.nn.ModuleList()for layer in range(2):self.batch_norms.append(torch.nn.BatchNorm1d(feature_length))def forward(self, x, edge_index, edge_attr, batch):# node_representation = self.gnn(x, edge_index.long(), None)# if self.encoder_type == 'gat':#     print("before", edge_index.shape, edge_index.dtype, edge_index)#     edge_index = edge_index.type(torch.LongTensor)#     print("after", edge_index.shape, edge_index.dtype, edge_index)# print("self.encoder_type", self.encoder_type)if self.encoder_type != 'mlp':h_list = [x, ]for layer in range(2):inp = h_list[layer]h = self.gnns_en[layer](inp, edge_index, edge_attr)h = self.batch_norms[layer](h)# h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)if layer == 1:# remove relu for the last layerh = F.dropout(h, 0.2, training=self.training)else:h = F.dropout(F.relu(h), 0.2, training=self.training)h_list.append(h)x = h_list[-1] \+ h_list[0]x = global_mean_pool(x, batch)x = self.fc3(x)else:x = x.view(-1, self.feature_length)# print(x, x.shape)# print(edge_index, edge_index.shape)x = F.dropout(x, p=0.9, training=self.training)x = F.relu(self.fc1(x))x = F.dropout(x, p=0.9, training=self.training)x = global_mean_pool(x, batch)x = self.fc2(x)return xclass GraphonFactorization(torch.nn.Module, ABC):def __init__(self, num_factors: int, graphs: TUDataset, seed: int, param_args, node_type: str = 'categorical'):"""A basic graphon model based on Fourier transformationArgs:num_factors: the number of sin/cos bases for one graphongraphs: the graphs used as the prior of the modelseed: random seednode_type: 'binary', 'categorical' and 'continuous'"""super(GraphonFactorization, self).__init__()self.num_factors = num_factorsself.node_type = node_typeself.factors_graphon = nn.ParameterList()self.factors_signal = nn.ParameterList()self.num_partitions = []indices = list(range(len(graphs)))random.seed(seed)random.shuffle(indices)# indices = np.random.RandomState(seed).permutation(len(graphs))for c in range(self.num_factors):sample = graphs[indices[c]]adj = torch.sparse_coo_tensor(sample.edge_index,torch.ones(sample.edge_index.shape[1]),size=[sample.x.shape[0], sample.x.shape[0]])adj = adj.to_dense()# print(adj.shape)if len(adj.shape) > 2:adj = adj.sum(2)# attribute = sample.xdegrees = torch.sum(adj, dim=1)idx = torch.argsort(degrees)# print(idx.shape)adj = adj[idx, :][:, idx]# attribute = attribute[idx, :]num_partitions = adj.shape[0]graphon = nn.Parameter(data=(adj - 0.5), requires_grad=True)# if self.node_type == "binary" or "categorical":#     signal = nn.Parameter(data=(attribute - 0.5), requires_grad=True)# else:#     signal = nn.Parameter(data=attribute, requires_grad=True)self.num_partitions.append(num_partitions)self.factors_graphon.append(graphon)# self.factors_signal.append(signal)# self.dim = self.factors_signal[0].shape[1]self.sigmoid = nn.Sigmoid()# self.relu = nn.ReLU()self.softmax0 = nn.Softmax(dim=0)self.softmax1 = nn.Softmax(dim=1)self.softmax2 = nn.Softmax(dim=2)# 这里三层softmax是?self.batch_size = param_args['batch_size']self.fc = torch.nn.Linear(param_args['batch_size'], 1)# x = (torch.arange(0, 100) + 0.5).view(1, -1) / 100# self.register_buffer('positions', x)self.num_components =param_args['n_components']self.prior_type = param_args['prior_type']self.prior = Prior(data_size=[self.num_components, self.num_factors],prior_type=self.prior_type)def sampling_z(self, num_samples):return self.prior.sampling(num_samples)def sampling(self, vs: torch.Tensor):"""Sampling graphon factorsArgs:vs: (n_nodes)Returns:graphons: (n_factors, n_nodes, n_nodes)signals: (n_factors, n_nodes, n_nodes)"""n_nodes = vs.shape[0]graphons = torch.zeros(self.num_factors, n_nodes, n_nodes).to(vs.device)# signals = torch.zeros(1, self.num_factors, n_nodes, self.dim).to(vs.device)for c in range(self.num_factors):idx = torch.floor(self.num_partitions[c] * vs).long()graphons[c, :, :] = self.factors_graphon[c][idx, :][:, idx]# signals[0, c, :, :] = self.factors_signal[c][idx, :]graphons = self.sigmoid(graphons)# return graphons, signalsreturn graphonsdef forward(self, zs: torch.Tensor, vs: torch.Tensor):"""Given a graphon model, sample a batch of graphs from itArgs:zs: (batch_size, n_factors) latent representationsvs: (n_nodes) random variables ~ Uniform([0, 1])Returns:graphon: (batch_size, n_nodes, n_nodes)signal: (batch_size, n_nodes, dim)graph: (batch_size, n_nodes, n_nodes) adjacency matrixattribute: (batch_size, n_nodes, dim) node attributes"""tzs = zs.t()tzs_pad = tzsif self.batch_size - tzs.shape[1] != 0:pad = torch.zeros(tzs.shape[0],self.batch_size - tzs.shape[1],device=tzs.device)tzs_pad = torch.cat((tzs, pad), dim=1)zs_hat_one = self.fc(tzs_pad)zs_hat = self.softmax1(zs_hat_one)# graphons, signals = self.sampling(vs)  # basis# print('zs_hat', zs_hat.shape)graphons_basis = self.sampling(vs)  # basis# graphons_basis = torch.sigmoid(graphons)  #TODO: here change to [0,1]assert (graphons_basis.max().item() <= 1and graphons_basis.min().item() >= 0)# print('graphons', graphons.shape)graphon_est = (zs_hat.view(self.num_factors, 1, 1) * graphons_basis).sum(0)  # ( n_nodes, n_nodes)# signal = (zs_hat.view(-1, self.num_factors, 1, 1) * signals).sum(1)  # (batch, n_nodes, dim)# if self.node_type == 'binary':#     signal = self.sigmoid(signal)# if self.node_type == 'categorical':#     signal = self.softmax2(signal)# graphs = torch.bernoulli(graphon)  # TODO: ???!?!# graphs += graphs.clone().permute(0, 2, 1)  # Change: add .clone()# graphs[graphs > 1] = 1# if self.node_type == "binary":#     attributes = torch.bernoulli(signal)# elif self.node_type == "categorical":#     distribution = torch.distributions.one_hot_categorical.OneHotCategorical(signal)#     attributes = distribution.sample()# else:#     distribution = torch.distributions.normal.Normal(signal, scale=2)#     attributes = distribution.sample()# return graphon, signal, graphs, attributesreturn graphon_estclass GraphonNewFactorization(torch.nn.Module, ABC):def __init__(self, num_factors: int, graphs_pre, graphs_down: TUDataset, seed: int, args,node_type: str = 'categorical'):"""A basic graphon model based on Fourier transformationArgs:num_factors: the number of sin/cos bases for one graphongraphs: the graphs used as the prior of the modelseed: random seednode_type: 'binary', 'categorical' and 'continuous'"""super(GraphonNewFactorization, self).__init__()self.num_factors = num_factorsself.node_type = node_typeself.factors_graphon = nn.ParameterList()self.factors_signal = nn.ParameterList()self.num_partitions = []num_mul_nodes = args.nnodes * args.ngraphsindices_pre = list(range(len(graphs_pre)))random.seed(seed)random.shuffle(indices_pre)indices_pre = indices_pre[:num_mul_nodes]num_pre_factors = int(self.num_factors / 2)print('Dealing with pretrain basis')for c in range(num_pre_factors):sample = graphs_pre[indices_pre[c]]node_ids = list(range(sample.x.shape[0]))random.shuffle(node_ids)node_sample = node_ids[:num_mul_nodes]edge, _ = subgraph(node_sample, sample.edge_index, relabel_nodes=True)adj = torch.sparse_coo_tensor(edge,torch.ones(edge.shape[1]),size=[num_mul_nodes, num_mul_nodes])adj = adj.to_dense()# print(adj.shape)if len(adj.shape) > 2:adj = adj.sum(2)# attribute = sample.xdegrees = torch.sum(adj, dim=1)idx = torch.argsort(degrees)# print(idx.shape)adj = adj[idx, :][:, idx]# attribute = attribute[idx, :]num_partitions = adj.shape[0]graphon = nn.Parameter(data=(adj - 0.5), requires_grad=True)self.num_partitions.append(num_partitions)self.factors_graphon.append(graphon)indices_down = list(range(len(graphs_down)))random.seed(seed)random.shuffle(indices_down)# indices = np.random.RandomState(seed).permutation(len(graphs))print('Dealing with downstream basis')for c in range(self.num_factors - num_pre_factors):sample = graphs_down[indices_down[c]]# adj = torch.sparse_coo_tensor(sample.edge_index,#                               torch.ones(sample.edge_index.shape[1]),#                               size=[sample.x.shape[0], sample.x.shape[0]])node_ids = list(range(sample.x.shape[0]))random.shuffle(node_ids)node_sample = node_ids[:num_mul_nodes]edge, _ = subgraph(node_sample, sample.edge_index, relabel_nodes=True)adj = torch.sparse_coo_tensor(edge,torch.ones(edge.shape[1]),size=[num_mul_nodes, num_mul_nodes])adj = adj.to_dense()# print(adj.shape)if len(adj.shape) > 2:adj = adj.sum(2)# attribute = sample.xdegrees = torch.sum(adj, dim=1)idx = torch.argsort(degrees)# print(idx.shape)adj = adj[idx, :][:, idx]# attribute = attribute[idx, :]num_partitions = adj.shape[0]graphon = nn.Parameter(data=(adj - 0.5), requires_grad=True)# if self.node_type == "binary" or "categorical":#     signal = nn.Parameter(data=(attribute - 0.5), requires_grad=True)# else:#     signal = nn.Parameter(data=attribute, requires_grad=True)self.num_partitions.append(num_partitions)self.factors_graphon.append(graphon)# self.factors_signal.append(signal)# self.dim = self.factors_signal[0].shape[1]self.sigmoid = nn.Sigmoid()# self.relu = nn.ReLU()self.softmax0 = nn.Softmax(dim=0)self.softmax1 = nn.Softmax(dim=1)self.softmax2 = nn.Softmax(dim=2)# 这里三层softmax是?self.batch_size = args.batch_sizeself.fc = torch.nn.Linear(args.batch_size, 1)# x = (torch.arange(0, 100) + 0.5).view(1, -1) / 100# self.register_buffer('positions', x)self.num_components = args.n_componentsself.prior_type = args.prior_typeself.prior = Prior(data_size=[self.num_components, self.num_factors],prior_type=self.prior_type)def sampling_z(self, num_samples):return self.prior.sampling(num_samples)def sampling(self, vs: torch.Tensor):"""Sampling graphon factorsArgs:vs: (n_nodes)Returns:graphons: (n_factors, n_nodes, n_nodes)signals: (n_factors, n_nodes, n_nodes)"""n_nodes = vs.shape[0]graphons = torch.zeros(self.num_factors, n_nodes, n_nodes).to(vs.device)# signals = torch.zeros(1, self.num_factors, n_nodes, self.dim).to(vs.device)for c in range(self.num_factors):idx = torch.floor(self.num_partitions[c] * vs).long()graphons[c, :, :] = self.factors_graphon[c][idx, :][:, idx]# signals[0, c, :, :] = self.factors_signal[c][idx, :]graphons = self.sigmoid(graphons)# return graphons, signalsreturn graphonsdef forward(self, zs: torch.Tensor, vs: torch.Tensor):"""Given a graphon model, sample a batch of graphs from itArgs:zs: (batch_size, n_factors) latent representationsvs: (n_nodes) random variables ~ Uniform([0, 1])Returns:graphon: (batch_size, n_nodes, n_nodes)signal: (batch_size, n_nodes, dim)graph: (batch_size, n_nodes, n_nodes) adjacency matrixattribute: (batch_size, n_nodes, dim) node attributes"""tzs = zs.t()tzs_pad = tzsif self.batch_size - tzs.shape[1] != 0:pad = torch.zeros(tzs.shape[0],self.batch_size - tzs.shape[1],device=tzs.device)tzs_pad = torch.cat((tzs, pad), dim=1)zs_hat_one = self.fc(tzs_pad)zs_hat = self.softmax1(zs_hat_one)# graphons, signals = self.sampling(vs)  # basis# print('zs_hat', zs_hat.shape)graphons_basis = self.sampling(vs)  # basis# graphons_basis = torch.sigmoid(graphons)  #TODO: here change to [0,1]assert (graphons_basis.max().item() <= 1and graphons_basis.min().item() >= 0)# print('graphons', graphons.shape)graphon_est = (zs_hat.view(self.num_factors, 1, 1) * graphons_basis).sum(0)  # ( n_nodes, n_nodes)# signal = (zs_hat.view(-1, self.num_factors, 1, 1) * signals).sum(1)  # (batch, n_nodes, dim)# if self.node_type == 'binary':#     signal = self.sigmoid(signal)# if self.node_type == 'categorical':#     signal = self.softmax2(signal)# graphs = torch.bernoulli(graphon)  # TODO: ???!?!# graphs += graphs.clone().permute(0, 2, 1)  # Change: add .clone()# graphs[graphs > 1] = 1# if self.node_type == "binary":#     attributes = torch.bernoulli(signal)# elif self.node_type == "categorical":#     distribution = torch.distributions.one_hot_categorical.OneHotCategorical(signal)#     attributes = distribution.sample()# else:#     distribution = torch.distributions.normal.Normal(signal, scale=2)#     attributes = distribution.sample()# return graphon, signal, graphs, attributesreturn graphon_estdef raml(graphons_hat, graphons_lbl, args):# adj = torch.sparse_coo_tensor(data.edge_index,#                               torch.ones(data.edge_index.shape[1]),#                               size=[data.x.shape[0], data.x.shape[0]])# adj = adj.to_dense()# if len(adj.shape) > 2:#     adj = adj.sum(2)# log_p_x = torch.zeros(graphons_hat.shape[0], args.n_graphs).to(graphons_hat.device)# for b in range(graphons_hat.shape[0]):# d_fgw = torch.zeros(args.n_graphs).to(graphons_hat.device)# adj2 = adj[data.batch == b, :][:, data.batch == b]# s2 = data.x[data.batch == b, :]# for k in range(args.n_graphs):#     adj0 = graphons_hat[b, k * args.n_nodes:(k + 1) * args.n_nodes, :][:,#            k * args.n_nodes:(k + 1) * args.n_nodes]#     # s0 = signals[b, k * args.n_nodes:(k + 1) * args.n_nodes, :]#     adj1 = graphons_lbl[b, k * args.n_nodes:(k + 1) * args.n_nodes, :][:,#            k * args.n_nodes:(k + 1) * args.n_nodes]# s1 = attributes[b, k * args.n_nodes:(k + 1) * args.n_nodes, :]# if node_type == 'binary':#     log_p_x[b, k] = F.binary_cross_entropy(input=adj0, target=adj1, reduction='mean')# elif node_type == 'categorical':#     log_p_x[b, k] = F.binary_cross_entropy(input=adj0, target=adj1, reduction='mean')# else:# log_p_x[b, k] = F.binary_cross_entropy(input=adj0, target=adj1, reduction='mean')# d_fgw[k] = fgw_distance(adj1, adj2, args)# d_fgw[k] = fgw_distance(adj1, adj0, args)# print(b, k, d_fgw[k])# print('graphons_hat, graphons_lbl')# print(graphons_hat.shape)# print(graphons_lbl.shape)d_fgw = fgw_distance(graphons_hat, graphons_lbl, args)# print('d_fgw', d_fgw)# q_x = F.softmax(-2 * d_fgw / torch.min(d_fgw), dim=0).detach()  # TODO: detach() ??# log_p_x[b, :] *= q_x# print(q_x.shape)# log_p_x[b, :] = q_x# return log_p_x.mean()return d_fgwdef distance_tensor(pts_src: torch.Tensor, pts_dst: torch.Tensor, p: int = 2):"""Returns the matrix of ||x_i-y_j||_p^p.:param pts_src: [R, D] matrix:param pts_dst: [C, D] matrix:param p::return: [R, C, D] distance matrix"""x_col = pts_src.unsqueeze(1)y_row = pts_dst.unsqueeze(0)distance = torch.abs(x_col - y_row) ** preturn distancedef sliced_fgw_distance(posterior_samples, prior_samples, num_projections=50, p=2, beta=0.1):# derive latent space dimension size from random samples drawn from latent prior distributionembedding_dim = prior_samples.size(1)# generate random projections in latent spaceprojections = torch.randn(size=(embedding_dim, num_projections)).to(posterior_samples.device)projections /= (projections ** 2).sum(0).sqrt().unsqueeze(0)# calculate projections through the encoded samplesposterior_projections = posterior_samples.matmul(projections)  # batch size x #projectionsprior_projections = prior_samples.matmul(projections)  # batch size x #projectionsposterior_projections = torch.sort(posterior_projections, dim=0)[0]prior_projections1 = torch.sort(prior_projections, dim=0)[0]prior_projections2 = torch.sort(prior_projections, dim=0, descending=True)[0]posterior_diff = distance_tensor(posterior_projections, posterior_projections, p=p)prior_diff1 = distance_tensor(prior_projections1, prior_projections1, p=p)prior_diff2 = distance_tensor(prior_projections2, prior_projections2, p=p)# print(posterior_projections.size(), prior_projections1.size())# print(posterior_diff.size(), prior_diff1.size())w1 = torch.sum((posterior_projections - prior_projections1) ** p, dim=0)w2 = torch.sum((posterior_projections - prior_projections2) ** p, dim=0)# print(w1.size(), torch.sum(w1))gw1 = torch.mean(torch.mean((posterior_diff - prior_diff1) ** p, dim=0), dim=0)gw2 = torch.mean(torch.mean((posterior_diff - prior_diff2) ** p, dim=0), dim=0)# print(gw1.size(), torch.sum(gw1))fgw1 = (1 - beta) * w1 + beta * gw1fgw2 = (1 - beta) * w2 + beta * gw2return torch.sum(torch.min(fgw1, fgw2))if __name__ == "__main__":pass

小结

结构差异现象的核心根源在于预训练图目标应用图(或称为下游图)之间在生成模式上存在的显著差异。针对此问题,文中创新性地提出了G-TUNING方法,旨在保持并适应下游图的独特生成模式。具体而言,对于给定的下游图G,G-TUNING的核心策略是调整预训练的图神经网络(GNN),使其能够重新构造出G的生成模式,这里以图元w为代表。

然而,直接精确重建图元w在计算上极具挑战性且成本高昂。为突破这一瓶颈,文章进一步提供了深入的理论分析,证明了存在一组称为graphon bases的替代graphon,它们可以作为构建块来近似表示图元w。通过巧妙地利用这些graphon bases的线性组合,G-TUNING能够以高效的方式逼近真实的图元生成模式。

这一理论发现为G-TUNING模型的构建奠定了坚实基础,因为它允许模型有效地学习graphon bases及其相应的组合系数,从而在不牺牲性能的前提下,显著降低计算复杂度。实验结果显示,与现有技术相比,G-TUNING在域内迁移学习和跨域迁移学习场景下,分别实现了平均0.5%和2.6%的性能提升,充分验证了其有效性和优越性。

参考文献

[1] Yifei Sun, Qi Zhu, Yang Yang, Chunping Wang, Tianyu Fan, Jiajun Zhu, Lei Chen. Fine-tuning Graph Neural Networks by Preserving Graph Generative Patterns. [C] AAAI2024 https://export.arxiv.org/abs/2312.13583

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/50794.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

几个小创新模型,Transformer与SVM、LSTM、BiLSTM、Adaboost的结合,MATLAB分类全家桶再更新!...

截止到本期MATLAB机器学习分类全家桶&#xff0c;一共发了5篇&#xff0c;参考文章如下&#xff1a; 1.机器学习分类全家桶&#xff0c;模式识别&#xff0c;故障诊断的看这一篇绝对够了&#xff01;MATLAB代码 2. 再更新&#xff0c;机器学习分类全家桶&#xff0c;模式识别&a…

【四】jdk8基于m2芯片arm架构Ubuntu24虚拟机下载与安装

文章目录 1. 安装版本2. 开始安装3. 集群安装 1. 安装版本 如无特别说明&#xff0c;本文均在root权限下安装。进入oracle官网&#xff1a;https://www.oracle.com/java/technologies/downloads/找到最下面Java SE 看到java 8&#xff0c;下载使用 ARM64 Compressed Archive版…

vue3+vite纯前端实现自动触发浏览器刷新更新版本内容,并在打包时生成版本号文件

前言 在前端项目中&#xff0c;有时候为了实现自动触发浏览器刷新并更新版本内容&#xff0c;可以采取一系列巧妙的措施。我的项目中是需要在打包时候生成一个version.js文件&#xff0c;用当前打包时间作为版本的唯一标识&#xff0c;然后打包发版 &#xff0c;从实现对版本更…

五大设备制造商的 200 多种机型的安全启动功能完全失效

2012 年&#xff0c;一个由硬件和软件制造商组成的行业联盟采用了安全启动技术&#xff0c;以防范长期存在的安全威胁。这种威胁是恶意软件的幽灵&#xff0c;它可以感染 BIOS&#xff0c;即每次计算机启动时加载操作系统的固件。从那里&#xff0c;它可以保持不受检测和删除&a…

从零开始学Java(超详细韩顺平老师笔记梳理)08——面向对象编程中级(上)IDEA常用快捷键、包、封装、继承

文章目录 前言一、IDEA使用常用快捷键模板/自定义模板 二、包package1. 基本介绍2. 包的命名规范3. 常用的包和如何引入4. 注意事项和细节 三、访问修饰符&#xff08;四类&#xff09;四、封装Encapsulation&#xff08;重点&#xff09;1. 封装介绍2. 封装步骤3. 快速入门4. …

SpringCloud Nacos的配置与使用

Spring Cloud Nacos的配置与使用 文章目录 Spring Cloud Nacos的配置与使用1. 简单介绍2. 环境搭建3. 服务注册/服务发现4. Nacos 负载均衡4.1 服务下线4.2 权重配置4.3 同集群优先访问 5. Nacos 健康检查5.1 两种健康检查机制5.2 服务实例类型 6.Nacos 环境隔离6.1 创建namesp…

【MySQL进阶之路 | 高级篇】表级锁之S锁,X锁,意向锁

1. 从数据操作的粒度划分&#xff1a;表级锁&#xff0c;页级锁&#xff0c;行锁 为了尽可能提高数据库的并发度&#xff0c;每次锁定的数据范围越小越好&#xff0c;理论上每次只锁定当前操作的数据的方案会得到最大的并发度&#xff0c;但是管理锁是很耗资源的事情&#xff…

驾驭代码的无形疆界:动态内存管理揭秘

目录 1.:为什么要有动态内存分配 2.malloc和free 2.1:malloc 2.2:free 3.calloc和realloc 3.1:calloc 3.1.1:代码1(malloc) 3.1.2:代码2(calloc) 3.2:realloc 3.2.1:原地扩容 3.2.2:异地扩容 3.2.3:代码1(原地扩容) 3.2.3:代码2(异地扩容) 4:常见的动态内存的错误…

vite + xlsx + xlsx-style 导出 Excel

如下 npm i 依赖 npm i xlsxnpm i xlsx-style-vite1、简单的使用&#xff1a;.vue文件中使用 const dataSource ref([]) // 数据源const columns [{title: 用户名,key: userName,width: 120,},{title: 用户组,key: userGroup,width: 120,},{title: 状态,key: enable,width: …

鸿蒙(HarmonyOS)下拉选择控件

一、操作环境 操作系统: Windows 11 专业版、IDE:DevEco Studio 3.1.1 Release、SDK:HarmonyOS 3.1.0&#xff08;API 9&#xff09; 二、效果图 三、代码 SelectPVComponent.ets Component export default struct SelectPVComponent {Link selection: SelectOption[]priva…

浅谈我对RESTful架构的理解

总结说在前面&#xff1a; RESTful API是目前比较成熟的一套互联网应用程序的 API 设计理论&#xff0c;他是一种理论规范&#xff0c;方便不同的前端设备与后端进行通信&#xff0c;在 RESTful 风格的 API 设计架构中&#xff0c;每个网址代表一种资源&#xff08;resource&am…

maven介绍 搭建Nexus3(maven私服搭建)

Maven是一个强大的项目管理工具&#xff0c;它基于项目对象模型&#xff08;POM&#xff1a;Project Object Model&#xff09;的概念&#xff0c;通过XML格式的配置文件&#xff08;pom.xml&#xff09;来管理项目的构建 Maven确实可以被视为一种工程管理工具或项目自动化构…

飞凌嵌入式技术创新日深圳站,8月26日见!

飞凌嵌入式技术创新日&#xff08;深圳站&#xff09;将于8月26日举行&#xff0c;一场嵌入式前沿科技的高端局就在眼前。届时&#xff0c;将有多位重量级技术大咖出席&#xff0c;为大家分享最新的研究成果、独到的行业见解和典型的应用案例&#xff0c;紧密结合当前行业热点和…

网络服务综合项目(一键部署shell脚本)

目录 需求&#xff1a; 主机环境描述 注意&#xff1a; 项目需求&#xff1a; 代码讲解 配置本地仓库 安装软件包 配置防火墙 配置策略中的一个布尔值 配置web服务 配置网络仓库 配置DNS服务 配置NTP服务 配置MySQL服务 配置NFS服务 配置论坛服务 进入网站配置…

Python数据分析案例55——基于LSTM结构自编码器的多变量时间序列异常值监测

案例背景 时间序列的异常值检测是方兴未艾的话题。比如很多单变量的&#xff0c;一条风速&#xff0c;一条用电量这种做时间序列异常值检测&#xff0c;想查看一下哪个时间点的用电量异常。 多变量时间序列由不同变量随时间变化的序列组成&#xff0c;这些时间序列在实际应用…

小黄人欢乐来袭,国漫萌物大集结

最近上映的《神偷奶爸》4不知道大家有没有去看&#xff0c;小黄人作为该系列电影的标志性角色&#xff0c;继续以其呆萌可爱的形象和幽默搞怪的性格赢得了观众的喜爱。 在中国动漫中&#xff0c;也有许多可爱且深受观众喜爱的萌物角色。这些角色以其独特的形象、有趣的性格和与…

数据结构day6

一、思维导图 二、模拟面试 typedef定义函数指针的方式typedef int(*p)(int,int);对void*指针的理解&#xff0c;相关应用万能指针&#xff0c;可以定义形参用来接收任意类型的指针变量&#xff0c;也可以定义函数用来返回任意类型的指针变量例如malloc函数在堆区申请内存&…

HTTP协议和RPC协议的区别是什么

从功能层面来说&#xff0c;HTTP协议是一个应用层的超文本传输协议&#xff0c;它是万维网数据通信的一个基础&#xff0c;主要服务在网页端和服务端的一个数据传输上。而RPC是一个远程过程调用协议&#xff0c;它是定位在实现分布式应用之间的一个数据通信&#xff0c;屏蔽了通…

SpringBoot入门:如何新建SpringBoot项目(保姆级教程)

在本文中&#xff0c;我们将演示如何新建一个基本的 Spring Boot 项目。写这篇文章的时候我还是很惊讶的&#xff0c;因为我发现有些java的初学者&#xff0c;甚至工作10年的老员工居然并不会新建一个SpringBoot项目&#xff0c;所以特别出了一篇文章来教大家新建一个SpringBoo…

数据挖掘-数据预处理

来自&#x1f96c;&#x1f436;程序员 Truraly | 田园 的博客&#xff0c;最新文章首发于&#xff1a;田园幻想乡 | 原文链接 | github &#xff08;欢迎关注&#xff09; 文章目录 3.3.1 数据的中心趋势平均数和加权平均数众数&#xff0c;中位数和均值描述数据的离散程度 &a…