【图神经网络】GNNExplainer代码解读及其PyG实现

GNNExplainer代码解读及其PyG实现

  • 使用GNNExplainer
  • GNNExplainer源码速读
    • 前向传播
    • 损失函数
  • 基于GNNExplainer图分类解释的PyG代码示例
  • 参考资料

接上一篇博客图神经网络的可解释性方法及GNNexplainer代码示例,我们这里简单分析GNNExplainer源码,并用PyTorch Geometric手动实现。
GNNExplainer的源码地址:https://github.com/RexYing/gnn-model-explainer

使用GNNExplainer

(1)安装:

git clone https://github.com/RexYing/gnn-model-explainer

推荐使用python3.7以及创建虚拟环境:

virtualenv venv -p /usr/local/bin/python3
source venv/bin/activate

(2)训练一个GCN模型

python train.py --dataset=EXPERIMENT_NAME

其中EXPERIMENT_NAME表示想要复现的实验名称。

训练GCN模型的完整选项列表:

python train.py --help

(3)解释一个GCN模型
要运行解释器,请运行以下内容:

python explainer_main.py --dataset=EXPERIMENT_NAME

(4)可视化解释
使用Tensorboard:优化的结果可以通过Tensorboard可视化。

tensorboard --logdir log

GNNExplainer源码速读

GNNExplainer会从2个角度解释图:

  • 边(edge):会生成一个edge mask,表示每条边在图中出现的概率,值为0-1之间的浮点数。edge mask也可以当作一个权重,可以取topk的edge连成的子图来解释。
  • 结点特征(node feature):node feature(NF)即结点向量,比如一个结点128维表示128个特征,那么它同时会生成一个NF mask来表示每个特征的权重,这个可以不要。

代码目录

  • explainer目录下的ExplainModel类定义了GNNExplainer网络的模块结构,继承torch.nn.Module:

    • 在初始化init的时候,用construct_edge_maskconstruct_feat_mask函数初始化要学习的两个mask(分别对应于两个nn.Parameter类型的变量: n × n n×n n×n维的maskd维全0的feat_mask);diag_mask即主对角线上是0,其余元素均为1的矩阵,用于_masked_adj函数。
    • _masked_adj函数将mask用sigmod或ReLU激活后,加上自身转置再除以2,以转为对称矩阵,然后乘上diag_mask,最终将原邻接矩阵adj变换为masked_adj
  • Explainer类实现了解释的逻辑,主函数是其中的explain,用于解释原模型在单节点的预测结果,主要步骤:

    1. 取子图的adj, x, label图解释:取graph_idx对应的整个计算图;节点解释:调用extract_neighborhood函数取该节点num_gc_layers阶数的邻居。
    2. 将传入的模型预测输出pred转为pred_label
    3. 构建ExplainModule,进行num_epochs轮训练(前向+反向传播)
adj   = torch.tensor(sub_adj, dtype=torch.float)
x     = torch.tensor(sub_feat, requires_grad=True, dtype=torch.float)
label = torch.tensor(sub_label, dtype=torch.long)if self.graph_mode:pred_label = np.argmax(self.pred[0][graph_idx], axis=0)print("Graph predicted label: ", pred_label)
else:pred_label = np.argmax(self.pred[graph_idx][neighbors], axis=1)print("Node predicted label: ", pred_label[node_idx_new])explainer = ExplainModule(adj=adj,x=x,model=self.model,label=label,args=self.args,writer=self.writer,graph_idx=self.graph_idx,graph_mode=self.graph_mode,
)
if self.args.gpu:explainer = explainer.cuda()...# NODE EXPLAINER
def explain_nodes(self, node_indices, args, graph_idx=0):
...def explain_nodes_gnn_stats(self, node_indices, args, graph_idx=0, model="exp"):
...# GRAPH EXPLAINER
def explain_graphs(self, graph_indices):
...

explain_nodesexplain_nodes_gnn_statsexplain_graphs这三个函数都是在它的基础上实现的。

下面分析其中的forwardloss函数。

前向传播

首先把待学习的参数mask和feat_mask分别乘上原邻接矩阵和特征向量,得到变换后的masked_adjx。前者通过调用_masked_adj函数完成,后者的实现如下:

feat_mask = (torch.sigmoid(self.feat_mask)if self.use_sigmoidelse self.feat_mask
)
if marginalize:std_tensor = torch.ones_like(x, dtype=torch.float) / 2mean_tensor = torch.zeros_like(x, dtype=torch.float) - xz = torch.normal(mean=mean_tensor, std=std_tensor)x = x + z * (1 - feat_mask)
else:x = x * feat_mask

完整代码如下:
forward
这里需要说明的是marginalize为True的情况,参考论文中的Learning binary feature selector F:
Learning binary feature selector F

  • 如果同mask一样学习feature_mask,在某些情况下回导致重要特征也被忽略(学到的特征遮罩也是接近于0的值),因此,依据 X S X_S XS的经验边缘分布使用Monte Carlo方法来抽样得到 X = X S F X=X_S^F X=XSF.
  • 为了解决随机变量 X X X的反向传播的问题,引入了"重参数化"的技巧,即将其表示为一个无参的随机变量 Z Z Z的确定性变换: X = Z + ( X S − Z ) ⊙ F X=Z+(X_S-Z)\odot F X=Z+(XSZ)F s . t . ∑ j F j ≤ K F s.t. \sum_{j}F_j\le K_F s.t.jFjKF
    其中, Z Z Z是依据经验分布采样得到的 d d d维随机变量, K F K_F KF是表示保留的最大特征数的参数(utils/io_utils.py中的denoise_graph函数)。

接着将masked_adjx输入原始模型得到ExplainModule结果pred

损失函数

loss = pred_loss + size_loss + lap_loss + mask_ent_loss + feat_size_loss

可知,总的loss包含五项,除了对应于论文中损失函数公式的pred_loss,其余各项损失的作用参考论文Integrating additional constraints into explanations,它们的权重定义在coeffs中:

self.coeffs = {"size": 0.005,"feat_size": 1.0,"ent": 1.0,"feat_ent": 0.1,"grad": 0,"lap": 1.0,
}

Integrating additional constraints into explanations

  1. pred_loss
mi_obj = False
if mi_obj:pred_loss = -torch.sum(pred * torch.log(pred))
else:pred_label_node = pred_label if self.graph_mode else pred_label[node_idx]gt_label_node = self.label if self.graph_mode else self.label[0][node_idx]logit = pred[gt_label_node]pred_loss = -torch.log(logit)

其中pred是当前的预测结果,pred_label是原始特征上的预测结果。

  1. mask_ent_loss
# entropy
mask_ent = -mask * torch.log(mask) - (1 - mask) * torch.log(1 - mask)
mask_ent_loss = self.coeffs["ent"] * torch.mean(mask_ent)
  1. size_loss
# size
mask = self.mask
if self.mask_act == "sigmoid":mask = torch.sigmoid(self.mask)
elif self.mask_act == "ReLU":mask = nn.ReLU()(self.mask)
size_loss = self.coeffs["size"] * torch.sum(mask)
  1. feat_size_loss
# pre_mask_sum = torch.sum(self.feat_mask)
feat_mask = (torch.sigmoid(self.feat_mask) if self.use_sigmoid else self.feat_mask
)
feat_size_loss = self.coeffs["feat_size"] * torch.mean(feat_mask)
  1. lap_loss
# laplacian
D = torch.diag(torch.sum(self.masked_adj[0], 0))
m_adj = self.masked_adj if self.graph_mode else self.masked_adj[self.graph_idx]
L = D - m_adj
pred_label_t = torch.tensor(pred_label, dtype=torch.float)
if self.args.gpu:pred_label_t = pred_label_t.cuda()L = L.cuda()
if self.graph_mode:lap_loss = 0
else:lap_loss = (self.coeffs["lap"] * (pred_label_t @ L @ pred_label_t) / self.adj.numel())

补充

基于GNNExplainer图分类解释的PyG代码示例

对于图分类问题的解释,关键点有两个:

  • 要学习的Mask作用在整个图上,不用取子图
  • 标签预测和损失函数的对象是单个graph

实现代码如下:

#!/usr/bin/env python
# encoding: utf-8
# Created by BIT09 at 2023/4/28
import torch
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from math import sqrt
from tqdm import tqdm
from torch_geometric.nn import MessagePassing
from torch_geometric.data import Data
from torch_geometric.utils import k_hop_subgraph, to_networkxEPS = 1e-15class GNNExplainer(torch.nn.Module):r"""Args:model (torch.nn.Module): The GNN module to explain.epochs (int, optional): The number of epochs to train.(default: :obj:`100`)lr (float, optional): The learning rate to apply.(default: :obj:`0.01`)log (bool, optional): If set to :obj:`False`, will not log any learningprogress. (default: :obj:`True`)"""coeffs = {'edge_size': 0.001,'node_feat_size': 1.0,'edge_ent': 1.0,'node_feat_ent': 0.1,}def __init__(self, model, epochs=100, lr=0.01, log=True, node=False):  # disable node_feat_mask by defaultsuper(GNNExplainer, self).__init__()self.model = modelself.epochs = epochsself.lr = lrself.log = logself.node = nodedef __set_masks__(self, x, edge_index, init="normal"):(N, F), E = x.size(), edge_index.size(1)std = 0.1if self.node:self.node_feat_mask = torch.nn.Parameter(torch.randn(F) * 0.1)std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))self.edge_mask = torch.nn.Parameter(torch.randn(E) * std)self.edge_mask = torch.nn.Parameter(torch.zeros(E) * 50)for module in self.model.modules():if isinstance(module, MessagePassing):module.__explain__ = Truemodule.__edge_mask__ = self.edge_maskdef __clear_masks__(self):for module in self.model.modules():if isinstance(module, MessagePassing):module.__explain__ = Falsemodule.__edge_mask__ = Noneif self.node:self.node_feat_masks = Noneself.edge_mask = Nonedef __num_hops__(self):num_hops = 0for module in self.model.modules():if isinstance(module, MessagePassing):num_hops += 1return num_hopsdef __flow__(self):for module in self.model.modules():if isinstance(module, MessagePassing):return module.flowreturn 'source_to_target'def __subgraph__(self, node_idx, x, edge_index, **kwargs):num_nodes, num_edges = x.size(0), edge_index.size(1)if node_idx is not None:subset, edge_index, mapping, edge_mask = k_hop_subgraph(node_idx, self.__num_hops__(), edge_index, relabel_nodes=True,num_nodes=num_nodes, flow=self.__flow__())x = x[subset]else:x = xedge_index = edge_indexrow, col = edge_indexedge_mask = row.new_empty(row.size(0), dtype=torch.bool)edge_mask[:] = Truemapping = Nonefor key, item in kwargs:if torch.is_tensor(item) and item.size(0) == num_nodes:item = item[subset]elif torch.is_tensor(item) and item.size(0) == num_edges:item = item[edge_mask]kwargs[key] = itemreturn x, edge_index, mapping, edge_mask, kwargsdef __graph_loss__(self, log_logits, pred_label):loss = -torch.log(log_logits[0, pred_label])m = self.edge_mask.sigmoid()loss = loss + self.coeffs['edge_size'] * m.sum()ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS)loss = loss + self.coeffs['edge_ent'] * ent.mean()return lossdef visualize_subgraph(self, node_idx, edge_index, edge_mask, y=None,threshold=None, **kwargs):r"""Visualizes the subgraph around :attr:`node_idx` given an edge mask:attr:`edge_mask`.Args:node_idx (int): The node id to explain.edge_index (LongTensor): The edge indices.edge_mask (Tensor): The edge mask.y (Tensor, optional): The ground-truth node-prediction labels usedas node colorings. (default: :obj:`None`)threshold (float, optional): Sets a threshold for visualizingimportant edges. If set to :obj:`None`, will visualize alledges with transparancy indicating the importance of edges.(default: :obj:`None`)**kwargs (optional): Additional arguments passed to:func:`nx.draw`.:rtype: :class:`matplotlib.axes.Axes`, :class:`networkx.DiGraph`"""assert edge_mask.size(0) == edge_index.size(1)if node_idx is not None:# Only operate on a k-hop subgraph around `node_idx`.subset, edge_index, _, hard_edge_mask = k_hop_subgraph(node_idx, self.__num_hops__(), edge_index, relabel_nodes=True,num_nodes=None, flow=self.__flow__())edge_mask = edge_mask[hard_edge_mask]subset = subset.tolist()if y is None:y = torch.zeros(edge_index.max().item() + 1,device=edge_index.device)else:y = y[subset].to(torch.float) / y.max().item()y = y.tolist()else:subset = []for index, mask in enumerate(edge_mask):node_a = edge_index[0, index]node_b = edge_index[1, index]if node_a not in subset:subset.append(node_a.item())if node_b not in subset:subset.append(node_b.item())y = [y for i in range(len(subset))]if threshold is not None:edge_mask = (edge_mask >= threshold).to(torch.float)data = Data(edge_index=edge_index, att=edge_mask, y=y,num_nodes=len(y)).to('cpu')G = to_networkx(data, edge_attrs=['att'])  # , node_attrs=['y']mapping = {k: i for k, i in enumerate(subset)}G = nx.relabel_nodes(G, mapping)kwargs['with_labels'] = kwargs.get('with_labels') or Truekwargs['font_size'] = kwargs.get('font_size') or 10kwargs['node_size'] = kwargs.get('node_size') or 800kwargs['cmap'] = kwargs.get('cmap') or 'cool'pos = nx.spring_layout(G)ax = plt.gca()for source, target, data in G.edges(data=True):ax.annotate('', xy=pos[target], xycoords='data', xytext=pos[source],textcoords='data', arrowprops=dict(arrowstyle="->",alpha=max(data['att'], 0.1),shrinkA=sqrt(kwargs['node_size']) / 2.0,shrinkB=sqrt(kwargs['node_size']) / 2.0,connectionstyle="arc3,rad=0.1",))nx.draw_networkx_nodes(G, pos, node_color=y, **kwargs)nx.draw_networkx_labels(G, pos, **kwargs)return ax, Gdef explain_graph(self, data, **kwargs):self.model.eval()self.__clear_masks__()x, edge_index, batch = data.x, data.edge_index, data.batchnum_edges = edge_index.size(1)# Only operate on a k-hop subgraph around `node_idx`.x, edge_index, _, hard_edge_mask, kwargs = self.__subgraph__(node_idx=None, x=x, edge_index=edge_index,**kwargs)# Get the initial prediction.with torch.no_grad():log_logits = self.model(data, **kwargs)probs_Y = torch.softmax(log_logits, 1)pred_label = probs_Y.argmax(dim=-1)self.__set_masks__(x, edge_index)self.to(x.device)if self.node:optimizer = torch.optim.Adam([self.node_feat_mask, self.edge_mask],lr=self.lr)else:optimizer = torch.optim.Adam([self.edge_mask], lr=self.lr)epoch_losses = []for epoch in range(1, self.epochs + 1):epoch_loss = 0optimizer.zero_grad()if self.node:h = x * self.node_feat_mask.view(1, -1).sigmoid()log_logits = self.model(data, **kwargs)pred = torch.softmax(log_logits, 1)loss = self.__graph_loss__(pred, pred_label)loss.backward()optimizer.step()epoch_loss += loss.detach().item()epoch_losses.append(epoch_loss)edge_mask = self.edge_mask.detach().sigmoid()print(edge_mask)self.__clear_masks__()return edge_mask, epoch_lossesdef __repr__(self):return f'{self.__class__.__name__}()'

参考资料

  1. gnn-explainer
  2. 图神经网络的可解释性方法及GNNexplainer代码示例
  3. Pytorch实现GNNExplainer
  4. How to Explain Graph Neural Network — GNNExplainer
  5. https://gist.github.com/hongxuenong/9f7d4ce96352d4313358bc8368801707

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/645128.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于sentinel-2 遥感数据的水体提取(水体指数法)

本文框架设置如下: 简单介绍senintel-2数据;如何利用sentinel-2数据获取水体边界/范围 1 Sentinel-2数据介绍及下载方式 有Sentinel-2A/2B两颗卫星,其参数基本一致,因此两颗卫星的数据联合使用很方便。 分辨率有:1…

Laya2.13.3接入FGUI

下载与复制文件与Laya1.x类似,可以看我上一篇: Laya1.8.4接入FariyGui,以及其中踩的坑-CSDN博客 不同的是: 两个库文件需要在index.js中引入 新建一个脚本将fgui中搭建好的UI包引入: export default class GameApp…

食品加工厂可视化视频AI智能监管方案,助力工厂数字化运营

一、背景与需求分析 随着科技的不断进步和人们对食品安全和质量的日益关注,食品智慧工厂的建设成为了食品行业的一个重要趋势。智能化的食品工厂可以利用先进的技术和自动化系统,提高生产效率、降低监管成本,并确保产品的质量和安全。 行业…

Python爬虫——2023年西安全年气温数据并进行可视化处理

Python爬虫——2023年西安全年气温数据并进行可视化处理 一、网站选择 我们要找到西安历史气温数据,可以去一些天气网站上查找,但不一定每一个天气网站都会留有各城市的历史天气数据,因此我在这里给大家推荐两个网站方便大家进行历史气温的…

牛客30道题解析精修版

1.异常处理 都是Throwable的子类: ① Exception(异常):是程序本身可以处理的异常。 ② Error(错误): 是程序无法处理的错误。这些错误表示故障发生于虚拟机自身、或者发生在虚拟机试图执行应用时,一般不需要…

选择国产压测工具应注意什么?

随着互联网和信息技术的飞速发展,压力测试成为了确保软件系统稳定性和性能的不可或缺的环节。在压测工具的选择上,近年来国产压测工具逐渐崭露头角,但在使用时仍需谨慎。本文将探讨在选择国产压测工具时需要注意的关键因素。 功能完备性&…

GBASE南大通用数据库GBase 8s常见问题解析 -- 查找锁会话并解锁

本文摘自GBASE南大通用社区,by:wty,原文请点击:GBase 8s常见问题 -- 查找锁会话并解锁|GBASE社区|天津南大通用数据技术股份有限公司|GBASE-致力于成为用户最信赖的数据库产品供应商 问题现象 执行SQL时报错 244: Could not do…

电脑f盘满了怎么清理f盘空间?3个实用的方法!

磁盘空间的大小直接影响到计算机的性能和存储能力。当磁盘空间不足时,会导致系统运行缓慢、无法安装新的程序或存储文件,甚至会使计算机系统崩溃。那么如何解决磁盘空间不足的问题呢?下面提供了一些方法。 方法一:删除不需要的文…

跟着pink老师前端入门教程-day09

二十二、定位 22.1 为什么需要定位 1. 某个元素可以自由的在一个盒子内移动位置,并且压住其他盒子 2. 当我们滚动窗口时,盒子是固定屏幕某个位置的 解决方法: 1. 浮动可以让多个块级盒子一行没有缝隙排列显示,经常用于横向排…

STL第三讲

第三讲 stl六大部件:算法是函数模板,其他的是类模板 算法形式:传入两个迭代器(第三个参数可能有:一个比较的准则 算法需要的所有信息从迭代器获取 迭代器分类 基于红黑树的结构是双向迭代器; 基于hash的取…

【计算机网络】UDP协议与TCP协议

文章目录 一、端口号1.什么是端口号2.端口号范围划分3.认识知名端口号(Well-Know Port Number)4.netstat5.pidof 二、UDP协议1.UDP协议端格式2.UDP的特点3.面向数据报4.UDP的缓冲区5.UDP使用注意事项6.基于UDP的应用层协议 三、TCP协议1.TCP协议段格式1.1理解封装解包和分用1.2…

进程通信与socket编程实践之猜数字小游戏

socket是实现进程通信的一种重要方式,本文将通过socket编程实现服务器进程与客户端进程之间的通信,并在通信之外实现猜数字的小游戏。 1. 设计思路 本文设计的C/S结构的猜数字游戏功能如下:服务器端自动生成一个1-100之间的随机数字&#x…

100T数据存进服务器分几步?

大家好,我是豆小匠。 这期来聊聊数据存储相关的问题,包括: 容量评估。技术选型。容灾处理。 另外,文末赠送免费定制红包封面哦! 1. 容量评估 通过对容量&性能的评估,可以把业务需求转化成技术语言描…

鲲鹏微认证——openEuler开源操作系统迁移实践

文章目录 为什么要系统搬迁为什么选择欧拉欧拉系统迁移概述实施路径工具实战 为什么要系统搬迁 2020年12月,CentOs作为由开源社区免费提供的操作系统,宣布将对CentO58于2021年底停止服务,CentO57则于2024年6月底停止服务。 这将直接导致操作…

Linux系统SSH远程管理服务

目录 一、SSH服务介绍 1、SSH协议是什么? 2、SSH的优点 3、SSH的客户端与服务端 4、SSH的原理 4.1 公钥首次连接原理 4.2 ssh加密通讯原理 4.2.1 对称加密 4.2.2 非对称加密 4.2 ssh远程登录 二、服务端配置 1、常见配置项 1.1 修改默认端口 1.2 禁止…

未来已来:AI引领智能时代的多领域巨变

大家好,今天我们将深入探讨人工智能如何彻底改变我们的生活方式,领略未来的无限可能性。 1. 医疗革新:AI担任超级医生 医疗领域是AI最引人注目的战场之一。智能医学影像诊断系统,不仅能够精准识别病变,还能辅助医生提…

VS Code使用Git管理开发项目流程

以VSCode远程连接虚拟机开发为例,已经配置好SSH 在Github上搜索心仪的项目,比如权限管理 点击fork到自己账户仓库 虚拟机下创建一个目录 1)mkdir java_test 2)切换到java_test 初始化并克隆项目 1) git init:初始化仓库 2) g…

掼蛋功能之识别性格篇

常说:千人千面。大多数人一到牌局的场面,往往精神便会放松,面貌神情不再收敛,一言一行体现出的性格暴露无疑,具体表现为以下几种: 1、浮躁冲动型:此类人多数不讲究团队配合,自顾自出…

UE5 - Polycam扫描文件导入插件

Polycam是利用Gaussian Splatting进行3D重建的3D扫描相关软件,其对应有UE引擎的插件(Plugin_XV3dGS)可以把相关格式的文件导入到引擎; 首先Polycam的官网为:My Captures | Polycam 可以下载各种用户扫描文件&#xff…

vivado I/O和时钟规划设计流程步骤

I/O和时钟规划设计流程步骤 下图显示了左侧的项目设计流程步骤。水平箭头表示项目设计流程中可以执行I/O和时钟规划的点。中的步骤I/O和时钟规划设计流程如右图所示。 项目设计流程从一个空的I/O规划项目、RTL设计项目或合成后网表项目。使用这些项目类型中的任何一种&#xf…