GAT知识总结

《GRAPH ATTENTION NETWORKS》
解决GNN聚合邻居节点的时候没有考虑到不同的邻居节点重要性不同的问题,GAT借鉴了Transformer的idea,引入masked self-attention机制, 在计算图中的每个节点的表示的时候,会根据邻居节点特征的不同来为其分配不同的权值。

1.论文公式:

注意力系数:代表节点j的特征对i的重要性。
经过归一化:
得到最终注意力系数:
   a是 注意力机制权重, h是 节点特征,W是 节点特征的权重
如图:
将多头注意力得到的系数整合:
 一共k个注意力头,求i节点聚合后的特征。
如果是到了最后一层:
 使用avg去整合
如图:

2.论文导图:

3.核心代码:

从输入h节点特征 到 输出h`的过程:
  • 数据处理方式和GCN一样,对边、点、labels编码,并得到h和adj矩阵。
  • GraphAttention层的具体实现:
定义权重W,作用是将input的维度转换,且是可学习参数。
self.W = nn.Parameter(torch.empty(size=(in_features, out_features))) #一开始定义的是空的
nn.init.xavier_uniform_(self.W.data, gain=1.414)  #随机初始化,然后在学习的过程中不断更新学习参数。
定义 注意力机制权重a,用来求注意力分数,也是可学习参数。
self.a = nn.Parameter(torch.empty(size=(2*out_features, 1))) 
nn.init.xavier_uniform_(self.a.data, gain=1.414)  #随机初始化
求得注意力分数
Wh1 = torch.matmul(Wh, self.a[:self.out_features, :])
Wh2 = torch.matmul(Wh, self.a[self.out_features:, :])
e = Wh1 + Wh2.T
然后进行softmax、dropout。
zero_vec = -9e15*torch.ones_like(e)
##如果如果两个节点有边链接则使其注意力得分为e,如果没有则使用-9e15的mask作为其得分
attention = torch.where(adj > 0, e, zero_vec)
##使用 softmax 函数对注意力得分进行归一化
attention = F.softmax(attention, dim=1)
##dropout,为模型的训练增加了一些正则化
attention = F.dropout(attention, self.dropout, training=self.training)
  • 整体GAT实现:
根据头数n,定义n个GraphAttention 层(每个头都具有相同的参 数,并且 共享输入特征 ),和一个 GraphAttention 输出层。
n个层各输出一个注意力分数,最后拼接成一个。 features( features.shape[1] --> hidden(自定义) --*n(自定义) --> n*hidden
self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]
for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)   #将层命名然后加到模块里
输出层将 n*hidden --> class( int(labels.max()) + 1)
self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)
forward(x, adj)函数:
x = F.dropout(x, self.dropout, training=self.training)
x = torch.cat([att(x, adj) for att in self.attentions], dim=1)   #遍历每个层对象,传参去运行。重复n次后得到 n * hidden 维的拼接后的注意力分数
x = F.dropout(x, self.dropout, training=self.training)
x = F.elu(self.out_att(x, adj))
return F.log_softmax(x, dim=1)   #输出层最后得到class维的分类概率矩阵。
整体是 通过 x 在层之间传递参数的,x一直在变,不断的dropout。

4.采取的正则化:

正则化通过对算法的修改来减少泛化误差,目前在深度学习中使用较多的策略有参数范数惩罚,提前终止,DropOut等。
  • dropout:通过使其它隐藏层神经网络单元不可靠从而 阻止了共适应的发生。因此,一个隐藏层神经元不能依赖其它特定神经元去纠正其错误。因为dropout程序导致两个神经元不一定每次都在一个dropout网络中出现。这样权值的更新不再依赖于有固定关系的隐含节点的共同作用,阻止了某些特征仅仅在其它特定特征下才有效果的情况 。迫使网络去学习更加鲁棒的特征 ,这些特征在其它的神经元的随机子集中也存在。
  • 提前停止:是将一部分训练集作为验证集。 当 验证集的性能越来越差时或者性能不再提升,则立即停止对该模型的训练。 这被称为提前停止。
在当前best的loss对应的epoch之后后,又跑了patience轮,但是新的loss没有比之前best loss小的,也就是loss越来越大了,出现过拟合了。这时就应该停止训练。
if loss_values[-1] < best:
        best = loss_values[-1]
        best_epoch = epoch
        bad_counter = 0
    else:
        bad_counter += 1
    if bad_counter == args.patience:
        break

5.和GCN对比adj的用法:

GCN通过变换(归一、正则)的邻接矩阵和节点特征矩阵相乘,得到了考虑邻接节点后的节点特征矩阵。通过两层卷积,得到output。
GAT利用Wh和注意力权重a计算出注意力分数,再用adj矩阵把不相邻的节点的分数mask掉。即,可以给不同节点分配不同权重。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/49086.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

解开基于大模型的Text2SQL的神秘面纱

你好&#xff0c;我是 shengjk1&#xff0c;多年大厂经验&#xff0c;努力构建 通俗易懂的、好玩的编程语言教程。 欢迎关注&#xff01;你会有如下收益&#xff1a; 了解大厂经验拥有和大厂相匹配的技术等 希望看什么&#xff0c;评论或者私信告诉我&#xff01; 文章目录 一…

JAVA基础 - 对象

目录 一. 简介 二. 空对象 三. 构造方法 四. 析构方法 五. this关键字 六. 对象销毁 一. 简介 在 Java 中&#xff0c;对象&#xff08;Object&#xff09;是面向对象编程的核心概念。 对象是类的实例化&#xff0c;它将数据&#xff08;属性&#xff09;和操作这些数据…

【运算放大器】输入失调电压和输入偏置电流(2)实例计算

概述 根据上一篇文章的理论&#xff0c;分别计算没有输入电阻和有输入电阻两种情况下的运放总输出误差。例题来自于TI高精度实验室系列课程。 目录 概述实例计算 1&#xff1a;没有输入电阻实例计算 2&#xff1a;有输入电阻总结 实例计算 1&#xff1a;没有输入电阻 要求&am…

通过IEC104转MQTT网关对接阿里云、华为云、亚马逊AWS、ThingsBoard、Ignition、Zabbix

随着工业互联网的快速发展&#xff0c;传统电力系统中的IEC 104协议设备正逐步向更加开放、灵活的物联网架构转型。MQTT&#xff08;Message Queuing Telemetry Transport&#xff09;作为一种轻量级的消息传输协议&#xff0c;因其低带宽消耗、高可靠性和广泛的支持性&#xf…

vue3前端开发-小兔鲜项目-路由拦截器增加token的携带

vue3前端开发-小兔鲜项目-路由拦截器增加token的携带&#xff01;实际开发中&#xff0c;很多业务接口的请求&#xff0c;都要求必须是登录状态&#xff01;为此&#xff0c;这个token信息就会频繁的被加入到了请求头部信息中。request请求头内既然需要频繁的携带这个token.我们…

集团ERP信息化项目实施方案(82页PPT)

集团ERP信息化项目实施方案的82页PPT详尽阐述了企业资源规划&#xff08;ERP&#xff09;系统实施的全过程&#xff0c;旨在帮助集团整合多个业务流程于一个统一的平台。方案从当前市场环境分析入手&#xff0c;解释了ERP系统对于提升集团运营效率、降低成本和优化资源配置的必…

【OpenCV C++20 学习笔记】图片融合

图片融合 原理实现结果展示完整代码 原理 关于OpenCV的配置和基础用法&#xff0c;请参阅本专栏的其他文章&#xff1a;垚武田的OpenCV合集 这里采用的图片熔合的算法来自Richard Szeliski的书《Computer Vision: Algorithms and Applications》&#xff08;《计算机视觉&#…

STM32是使用的内部时钟还是外部时钟

STM32是使用的内部时钟还是外部时钟&#xff0c;经常会有人问这个问题。 1、先了解时钟树&#xff0c;见下图&#xff1a; 2、在MDK中&#xff0c;使用的是HSEPLL作为SYSCLK&#xff0c;因此需要对时钟配置寄存器&#xff08;RCC_CFGR&#xff09;进行配置&#xff0c;寄存器内…

Eaton伊顿触摸屏维修XV-303-15-C00-A00-1C

伊顿触摸屏维修,工业触摸屏维修,主板维修,坏高故障,损坏显示,不损坏,运行稳定,不花屏,无反应慢等故障维修,维修有保障,资费低.,触摸屏主板坏,高压板故障,按键损坏等均可修理。 伊顿触摸屏维修 EATON触摸屏维修 伊顿工控机维修 EATON工控机维修 伊顿人机界面维修 EATON触摸屏维…

深度解读大语言模型中的Transformer架构

一、Transformer的诞生背景 传统的循环神经网络&#xff08;RNN&#xff09;和长短期记忆网络&#xff08;LSTM&#xff09;在处理自然语言时存在诸多局限性。RNN 由于其递归的结构&#xff0c;在处理长序列时容易出现梯度消失和梯度爆炸的问题。这导致模型难以捕捉长距离的依…

【机器学习】用Jupyter Notebook实现并探索单变量线性回归的代价函数以及遇到的一些问题

引言 在机器学习中&#xff0c;代价函数&#xff08;Cost Function&#xff09;是一个用于衡量模型预测值与实际值之间差异的函数。在监督学习中&#xff0c;代价函数是评估模型性能的关键工具&#xff0c;它可以帮助我们了解模型在训练数据上的表现&#xff0c;并通过优化过程…

数据结构——排序大汇总(建议收藏)

这篇文章将为大家详细讲解各大排序的基本思想与实现代码~ 内有动图 首先&#xff0c;我们来看常见的排序有以下几大类&#xff1a; 1.插入排序 插入排序的主要思想是将每个位置的元素插入到前面已具备顺序的数组中 实际中我们玩扑克牌时&#xff0c;就用了插入排序的思想 …

快手可灵视频生成大模型全方位测评

快手视频生成大模型“可灵”&#xff08;Kling&#xff09;&#xff0c;是全球首个真正用户可用的视频生成大模型&#xff0c;自面世以来&#xff0c;凭借其无与伦比的视频生成效果&#xff0c;在全球范围内赢得了用户的热烈追捧与高度评价。截至目前&#xff0c;申请体验其内测…

人工智能:大语言模型提示注入攻击安全风险分析报告下载

大语言模型提示注入攻击安全风险分析报告下载 今天分享的是人工智能AI研究报告&#xff1a;《大语言模型提示注入攻击安全风险分析报告》。&#xff08;报告出品方&#xff1a;大数据协同安全技术国家工程研究中心安全大脑国家新一代人工智能开放创新平台&#xff09; 研究报告…

stats 监控 macOS 系统

Stats 监控 macOS 系统 CPU 利用率GPU 利用率内存使用情况磁盘利用率网络使用情况电池电量 brew install stats参考 stats github

59、mysql存储过程

存储过程 一、存储过程&#xff1a; 1.1、存储过程的概念 概念&#xff1a;完成特定功能的sql语句的集合。把定义好的sql集合在一个特定的sql的函数当中 每次执行调用函数即可。还可以实现传参的调用。 1.2、存储过程的语法&#xff1a; delimiter $$ ##delimiter开始和结…

支持4K高分辨率,PixArt-Sigma最新文生图落地经验

PixArt-Sigma是由华为诺亚方舟实验室、大连理工大学和香港大学的研究人员共同开发的一个先进的文本到图像&#xff08;Text-to-Image&#xff0c;T2I&#xff09;生成模型。 PixArt-Sigma是在PixArt-alpha的基础上进一步改进的模型&#xff0c;旨在生成高质量的4K分辨率图像。…

2024牛客暑期多校第四场

A-LCT 带权并查集&#xff0c;维护一下每个点在当前树的深度和以它为根能找到的最深的深度。‘ #include<bits/stdc.h>using namespace std; typedef long long ll; const int N 1e6 100;int fa[N],ans[N],val[N];int find(int x){if(fa[x]x)return x;int tfa[x];fa[x…

C++初学(3)

面向对象编程&#xff08;OOP&#xff09;的本质是设计并拓展自己的数据类型&#xff0c;设计自己的数据类型就是让类型与数据匹配。内置的C类型分为两组&#xff1a;基本类型和复合类型。这里我们将介绍基本类型的整数和浮点数 3.1、简单变量 3.1.1、变量名 C必须遵循几种简…

场外期权如何报价?名义本金是什么?

今天带你了解场外期权如何报价&#xff1f;名义本金是什么&#xff1f;投资者首先需要挑选自己想要进行期权交易的沪深上市公司股票。选出股票后&#xff0c;需要将股票信息、预期的操作时间&#xff08;如期限&#xff09;、看涨或看跌的选择以及预计的交易金额等信息报给场外…