DGL在异构图上的GraphConv模块

回顾同构图GraphConv模块

首先回顾一下同构图中实现GraphConv的主要思路(以GraphSAGE为例):
在初始化模块首先是获取源节点和目标节点的输入维度,同时获取输出的特征维度。根据SAGE论文提出的三种聚合操作,需要获取所使用的聚合类型,方便后面使用Pytorch中的nn模块实现。最后是特征归一化操作。
其具体的代码段为:

获取相关输入特征

        # 获取源节点和目标节点的输入特征维度self._in_src_feats, self._in_dest_feats = expand_as_pair(in_feats)# 输出特征维度self._out_feats = out_featsself._aggre_type = aggregator_typeself.norm = normself.activation = activation

根据聚合类型选择Pytorch对应的nn模块中的函数

        # 聚合类型:mean、pool、lstm、gcnif aggregator_type not in ['mean', 'pool', 'lstm', 'gcn']:raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))if aggregator_type == 'pool':self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)if aggregator_type == 'lstm':self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)if aggregator_type in ['mean', 'pool', 'lstm']:self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)

权重初始化

构造函数的最后调用了 reset_parameters() 进行权重初始化。

def reset_parameters(self):"""重新初始化可学习的参数"""gain = nn.init.calculate_gain('relu')if self._aggre_type == 'pool':nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)if self._aggre_type == 'lstm':self.lstm.reset_parameters()if self._aggre_type != 'gcn':nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)# 上面代码里的 norm 是用于特征归一化的可调用函数。在SAGEConv论文里,归一化可以是L2归一化: hv=hv/∥hv∥2

forward函数

在NN模块中, forward() 函数执行了实际的消息传递和计算。与通常以张量为参数的PyTorch NN模块相比,DGL NN模块额外增加了1个参数 :class:dgl.DGLGraph。forward() 函数的内容一般可以分为3项操作:

  1. 检测输入图对象是否符合规范。
  2. 消息传递和聚合
  3. 聚合后,更新特征作为输出。

检测输入图对象的规范性

# 输入图对象的规范检测
with graph.local_scope():# 指定图类型,然后根据图类型扩展输入特征feat_src, feat_dst = expand_as_pair(feat, graph)

对于expand_as_pair()函数,其实现的操作是如果输入的特征不是一对的话(源节点和目标节点),就根据图Graph将特征变成一对,但要求图必须是一个block,其对应的源码为:

def expand_as_pair(input_, g=None):"""Return a pair of same element if the input is not a pair.如果输入不是一对,则返回相同元素的一对。If the graph is a block, obtain the feature of destination nodes from the source nodes.如果图是块,则从源节点中获取目的节点的特征。Parameters----------input_ : Tensor, dict[str, Tensor], or their pairsThe input featuresg : DGLGraph or NoneThe graph.If None, skip checking if the graph is a block.Returns-------tuple[Tensor, Tensor] or tuple[dict[str, Tensor], dict[str, Tensor]]The features for input and output nodes输入和输出节点的特性"""if isinstance(input_, tuple):return input_elif g is not None and g.is_block:if isinstance(input_, Mapping):input_dst = {k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))for k, v in input_.items()}else:input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())return input_, input_dstelse:return input_, input_

消息传递和聚合

聚合部分的代码执行了消息传递和聚合的计算。这部分代码会因模块而异。请注意,代码中的所有消息传递均使用 update_all() APIDGL内置的消息/聚合函数来实现,以充分利用 2.2 编写高效的消息传递代码 里所介绍的性能优化。

        # 消息传递和聚合if self._aggre_type == 'mean':graph.srcdata['h'] = feat_srcgraph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))h_neigh = graph.dstdata['neigh']elif self._aggre_type == 'gcn':check_eq_shape(feat)graph.srcdata['h'] = feat_srcgraph.dstdata['h'] = feat_dstgraph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))# 除以入度degs = graph.in_degrees().to(feat_dst)h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)elif self._aggre_type == 'pool':graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))h_neigh = graph.dstdata['neigh']else:raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

如果是gcn聚合方式的话还需要用到它自身的特征,但是SAGE不需要,它只需要聚合邻居的特征,这里通过一条判断语句加以区分:

        # GraphSAGE中gcn聚合不需要fc_selfif self._aggre_type == 'gcn':rst = self.fc_neigh(h_neigh)else:rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)

更新特征

聚合后,更新特征作为输出——forward() 函数的最后一部分是在完成消息聚合后更新节点的特征。 常见的更新操作是根据构造函数中设置的选项来应用激活函数和进行归一化。

        # 更新特征作为输出# 激活函数if self.activation is not None:rst = self.activation(rst)# 归一化if self.norm is not None:rst = self.norm(rst)return rst

异构图GraphConv模块

DGL提供了 HeteroGraphConv,用于定义异构图上GNN模块。 实现逻辑与消息传递级别的API multi_update_all() 相同,它包括:

  • 每个关系上的DGL NN模块。
  • 聚合来自不同关系上的结果。
    其对应的数学公式为:(r表示关系)

在这里插入图片描述

__ init __函数

异构图的卷积操作接受一个字典类型参数 mods。这个字典的键为关系名,值为作用在该关系上NN模块对象。参数 aggregate 则指定了如何聚合来自不同关系的结果。

class HeteroGraphConv(nn.Module):def __init__(self, mods, aggregate='sum'):super(HeteroGraphConv, self).__init__()self.mods = nn.ModuleDict(mods)if isinstance(aggregate, str):# 获取聚合函数的内部函数self.agg_fn = get_aggregate_fn(aggregate)else:self.agg_fn = aggregate

nn.ModuleDict() 用于保存字典中的子模块。Pytorch官方也给出了对应的示例:

class MyModule(nn.Module):def __init__(self):super().__init__()self.choices = nn.ModuleDict({'conv': nn.Conv2d(10, 10, 3),'pool': nn.MaxPool2d(3)})self.activations = nn.ModuleDict([['lrelu', nn.LeakyReLU()],['prelu', nn.PReLU()]])def forward(self, x, choice, act):x = self.choices[choice](x)x = self.activations[act](x)return x

forward函数

对于前向传播函数,除了需要输入图和输入张量以外,它还需要2个额外的字典参数mod_argsmod_kwargs。这2个字典与 self.mods 具有相同的键,值则为对应NN模块自定义参数
forward() 函数的输出结果也是一个字典类型的对象。其键为 nty,其值为每个目标节点类型 nty 的输出张量的列表, 表示来自不同关系的计算结果HeteroGraphConv 会对这个列表进一步聚合,并将结果返回给用户。聚合操作主要是:

if g.is_block:src_inputs = inputsdst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
else:src_inputs = dst_inputs = inputsfor stype, etype, dtype in g.canonical_etypes:rel_graph = g[stype, etype, dtype]if rel_graph.num_edges() == 0:continueif stype not in src_inputs or dtype not in dst_inputs:continuedstdata = self.mods[etype](rel_graph,(src_inputs[stype], dst_inputs[dtype]),*mod_args.get(etype, ()),**mod_kwargs.get(etype, {}))outputs[dtype].append(dstdata)

输入 g 可以是异构图或来自异构图的子图区块。和普通的NN模块一样,forward() 函数需要分别处理不同的输入图类型

上述代码中的for循环为处理异构图计算的主要逻辑

  • 首先我们遍历图中所有的关系(通过调用 canonical_etypes)。
  • 通过关系名,我们可以使用g[ stype, etype, dtype ]的语法将只包含该关系的子图( rel_graph )抽取出来。
  • 对于二分图,输入特征将被组织为元组 (src_inputs[stype], dst_inputs[dtype])
  • 接着调用用户预先注册在该关系上的NN模块,并将结果保存在outputs字典中。

最后,HeteroGraphConv 会调用用户注册的 self.agg_fn 函数聚合来自多个关系的结果。

rsts = {}
for nty, alist in outputs.items():if len(alist) != 0:rsts[nty] = self.agg_fn(alist, nty)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/171332.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

蓝桥杯第四场双周赛(1~6)

1、水题 2、模拟题,写个函数即可 #define pb push_back #define x first #define y second #define int long long #define endl \n const LL maxn 4e057; const LL N 5e0510; const LL mod 1e097; const int inf 0x3f3f; const LL llinf 5e18;typedef pair…

vue3+ts 兄弟组件之间传值

父级&#xff1a; <template><div><!-- <A on-click"getFlag"></A><B :flag"Flag"></B> --><A></A><B></B></div> </template><script setup lang"ts"> i…

01、copilot+pycharm

之——free for student 目录 之——free for student 杂谈 正文 1.for student 2.pycharm 3.使用 杂谈 copilot是github推出的AI程序员&#xff0c;将chatgpt搬到了私人终端且无token限制&#xff0c;下面是使用方法。 GitHub Copilot 是由 GitHub 与 OpenAI 合作开发的…

2023年3月电子学会青少年软件编程 Python编程等级考试一级真题解析(判断题)

2023年3月Python编程等级考试一级真题解析 判断题(共10题,每题2分,共20分) 26、在Python编程中,print的功能是将print()小括号的内容输出到控制台,比如:在Python Shell中输入print(北京,你好)指令,小括号内容可以输出到控制台 答案:错 考点分析:考查python中print…

【【Linux编程介绍之关键配置和常用用法】】

Linux编程介绍之关键配置和常用用法 Hello World ! 我们所说的编写代码包括两部分&#xff1a;代码编写和编译&#xff0c;在Windows下可以使用Visual Studio来完成这两部&#xff0c;可以在 Visual Studio 下编写代码然后直接点击编译就可以了。但是在 Linux 下这两部分是分开…

2024年第十六届山东省职业院校技能大赛中职组 “网络安全”赛项竞赛正式卷任务书

2024年第十六届山东省职业院校技能大赛中职组 “网络安全”赛项竞赛正式卷任务书 2024年第十六届山东省职业院校技能大赛中职组 “网络安全”赛项竞赛正式卷A模块基础设施设置/安全加固&#xff08;200分&#xff09;A-1&#xff1a;登录安全加固&#xff08;Windows, Linux&am…

【Mybatis】预编译/即时sql 数据库连接池

回顾 Mybatis是一个持久层框架.有两种方式(这两种方式可以共存) 1.注解 2.xml 一.传递参数 以使用#{} 来接受参数为例 (以上两种方式一样适用的) 1)传递单个参数 #{} 可以为任意名称 2)多个参数 默认的参数名称就是接口方法声明的形参 3)参数为对象 默认给每个对象的每个属性都…

Linux内核中的overlay文件系统

一、简介 Docker 内核实现容器的功能用了linux 内核中的三个特性 Namespace、Cgroup、UnionFs&#xff0c;今天我们来说一下UnionFs。 linux UnionFs 实现的是overlay 文件系统 OverlayFs 文件系统分为三层&#xff0c; lower 是只读层 Upper 是可读写 Merged 是 lower 和U…

OD机考真题搜集:叠积木1

题目 有一堆长方体积木,它们的高度和宽度都相同,但长度不一。 小橙想把这堆积木叠成一面墙,墙的每层可以放一个积木,或将两个积木拼接起来,要求每层的长度相同。若必须用完这些积木,叠成的墙最多为多少层?如下是叠成的一面墙的图示,积木仅按宽和高所在的面进行拼接。 …

【数据结构】树与二叉树(廿六):树删除指定结点及其子树(算法DS)

文章目录 5.3.1 树的存储结构5. 左儿子右兄弟链接结构 5.3.2 获取结点的算法1. 获取大儿子、大兄弟结点2. 搜索给定结点的父亲3. 搜索指定数据域的结点4. 删除结点及其左右子树a. 逻辑删除与物理删除b. 算法DSTc. 算法解析d. 代码实现递归释放树算法DS e. 算法测试 5. 代码整合…

PPT 遇到问题总结(修改页码统计)

PPT常见问题 1. 修改页码自动计数 1. 修改页码自动计数 点击 视图——>幻灯片母版——>下翻找到计数页直接修改——>关闭母版视图

vue+springboot读取git的markdown文件并展示

前言 最近&#xff0c;在研究一个如何将我们git项目的MARKDOWN文档获取到&#xff0c;并且可以展示到界面通过检索查到&#xff0c;于是经过几天的摸索&#xff0c;成功的研究了出来 本次前端vue使用的是Markdown-it Markdown-it 是一个用于解析和渲染 Markdown 标记语言的 …

Cache学习(3):Cache地址映射(直接映射缓存组相连缓存全相连缓存)

1 Cache的与存储地址的映射 以一个Cache Size 为 128 Bytes 并且Cache Line是 16 Bytes的Cache为例。首先把这个Cache想象成一个数组&#xff0c;数组总共8个元素&#xff0c;每个元素大小是 16 Bytes&#xff0c;如下图&#xff1a; 现在考虑一个问题&#xff0c;CPU从0x0654…

城市生命线丨桥梁结构健康监测系统的作用

在城市建设当中&#xff0c;有非常多的城市基本建设&#xff0c;建设当中&#xff0c;桥梁作为不可忽视的一环&#xff0c;也需要有很多桥梁建设的智能监测系统&#xff0c;在这个桥梁结构健康监测系统中&#xff0c;桥梁的各个数值都能被监测得到。 WITBEE万宾使用城市生命线智…

高并发内存池

1.什么是内存池 内存池动态内存分配与管理技术&#xff0c;对于程序员来说&#xff0c;通常情况下&#xff0c;动态申请内存需要使用new,delete,malloc,free这些API来申请&#xff0c;这样导致的后果是&#xff0c;当程序长时间运行之后&#xff0c;由于程序运行时所申请的内存…

探索 Rollup:简化你的前端构建流程

&#x1f90d; 前端开发工程师&#xff08;主业&#xff09;、技术博主&#xff08;副业&#xff09;、已过CET6 &#x1f368; 阿珊和她的猫_CSDN个人主页 &#x1f560; 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 &#x1f35a; 蓝桥云课签约作者、已在蓝桥云…

Linux 命令vim(编辑器)

(一)vim编辑器的介绍 vim是文件编辑器&#xff0c;是vi的升级版本&#xff0c;兼容vi的所有指令&#xff0c;同时做了优化和延伸。vim有多种模式&#xff0c;其中常用的模式有命令模式、插入模式、末行模式&#xff1a;。 (二)vim编辑器基本操作 1 进入vim编辑文件 1 vim …

排序算法:归并排序、快速排序、堆排序

归并排序 要将一个数组排序&#xff0c;可以先将它分成两半分别排序&#xff0c;然后再将结果合并&#xff08;归并&#xff09;起来。这里的分成的两半&#xff0c;每部分可以使用其他排序算法&#xff0c;也可以仍然使用归并排序&#xff08;递归&#xff09;。 我看《算法》…

电源的纹波

电源纹波的产生 我们常见的电源有线性电源和开关电源&#xff0c;它们输出的直流电压是由交流电压经整流、滤波、稳压后得到的。由于滤波不干净&#xff0c;直流电平之上就会附着包含周期性与随机性成分的杂波信号&#xff0c;这就产生了纹波。 在额定输出电压、电流的情况下…

leetCode 1080.根到叶路径上的不足节点 + 递归 + 图解

给你二叉树的根节点 root 和一个整数 limit &#xff0c;请你同时删除树中所有 不足节点 &#xff0c;并返回最终二叉树的根节点。假如通过节点 node 的每种可能的 “根-叶” 路径上值的总和全都小于给定的 limit&#xff0c;则该节点被称之为 不足节点 &#xff0c;需要被删除…