图神经网络教程之HAN-异构图模型

异构图

包含不同类型节点和链接的异构图

异构图的定义:节点类别数量和边的类别数量加起来大于2就叫异构图。

meta-path元路径的定义:连接两个对象的复合关系,比如,节点类型A和节点类型B,A-B-A和B-A-B都是一种元路径。

meta-path下的邻居节点的定义:如下图所示。

在这里插入图片描述

其中m1-a1-m2,m1-a3-m3都是一种meta-path,所以m1的邻居有m2、m3以及本身m1

在这里插入图片描述

节点级别的attention和语义级别的attention

节点级别:简单来说就是单种meta-path求得节点embeddings,比如对于M-D-M,Terminator2的embeddings通过M-D-M的元路径即可求的另一个M(Termintor)的embeddings。

语义级别:对于Terminator的embeddings不再是根据一种meta-path进行获取,而是根据两种meta-path进行权重的分配相加得到。

节点级别:

举例子:

在这里插入图片描述

如上图所示,对于异构图,一种meta-path为蓝-黄-蓝,对于节点x1-xa-x2,所以x1与x2通过meta-path元路径,同理每一对节点,构成上图中的第二个图的连接方式。

在这里插入图片描述

对于节点x1,与节点x2、x3、x6相连,所以x2、x3、x6都是节点x1的邻居节点,也就是公式2。

对于公式三,分子将i和j节点拼接在一起以后乘以一个可学习的参数然后再通过激活函数,再通过exp。分母就是他的邻居节点的。

对后求的节点级别下的embeddings。

语义级别:

简单来说语义级别就是多种meta-path呗,只需要把每种meta-path下面的求出来进行加权就可以了。

在这里插入图片描述

如上图所示,通过节点级别的求解方法,求出来对于每一种metapath下面的embeddings,然后最后进行加权求和。

知道了上面的HAN的原理,下面讲解一下model代码。

在讲解原理的时候分为语义级别和节点级别,在代码的时候会分为给定已经处理好的邻接矩阵和直接输入异构图。

异构图直接输入(异构图模型。):

需要将meta-path转化为邻接矩阵即元组形式。

实现了Heterogeneous Graph Attention Network(HAN)模型,用于处理异构图数据。HAN是一种深度学习模型,用于在异构图中进行节点分类任务

import torch
import torch.nn as nn
import torch.nn.functional as Ffrom dgl.nn.pytorch import GATConv

首先,导入了PyTorch库以及用于图神经网络的相关模块。

class SemanticAttention(nn.Module):def __init__(self, in_size, hidden_size=128):super(SemanticAttention, self).__init__()# input:[Node, metapath, in_size]; output:[None, metapath, 1]; 所有节点在每个meta-path上的重要性值self.project = nn.Sequential(nn.Linear(in_size, hidden_size),nn.Tanh(),nn.Linear(hidden_size, 1, bias=False))

这里定义了一个名为SemanticAttention的PyTorch模型类,它用于计算每个节点在不同元路径(metapath)上的重要性。SemanticAttention类有以下成员:

  • __init__方法:初始化模型。它接受输入特征的维度in_size以及可选的隐藏层维度hidden_size。在初始化过程中,它创建了一个神经网络模块self.project,该模块包括两个线性层和一个Tanh激活函数,最后一个线性层没有偏差。
    def forward(self, z):w = self.project(z).mean(0)#每个节点在metapath维度的均值; mean(0): 每个meta-path上的均值(/|V|); (MetaPath, 1)beta = torch.softmax(w, dim=0)       # 归一化   # (M, 1)beta = beta.expand((z.shape[0],) + beta.shape) #  拓展到N个节点上的metapath的值   (N, M, 1)return (beta * z).sum(1)#(beta*z)=>所有节点,在metapath上的attention值;(beta*z).sum(1)=>节点最终的值(N,D*K)
  • forward方法:用于计算每个节点在不同元路径上的重要性。首先,将输入特征z通过self.project模块传递,然后计算每个元路径上的重要性均值w。接着,使用softmax函数对这些均值进行归一化,以获得每个元路径上的注意力权重beta。最后,将注意力权重与输入特征相乘,并对所有元路径求和,得到最终的节点表示。

这个SemanticAttention模块的目的是计算每个节点在不同元路径上的权重,以便后续的元路径级别的注意力聚合。

接下来,定义了另一个模型类HANLayer

class HANLayer(nn.Module):def __init__(self, num_meta_paths, in_size, out_size, layer_num_heads, dropout):super(HANLayer, self).__init__()self.gat_layers = nn.ModuleList()for i in range(num_meta_paths):  # meta-path Layers; 两个meta-path的维度是一致的self.gat_layers.append(GATConv(in_size, out_size, layer_num_heads,dropout, dropout, activation=F.elu))self.semantic_attention = SemanticAttention(in_size=out_size * layer_num_heads)  # 语义attention; out-size*layersself.num_meta_paths = num_meta_paths

HANLayer类代表了HAN模型中的一个层次。每个HANLayer层包括以下成员:

  • __init__方法:初始化层。它接受以下参数:
    • num_meta_paths:元路径的数量。
    • in_size:输入特征的维度。
    • out_size:输出特征的维度。
    • layer_num_heads:每个GAT层中的注意力头的数量。
    • dropout:用于正则化的dropout率。

在初始化过程中,它首先创建了多个GATConv层,每个GATConv层对应一个元路径,这些层将用于图注意力聚合。然后,创建了一个SemanticAttention模块,用于计算每个节点在不同元路径上的语义级别的注意力。

接下来,定义了整个HAN模型类HAN

class HAN(nn.Module):def __init__(self, num_meta_paths, in_size, hidden_size, out_size, num_heads, dropout):super(HAN, self).__init__()self.layers = nn.ModuleList()self.layers.append(HANLayer(num_meta_paths, in_size, hidden_size, num_heads[0], dropout)) # meta-path数量 + semantic_attentionfor l in range(1, len(num_heads)): # 多层多头,目前是没有self.layers.append(HANLayer(num_meta_paths, hidden_size * num_heads[l-1],hidden_size, num_heads[l], dropout))self.predict = nn.Linear(hidden_size * num_heads[-1], out_size)  # hidden*heads, classes; HAN->classes

HAN类是整个HAN模型的定义。它接受以下参数:

  • num_meta_paths:元路径的数量。
  • in_size:输入特征的维度。
  • hidden_size:隐藏层的维度。
  • out_size:输出特征的维度(通常是类别数量)。
  • num_heads:一个列表,指定每个HANLayer层中的注意力头数量。
  • dropout:用于正则化的dropout率。

在初始化过程中,它首先创建了多个HANLayer层,每个HANLayer层包括一个或多个GATConv层和一个SemanticAttention层。

输入处理好的异构图,即邻接矩阵(普通图模型。):

import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.nn.pytorch import GATConv

首先,导入了必要的库和模块。

class SemanticAttention(nn.Module):def __init__(self, in_size, hidden_size=128):super(SemanticAttention, self).__init__()self.project = nn.Sequential(nn.Linear(in_size, hidden_size),nn.Tanh(),nn.Linear(hidden_size, 1, bias=False))

这里定义了一个名为SemanticAttention的PyTorch模型类,它用于计算每个节点在不同元路径上的语义级别的重要性。和第一个代码段的SemanticAttention类相似,这个类也包括以下成员:

  • __init__方法:初始化模型。它接受输入特征的维度in_size以及可选的隐藏层维度hidden_size。在初始化过程中,它创建了一个神经网络模块self.project,该模块包括两个线性层和一个Tanh激活函数,最后一个线性层没有偏差。
    def forward(self, z):w = self.project(z).mean(0)                    # (M, 1)beta = torch.softmax(w, dim=0)                 # (M, 1)beta = beta.expand((z.shape[0],) + beta.shape) # (N, M, 1)return (beta * z).sum(1)                       # (N, D * K)
  • forward方法:用于计算每个节点在不同元路径上的语义级别的重要性。首先,将输入特征z通过self.project模块传递,然后计算每个元路径上的语义级别的均值权重w。接着,使用softmax函数对这些均值进行归一化,得到每个元路径上的注意力权重beta,将这些权重与输入特征相乘,并对所有元路径求和,得到最终的节点表示。

接下来,定义了另一个模型类HANLayer,它代表HAN模型中的一个层次。

class HANLayer(nn.Module):def __init__(self, meta_paths, in_size, out_size, layer_num_heads, dropout):super(HANLayer, self).__init__()# One GAT layer for each meta path based adjacency matrixself.gat_layers = nn.ModuleList()for i in range(len(meta_paths)):self.gat_layers.append(GATConv(in_size, out_size, layer_num_heads,dropout, dropout, activation=F.elu,allow_zero_in_degree=True))self.semantic_attention = SemanticAttention(in_size=out_size * layer_num_heads)self.meta_paths = list(tuple(meta_path) for meta_path in meta_paths)  # 将meta-path转换成元组形式self._cached_graph = Noneself._cached_coalesced_graph = {}def forward(self, g, h):semantic_embeddings = []if self._cached_graph is None or self._cached_graph is not g:  # 第一次,建立一张metapath下的异构图self._cached_graph = gself._cached_coalesced_graph.clear()for meta_path in self.meta_paths:self._cached_coalesced_graph[meta_path] = dgl.metapath_reachable_graph(g, meta_path)  # 构建异构图的邻居;# self._cached_coalesced_graph 多个metapath下的异构图for i, meta_path in enumerate(self.meta_paths):new_g = self._cached_coalesced_graph[meta_path]  # meta-path下的节点邻居图semantic_embeddings.append(self.gat_layers[i](new_g, h).flatten(1))   # 图attentionsemantic_embeddings = torch.stack(semantic_embeddings, dim=1)                  # (N, M, D * K)return self.semantic_attention(semantic_embeddings)                            # (N, D * K)

HANLayer类包括以下主要部分:

  • __init__方法:初始化HAN层,它包括多个GATConv层以及一个语义注意力模块。每个GATConv层对应一个元路径,用于处理节点在该元路径上的信息。语义注意力模块用于计算节点在不同元路径上的语义级别的注意力。

  • forward方法:执行HAN层的前向传播。对于每个元路径,首先获取该元路径的邻居图,然后通过GATConv层计算节点的注意力表示。最后,通过语义注意力模块将不同元路径上的表示进行加权求和,得到最终的节点表示。

最后,定义了整个HAN模型类HAN

class HAN(nn.Module):def __init__(self, meta_paths, in_size, hidden_size, out_size, num_heads, dropout):super(HAN, self).__init__()self.layers = nn.ModuleList()self.layers.append(HANLayer(meta_paths, in_size, hidden_size, num_heads[0], dropout))for l in range(1, len(num_heads)):self.layers.append(HANLayer(meta_paths, hidden_size * num_heads[l-1],hidden_size, num_heads[l], dropout))self.predict = nn.Linear(hidden_size * num_heads[-1], out_size)

HAN类定义了整个HAN模型,包括多个HANLayer层以及最后的预测层。

  • __init__方法:初始化HAN模型,它包括多个HANLayer层,每个HANLayer层用于处理一个元路径的信息。最后,添加一个线性预测层,将最终的节点表示映

射到输出特征(通常是类别数量)。

  • forward方法:执行HAN模型的前向传播。它依次通过多个HANLayer层来计算最终的输出,每个HANLayer层都包括元路径信息的处理和注意力聚合。

训练代码train

训练代码就是常规的套路。

  1. 引入必要的库和模块:

    • 导入了PyTorch库和sklearn库,用于深度学习和评估模型性能。
    • 导入了自定义的load_dataEarlyStopping函数,以及其他必要的模块。
  2. score函数:

    • 这个函数用于计算模型的性能指标,包括准确率(accuracy)、微平均F1分数(micro_f1),和宏平均F1分数(macro_f1)。
    • 它接受模型的预测结果(logits)和真实标签(labels),然后计算这些性能指标。
    • 准确率表示正确分类的样本比例,微平均F1分数和宏平均F1分数是一种综合的评估指标,用于度量分类模型的性能。
  3. evaluate函数:

    • 这个函数用于评估模型在验证集上的性能。
    • 它接受模型(model)、图数据(g)、特征数据(features)、标签数据(labels)、掩码数据(mask),以及损失函数(loss_func)作为输入。
    • 在评估过程中,模型处于评估模式(model.eval()),不会更新梯度。
    • 通过模型预测验证集上的结果,并计算损失、准确率、微平均F1分数和宏平均F1分数。
    • 最后返回这些评估指标。
  4. main函数:

    • 这是主要的训练和评估逻辑所在的函数。
    • 首先,加载数据(包括图数据、特征数据、标签数据等)并将其移动到指定的计算设备(CPU或GPU)上。
    • 根据参数args中的'hetero'标志,选择不同的模型和数据处理方式。如果'hetero'为True,则使用异构图模型;否则,使用普通图模型。
    • 定义了模型的损失函数、优化器和早停(EarlyStopping)对象。
    • 开始训练循环,每个epoch进行一次训练和验证。在训练过程中,计算损失、准确率和F1分数等指标,并打印出来。如果验证集上的性能不再提升,会触发早停(early stopping)。
    • 最后,在测试集上评估模型的性能,并打印出测试集上的损失、准确率、微平均F1分数和宏平均F1分数。
  5. if __name__ == '__main__': 部分:

    • 这个部分用于设置命令行参数,并调用main函数来运行训练和评估过程。
    • 可以通过命令行传递参数来配置模型的训练和数据处理方式。

rlyStopping)对象。

  • 开始训练循环,每个epoch进行一次训练和验证。在训练过程中,计算损失、准确率和F1分数等指标,并打印出来。如果验证集上的性能不再提升,会触发早停(early stopping)。
  • 最后,在测试集上评估模型的性能,并打印出测试集上的损失、准确率、微平均F1分数和宏平均F1分数。
  1. if __name__ == '__main__': 部分:

    • 这个部分用于设置命令行参数,并调用main函数来运行训练和评估过程。
    • 可以通过命令行传递参数来配置模型的训练和数据处理方式。

总体来说,这段代码实现了一个用于异构图数据或普通图数据的节点分类任务的训练和评估流程。它加载数据、选择模型、进行训练和验证,最后在测试集上评估模型性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/66146.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux上git的简单使用

git的作用:版本控制多人协作 客户端 磁盘上的文件-->本地仓库-->远端仓库 服务端 gitee和GitHub是基于git的商业化网站 git的命令行如何使用? 1、新建一个仓库 .git ignore 是忽略带有某些后缀的文件的上传。 例如:里面有 .sln …

实战:基于卷积的MNIST手写体分类

前面实现了基于多层感知机的MNIST手写体识别,本章将实现以卷积神经网络完成的MNIST手写体识别。 1. 数据的准备 在本例中,依旧使用MNIST数据集,对这个数据集的数据和标签介绍,前面的章节已详细说明过了,相对于前面章…

vmware 16增加硬盘容量并在Ubuntu 18.04上边格式化并挂载

参考了《增加 VM虚拟机硬盘容量》 《Linux学习之分区挂载》中有给VMWare 16虚拟机添加一块硬盘的内容,需要先参考添加硬盘。 sudo mkfs.ext4 /dev/sda4给/dev/sda4进行ext4格式化。 sudo mkdir /mountsda4新建一个挂载目录。 sudo mount -t ext4 /dev/sda4 /mo…

AMBEO 双声道空间音频现已迈进直播制作领域

图片来源:Unsplash,作者:Bence Balla-Schottner AMBEO 双声道空间音频现已迈进直播制作领域 为所有观众解锁更加身临其境的听觉体验 森海塞尔将功能强大的 AMBEO 双声道空间音频技术引入了广播电视直播应用领域,对所有体育赛事广…

AD16 基础应用技巧(一些 “偏好“ 设置)

1. 修改铺铜后自动更新铺铜 AD16 铺铜 复制 自动变形 偏好设置 将【DXP】中的【参数选择】。 将【PCB Editor】中的【General】,然后勾选上【Repour Polygons After Modification】。 2. PCB直角走线处理与T型滴泪 一些没用的AD技巧——AD PCB直角走线处理与…

seq2seq与引入注意力机制的seq2seq

1、什么是 seq2seq? 就是字面意思,“句子 到 句子”。比如翻译。 2、seq2seq 有一些特点 seq2seq 的整体架构是 “编码器-解码器”。 其中,编码器是 RNN,并将 最后一个hidden state(隐藏状态)【即&…

【自用】西门子s7-200连接显示屏和物联网盒子完整配置过程

总览 1.PLC配置 2.显示屏配置 3.物联网盒子配置 一、PLC配置 1.连接PLC软件 STEP-7MicroWIN V4.0 SP9完整版 链接:https://pan.baidu.com/s/17LMEXnbkQZMPI8Bte24Eug?pwdjsi3 提取码:jsi3 2.PLC配置 打开 PLC 上面的小盖子,把红色按钮…

前端:html实现页面切换、顶部标签栏,类似于浏览器的顶部标签栏(完整版)

效果 代码 <!DOCTYPE html> <html><head><style>/* 左侧超链接列表 */.link {display: block;padding: 8px;background-color: #f2f2f2;cursor: pointer;}/* 顶部标签栏 */#tabsContainer {width:98%;display: flex;align-items: center;overflow-x: …

简易虚拟培训系统-UI控件的应用4

目录 Slider组件的常用参数 示例-使用Slider控制主轴 示例-Slider控制溜板箱的移动 本文以操作面板为例&#xff0c;介绍使用Slider控件控制开关和速度。 Slider组件的常用参数 Slider组件下面包含了3个子节点&#xff0c;都是Image组件&#xff0c;负责Slider的背景、填充区…

Ubuntu下的QT开发

ubuntu安装QT的组件如下&#xff1a; 若要在ubuntu下启动QT有两种方案&#xff0c;一种是在菜单栏搜索qt双QT Create&#xff1b;另一种则是使用命令&#xff1a;/opt/Qt5.12.9/Tools/QtCreator/bin/qtcreator.sh

pytest自动化测试两种执行环境切换的解决方案

目录 一、痛点分析 方法一&#xff1a;Hook方法pytest_addoption注册命令行参数 1、Hook方法注解 2、使用方法 方法二&#xff1a;使用插件pytest-base-url进行命令行传参 一、痛点分析 在实际企业的项目中&#xff0c;自动化测试的代码往往需要在不同的环境中进行切换&am…

C语言每日一练--------Day(11)

本专栏为c语言练习专栏&#xff0c;适合刚刚学完c语言的初学者。本专栏每天会不定时更新&#xff0c;通过每天练习&#xff0c;进一步对c语言的重难点知识进行更深入的学习。 今日练习题关键字&#xff1a;找到数组中消失的数字 哈希表 &#x1f493;博主csdn个人主页&#xff…

Kao框架学习

中间件&#xff1a;洋葱模型 这是官网上给出的示例&#xff0c;从logger依次往下执行&#xff0c;执行到最底层的response往回退&#xff0c;结构很像同心圆的洋葱从外层向内层再由内层向外层。 next表示暂停当前层的代码进入下一层&#xff0c; 当最后一层执行完毕开始回溯&a…

Jenkins清理构建(自动)

需求背景实现方法 Dashboard-->Project-->配置-->General-->Discard old builds # 注意&#xff1a;自动清理构建历史将在下次构建时进行

【Java 动态数据统计图】动态X轴二级数据统计图思路Demo(动态,排序,动态数组(重点推荐:难)九(131)

需求&#xff1a; 1.有一组数据集合&#xff0c;数据集合中的数据为动态&#xff1b; 举例如下&#xff1a; [{province陕西省, city西安市}, {province陕西省, city咸阳市}, {province陕西省, city宝鸡市}, {province陕西省, city延安市}, {province陕西省, city汉中市}, {pr…

【网络安全带你练爬虫-100练】第19练:使用python打开exe文件

目录 一、目标1&#xff1a;调用exe文件 二、目标2&#xff1a;调用exe打开文件 一、目标1&#xff1a;调用exe文件 1、subprocess 模块允许在 Python 中启动一个新的进程&#xff0c;并与其进行交互 2、subprocess.run() 函数来启动exe文件 3、subprocess.run(["文件路…

企业应用系统 PHP项目支持管理系统Dreamweaver开发mysql数据库web结构php编程计算机网页

一、源码特点 PHP 项目支持管理系统是一套完善的web设计系统 应用于企业项目管理&#xff0c;从企业内部的各个业务环境总体掌握&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。 php项目支撑管理系统2 二、功能介绍 (1)权限管理&#xff1…

手写RPC——数据序列化工具protobuf

手写RPC——数据序列化工具protobuf Protocol Buffers&#xff08;protobuf&#xff09;是一种用于结构化数据序列化的开源库和协议。下面是 protobuf 的一些优点和缺点&#xff1a; 优点&#xff1a; 高效的序列化和反序列化&#xff1a;protobuf 使用二进制编码&#xff0c…

力扣:随即指针138. 复制带随机指针的链表

复制带随机指针的链表 OJ链接 分析&#xff1a; 该题的大致题意就是有一个带随机指针的链表&#xff0c;复制这个链表但是不能指向原链表的节点&#xff0c;所以每一个节点都要复制一遍 大神思路&#xff1a; ps:我是学来的 上代码&#xff1a; struct Node* copyRandomList(s…

TuyaOS开发学习笔记(1)——NB-IoT开发搭建环境、编译烧写(MT2625)

一、搭建环境 1.1 官方资料 TuyaOS 1.2 安装VMware 官网下载&#xff1a;https://customerconnect.vmware.com/en/downloads/info/slug/desktop_end_user_computing/vmware_workstation_pro/16_0 百度网盘&#xff1a;https://pan.baidu.com/s/1oN7H81GV0g6cD9zsydg6vg 提取…