Python GNN图神经网络代码实战;GAT代码模版,简单套用,易于修改和提升,图注意力机制代码实战

1.GAT简介

GAT(Graph Attention Network)模型是一种用于图数据的深度学习模型,由Veličković等人在2018年提出。它通过自适应地在图中计算节点之间的注意力来学习节点之间的关系,并在节点表示中捕捉全局和局部信息。

GAT模型的核心思想是通过注意力机制,对图中的节点进行加权聚合。与传统的图卷积网络(GCN)模型不同,GAT不仅考虑节点本身的特征信息,还考虑了节点与其邻居节点之间的关系。每个节点在聚合邻居节点的特征时,会分配不同的注意力权重,以捕捉不同邻居节点对该节点的贡献程度。

GAT模型具有以下特点和优势:

  1. 自适应学习的注意力机制:GAT模型能够根据数据自动学习节点之间的注意力权重,从而捕捉到不同节点之间的重要性和关系。
  2. 并行计算效率高:由于注意力权重是节点间独立计算的,可以高效地并行计算,适用于大规模图数据。
  3. 稀疏性:GAT模型引入了注意力系数,可以将注意力集中在有用的邻居节点上,减小计算量和存储需求。
  4. 灵活性:GAT模型可以根据任务需求设计不同的注意力权重计算方式,适应不同的图学习任务。

2.代码实战

模型架构分为两部分:GAT主体部分,GAT的注意力计算部分

注意力机制:首先输入参数为(节点的特征表示hi,邻接矩阵),注意这个hi可以来源于上一层,也可以是原始的;先计算每个节点到中心节点的权值,也可以称为权重或者系数,然后对所有的权值进行归一化,最后对每个邻居节点与对应的权值相乘,然后相加就得到了中心节点的最终表示,注意求权值的时候是要考虑中心节点本身的;

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as Fclass GATLayer(nn.Module):def __init__(self, in_features, out_features, dropout, alpha, concat=True):super(GATLayer, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.dropout = dropoutself.alpha = alphaself.concat = concatself.W = nn.Linear(in_features, out_features)self.a = nn.Linear(2*out_features, 1)def forward(self, h, adj):Wh = self.W(h)  # W*hN = h.size()[0]  # Number of nodesa_input = torch.cat([Wh.repeat(1, N).view(N*N, -1), Wh.repeat(N, 1)], dim=1).view(N, -1, 2*self.out_features)e = F.leaky_relu(self.a(a_input).squeeze(2), negative_slope=self.alpha)zero_vec = -9e15*torch.ones_like(e)attention = torch.where(adj > 0, e, zero_vec)attention = F.softmax(attention, dim=1)attention = F.dropout(attention, p=self.dropout, training=self.training)h_prime = torch.matmul(attention, Wh)if self.concat:return F.elu(h_prime)else:return h_primeclass GAT(nn.Module):def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):super(GAT, self).__init__()self.dropout = dropoutself.hidden = nn.ModuleList([GATLayer(nfeat, nhid, dropout, alpha, concat=True) for _ in range(nheads)])self.out_att = GATLayer(nhid*nheads, nclass, dropout, alpha, concat=False)def forward(self, x, adj):x = F.dropout(x, self.dropout, training=self.training)x = torch.cat([att(x, adj) for att in self.hidden], dim=1)x = F.dropout(x, self.dropout, training=self.training)x = F.sigmoid(self.out_att(x, adj))return F.log_softmax(x, dim=1)# 创建示例数据和邻接矩阵
adj = torch.tensor([[0, 1, 1, 0],[1, 0, 1, 1],[1, 1, 0, 1],[0, 1, 1, 0]])  # 邻接矩阵
features = torch.randn(4, 5)  # 特征矩阵# 创建GAT模型
model = GAT(nfeat=5, nhid=8, nclass=2, dropout=0.6, alpha=0.2, nheads=2)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)# 训练模型
for epoch in range(100):optimizer.zero_grad()output = model(features, adj)# 假设这里有标签数据yy = torch.LongTensor([0, 1, 0, 1])  # 标签loss = criterion(output, y)loss.backward()optimizer.step()# 测试模型
output = model(features, adj)
_, predictions = output.max(dim=1)
correct = (predictions == y).sum().item()
accuracy = correct / len(y)
print("准确率:", accuracy)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/21462.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AI文章互评:得分最高的竟然不是GPT-4!

大家好,我是木易,一个持续关注AI领域的互联网技术产品经理,国内Top2本科,美国Top10 CS研究生,MBA。我坚信AI是普通人变强的“外挂”,所以创建了“AI信息Gap”这个公众号,专注于分享AI全维度知识…

Rspack + vue 修改css代码不能热更新

环境 rspack/cli0.6.5rspack/core0.6.5vue3.4.27 原因 使用了 rspack.config 中的 experiments.css:true(默认) export defualt {//...module: {rules:[{test: /\.css/,type: css},{test: /\.less/,use:[less-loader],type: css},]},//... } 自带的…

Dubbo的路由策略剖析

1 概述 Dubbo中的路由策略的作用是服务消费端使用路由策略对服务提供者列表进行过滤和选择,最终获取符合路由规则的服务提供者。 Dubbo中的路由策略主要分为两类,StateRouter和普通Router。 StateRouter (如TagStateRouter、ConditionStat…

实力!云起无垠晋级“第九届安全创客汇”年度10强

2024年5月28日,第九届“安全创客汇”复赛在重庆圆满落幕。在本次国内最具影响力的网络安全创业大赛中,云起无垠凭借其技术的创新性和巨大市场价值,成功跻身年度十强。 随着人工智能技术的不断发展,特别是在大模型技术的推动下&…

【图像处理与机器视觉】XJTU期末考点

题型 选择:1 分10 填空:1 分15 简答题(也含有计算和画图):10 分*4 计算题:15 分20 分 考点 选择题(部分) 数字图像处理基础 p(x,y),q(s,t)两个像素之间的距离由公式&#xff1a…

湖南(品牌调研)源点咨询 企业品牌调研侧重点分析

本文由湖南长沙(市场调研)源点咨询编辑发布 企业建立品牌,往往都需进行科学性的品牌调研。因为只有这样,才能让企业更好的把握市场的发展趋势,进而为品牌的建立和发展提供更有价值的数据参考!那么品牌的调…

江淮集团分享:江淮集团数据管理实践

下文为江淮集团信息化管理部副部长丁志海的演讲全文: 大家下午好。我是来自江淮汽车的丁志海,我做IT、做信息化做这一块有二十多年了。这次得帆邀请我来讲讲数据管理的实践经验。我就想说一说我的感受,为什么我们当初选择得帆,和一…

微信小程序计算器

微信小程序计算器 index.wxml <view classscreen>{{result}}</view><view classtest-bg><view classbtnGroup><view classitem grey bindtapclickButton id"{{C}}">AC</view><view classitem grey bindtapclickButton id&q…

AI精选付费资料包【37GB】

课程介绍 一、人工智能论文合集 二、AI必读经典书籍 三、超详细人工智能学习大纲 四、机器学习基础算法教程 五、深度学习神经网络基础教程 六、计算机视觉实战项目 课程获取 资料&#xff1a;AI精选付费资料包&#xff08;37.4GB&#xff09;获取&#xff1a;扫码关注公z号…

esp8266阿里云上线(小程序控制)

此wechatproject已上传在页面最上方 由图可见&#xff0c;项目只有两个页面&#xff0c;一个是获取该产品下的设备信息列表&#xff0c;一个是某设备对应的详情控制页面&#xff0c;由于这个项目只利用esp8266板子上自带的led&#xff0c;功能简单&#xff0c;只需要控制开关即…

leetcode 575.分糖果

思路&#xff1a;开两个数组&#xff0c;一个用来存储非负数的糖果个数&#xff0c;一个用来存储负数的糖果个数&#xff0c;这两个数组都是状态数组&#xff0c;而不是计数数组 如果当前能够吃的种类大于现有的种类&#xff0c;现有的种类个数就是答案&#xff1b; 如果当前…

Update! 基于RockyLinux9.3离线安装Zabbix6.0

链接&#xff1a; Ansible离线部署 之 Zabbixhttp://mp.weixin.qq.com/s?__bizMzk0NTQ3OTk3MQ&mid2247487434&idx1&sn3128800a0219c5ebc5a3f89d2c8ccf50&chksmc3158786f4620e90afe440bb32fe68541191cebbabc2d2ef196f7300e84cde1e1b57383c521a&scene21#we…

YOLOv9改进策略 | Conv篇 | 利用YOLOv10提出的SCDown魔改YOLOv9进行下采样(附代码 + 结构图 + 添加教程)

一、本文介绍 本文给大家带来的改进机制是利用YOLOv10提出的SCDown魔改YOLOv9进行下采样,其是更高效的下采样。具体而言,其首先利用点卷积调整通道维度,然后利用深度卷积进行空间下采样。这将计算成本减少到和参数数量减少到。同时,这最大限度地保留了下采样过程中的信息,…

创新指南|提高人才回报率的重要举措和指标

员工是组织最大的投资&#xff0c;也是最深层的价值源泉。人才系统必须同时强调生产力和价值创造。让合适的人才担任合适的职位&#xff0c;并为员工提供成功所需的支持和机会&#xff0c;这是实现回报的关键。本文将介绍组织可以采取的五项行动&#xff0c;以最大化企业的人才…

postgresql常用命令#postgresql认证

PostgreSQL 是一个功能强大的开源关系数据库管理系统&#xff0c;提供了一系列命令行工具来管理和操作数据库。以下是一些常用的 PostgreSQL 命令&#xff0c;涵盖数据库和用户管理、数据操作以及查询和维护等方面。 #PostgreSQL培训 #postgresql认证 #postgreSQL考试 #PG考试…

汽车识别项目

窗口设计 这里的代码放在py文件最前面或者最后面都无所谓 # 创建主窗口 window tk.Tk() window.title("图像目标检测系统") window.geometry(1000x650) # 设置窗口大小# 创建背景画布并使用grid布局管理器 canvas_background tk.Canvas(window, width1000, height…

【Hive SQL 每日一题】统计各个商品今年销售额与去年销售额的增长率及排名变化

文章目录 测试数据需求说明需求实现分步解析 测试数据 -- 创建商品表 DROP TABLE IF EXISTS products; CREATE TABLE products (product_id INT,product_name STRING );INSERT INTO products VALUES (1, Product A), (2, Product B), (3, Product C), (4, Product D), (5, Pro…

英码科技推出鸿蒙边缘计算盒子:提升国产化水平,增强AI应用效能,保障数据安全

当前&#xff0c;随着国产化替代趋势的加强&#xff0c;鸿蒙系统Harmony OS也日趋成熟和完善&#xff0c;各行各业都在积极拥抱鸿蒙&#xff1b;那么&#xff0c;边缘计算要加快实现全面国产化&#xff0c;基于鸿蒙系统开发AI应用势在必行。 关于鸿蒙系统及其优势 鸿蒙系统是华…

Linux 问题定位查看日志文件常用命令

Linux 问题定位查看日志文件常用命令 查看日志文件的前100行中是否包含关键词&#xff1a; head -n 100 /var/log/file.log | grep "keyword"查看日志文件的最后100行中是否包含关键词&#xff1a; tail -n 100 /var/log/file.log | grep "keyword"使用l…

ROS2从入门到精通4-3:全局路径规划插件开发案例(以A*算法为例)

目录 0 专栏介绍1 路径规划插件的意义2 全局规划插件编写模板2.1 构造规划插件类2.2 注册并导出插件2.3 编译与使用插件 3 全局规划插件开发案例(A*算法)常见问题 0 专栏介绍 本专栏旨在通过对ROS2的系统学习&#xff0c;掌握ROS2底层基本分布式原理&#xff0c;并具有机器人建…