图神经网络教程之GAT(pyG)

图神经网络-pyG-GAT

在上一章节介绍了pyG-GCN的使用,除了GCN,还有一些像GAT、GraphSage等等一些,本文将介绍GAT模型的构建

实现了一个使用Graph Attention Network(GAT)的节点分类模型,该模型在Cora数据集上进行训练和测试。

  1. 首先,导入所需的库和模块:

    • torch_geometric.datasets.Planetoid:用于加载Cora数据集。
    • torch:PyTorch的主要库。
    • torch.nn.functional as F:PyTorch的神经网络函数模块,用于定义神经网络的层和操作。
    • torch_geometric.nn.GATConv:PyTorch Geometric库中的图注意力网络层(Graph Attention Network,GATConv)。
    • torch_geometric.nn.GATConv:PyTorch Geometric库中的图注意力网络层(Graph Attention Network,GATConv)。
  2. 加载Cora数据集:

    dataset = Planetoid(root='./tmp/Cora', name='Cora')
    

    这行代码加载了Cora数据集,该数据集包括节点特征、图的边缘信息以及节点的真实标签。

  3. 定义一个名为GAT_Net的神经网络类:

    class GAT_Net(torch.nn.Module):
    

    这个类继承自PyTorch的torch.nn.Module基类,表示它是一个神经网络模型。

  4. GAT_Net类的构造函数中,定义了两个GAT层:

    def __init__(self, features, hidden, classes, heads=1):super(GAT_Net, self).__init__()self.gat1 = GATConv(features, hidden, heads=heads)self.gat2 = GATConv(hidden * heads, classes)
    
    • GATConv层是图注意力网络层,用于从图数据中提取特征。
    • self.gat1是第一个GATConv层,它将输入特征的维度设置为features,输出hidden维特征,同时可以指定heads的数量。
    • self.gat2是第二个GATConv层,将hidden * heads维特征映射到classes个类别。
  5. forward方法中定义了前向传播过程:

    def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.gat1(x, edge_index)x = F.relu(x)x = F.dropout(x, training=self.training)x = self.gat2(x, edge_index)return F.log_softmax(x, dim=1)
    
    • 输入数据data包括节点特征x和边索引edge_index
    • self.gat1self.gat2分别表示第一层和第二层的图注意力网络操作。
    • 使用ReLU激活函数进行非线性变换。
    • 使用Dropout层进行正则化。
    • 最后,通过F.log_softmax对输出进行softmax操作,以得到每个节点属于不同类别的概率分布。
  6. 检查并设置GPU或CPU设备:

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    

    这段代码会检查你的系统是否有可用的GPU,并将device设置为GPU或CPU,以便在相应的设备上运行模型。

  7. 创建并将模型移动到所选设备上:

    model = GAT_Net(dataset.num_node_features, 16, dataset.num_classes, heads=4).to(device)
    

    这将实例化GAT_Net模型,并将模型的参数和计算移动到GPU或CPU上。heads参数指定了GAT中的注意力头数量。

  8. 加载Cora数据集的第一个图数据实例:

    data = dataset[0]
    

    这将加载Cora数据集的第一个图数据实例,包括节点特征、图的边缘信息以及节点的真实标签。

  9. 定义优化器(这里使用Adam优化器):

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    

    这行代码创建一个Adam优化器,并将模型的参数传递给它,用于模型参数的更新。学习率为0.01。

  10. 将模型设置为训练模式:

    model.train()
    

    这行代码将模型切换到训练模式,以启用训练时的特定操作,如Dropout。

  11. 开始训练循环,训练模型200个epoch:

    for epoch in range(200):
    

    这是一个训练循环,将模型训练200次。

  12. 在每个epoch中,首先将优化器的梯度清零:

    optimizer.zero_grad()
    

    这行代码用于清除之前的梯度信息,以准备计算新的梯度。

  13. 通过模型前向传播计算预测结果:

    out = model(data)
    

    这会将数据传递给你的GAT模型,然后返回模型的预测结果。

  14. 计算损失函数,这里使用负对数似然损失(Negative Log-Likelihood Loss):

    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    

    这行代码计算了在训练节点子集上的负对数似然损失。data.train_mask指定了用于训练的节点子集,data.y是节点的真实标签。

  15. 反向传播和参数更新:

    loss.backward()
    optimizer.step()
    

    这两行代码用于计算梯度并执行梯度下降,更新模型的参数,以最小化损失函数。

from torch_geometric.datasets import Planetoid
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, SAGEConv, GATConvdataset = Planetoid(root='./tmp/Cora',name='Cora')
class GAT_Net(torch.nn.Module):def __init__(self, features, hidden, classes, heads=1):super(GAT_Net, self).__init__()self.gat1 = GATConv(features, hidden, heads=heads)self.gat2 = GATConv(hidden*heads, classes)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.gat1(x, edge_index)x = F.relu(x)x = F.dropout(x, training=self.training)x = self.gat2(x, edge_index)return F.log_softmax(x, dim=1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GAT_Net(dataset.num_node_features, 16, dataset.num_classes, heads=4).to(device)
data = dataset[0]
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)model.train()
for epoch in range(200):optimizer.zero_grad()out = model(data)loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()model.eval()
_, pred = model(data).max(dim=1)
correct = pred[data.test_mask].eq(data.y[data.test_mask]).sum()
acc = int(correct)/ int(data.test_mask.sum())
print('GAT',acc)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/64120.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《Kali渗透基础》15. WEB 渗透

kali渗透 1:WEB 技术1.1:WEB 攻击面1.2:HTTP 协议基础1.3:AJAX1.4:WEB Service 2:扫描工具2.1:HTTrack2.2:Nikto2.3:Skipfish2.4:Arachni2.5:OWAS…

前端面试必备 | uni-app 篇(P1-15)

文章目录 1. 请简述一下uni-app的定义和特点。2. uni-app兼容哪些前端框架?请列举几个。3. 请简述一下uni-app的跨平台工作原理。4. 什么是条件编译?在uni-app中如何实现条件编译?5. uni-app中的页面生命周期有哪些?请简要介绍。6…

UG\NX CAM二次开发 插入工序 UF_OPER_create

文章作者:代工 来源网站:NX CAM二次开发专栏 简介: UG\NX CAM二次开发 插入工序 UF_OPER_create 效果: 代码: void MyClass::do_it() {tag_t setup_tag=NULL_TAG;UF_SETUP_ask_setup(&setup_tag);if (setup_tag==NULL_TAG){uc1601("请先初始化加工环境…

flinkcdc同步完全量数据就不同步增量数据了

flinkcdc同步完全量数据就不同步增量数据了 使用flinkcdc同步mysql数据,使用的是全量采集模型 startupOptions(StartupOptions.earliest()) 全量阶段同步完成之后,发现并不开始同步增量数据,原因有以下两个: 原因1: …

WPF数据视图

将集合绑定到ItemsControl控件时,会不加通告的在后台创建数据视图——位于数据源和绑定的控件之间。数据视图是进入数据源的窗口,可以跟踪当前项,并且支持各种功能,如排序、过滤、分组。 这些功能和数据对象本身是相互独立的&…

【计算机网络】OSI 七层网络参考模型

OSI(Open Systems Interconnection)七层网络参考模型是一种用于描述计算机网络通信的框架,将网络通信划分为七个不同的层次,每个层次负责不同的功能。 以下为 OSI 七层网络参考模型的简单表格: --------------------…

maven的依赖下载不下来的几种解决方法

前言 每次部署测试环境,从代码库拉取代码,都会出现缺少包的情况。然后找开发一通调试,到处拷包。 方案一:pom文件注释/取消注释 注释掉pom.xml里的报红色的依赖(同时可以把本地maven库repo里对应的包删除)&…

一款不能错过的Git客户端:Fork for Mac,让你的代码管理更便捷

Fork for Mac是一款强大的Git客户端,让用户在Mac电脑上更方便地进行版本控制和代码管理。它具有以下特点: 易用性:Fork for Mac界面简洁明了,操作简单易懂,让用户可以快速上手。功能强大:支持各种Git功能&…

Mac软件删除方法?如何删除不会有残留

Mac电脑如果有太多无用的应用程序,很有可能会拖垮Mac系统的运行速度。因此,卸载电脑中无用的软件是优化Mac系统运行速度的最佳方式之一。Mac卸载应用程序的方式是和Windows有很大的区别,特别对于Mac新用户来说,如何无残留的卸载删…

Java8异步类CompletableFuture详解

1、前言 学习java基础时候多线程使用我们首先学习的 Runable 、Future 、 Thread 、ExecutorService、Callable等相关类,在我们日常工作或者学习中有些场景并不满足我们需求,JDK8引入了一个新的类 CompletableFuture 来解决之前得问题, Comp…

【Latex】使用技能站:(三)使用 Vscode 配置 LaTeX

使用 Vscode 配置 LaTeX 引言1 安装texlive2 安装vscode2.1 插件安装2.2 配置 3 安装SumatraPdf3.1 vscode配置3.2 配置反向搜索 引言 安装texlive 安装vscode 安装SumatraPdf 1 安装texlive 在线LaTeX编辑器:https://www.overleaf.com TeX Live下载:h…

使用 v-for 指令和数组来实现在 Uni-app 中动态增减表单项并渲染多个数据

在 data 中定义一个数组&#xff0c;用于存储表单项的数据&#xff1a; data() {return {formItems: []} } 在模板中使用 v-for 指令渲染表单项&#xff1a; <template><div><div v-for"(item, index) in formItems" :key"index"><…

【LeetCode】《LeetCode 101》第十二章:字符串

文章目录 12.1 字符串比较242 . 有效的字母异位词&#xff08;简单&#xff09;205. 同构字符串&#xff08;简单&#xff09;647. 回文子串&#xff08;中等&#xff09;696 . 计数二进制子串&#xff08;简单&#xff09; 12.2 字符串理解224. 基本计算器&#xff08;困难&am…

mysql之存储引擎

目录 存储引擎概念 MyISAM MyISAM特点 MyISAM 表的存储格式 MyISAM适用的生产场景 InnoDB InnoDB特点 选择存储引擎依据 MyISAM 和 INNODB区别 命令 查看系统支持的存储引擎 查看表使用的存储引擎 修改存储引擎 存储引擎概念 MySQL中的数据用各种不同的技术存…

通过python 获取当前局域网内存在的IP和MAC

通过python 获取当前局域网内存在的ip 通过ipconfig /all 命令获取局域网所在的网段 通过arp -d *命令清空当前所有的arp映射表 循环遍历当前网段所有可能的ip与其ping一遍建立arp映射表 for /L %i IN (1,1,254) DO ping -w 1 -n 1 192.168.3.%i 通过arp -a命令读取缓存的映射表…

Java的23种设计模式

Java的23种设计模式 一、创建型设计模式1.单例模式 singleton1.1.静态属性单例模式1.2 静态属性变种1.3 基础的懒汉模式1.4 线程安全的懒加载单例1.5 线程安全的懒加载 单例-改进1.6 双重检查锁1.7 静态内部类1.8 枚举单例1.9 注册表单例 2.工厂方法模式 factory3.抽象工厂模式…

通过chatgpt 学习React的useEffect

定义&#xff1a; useEffect 是 React 中的一个 Hook&#xff0c;它用于处理函数组件中的副作用操作。副作用操作可以包括数据获取、订阅事件、定时器等。 useEffect 接受两个参数&#xff1a;第一个参数是一个回调函数&#xff0c;用于执行副作用操作&#xff1b;第二个参数…

IE浏览器攻击:MS11-003_IE_CSS_IMPORT

目录 概述 利用过程 漏洞复现 概述 MS11-003_IE_CSS_IMPORT是指Microsoft Security Bulletin MS11-003中的一个安全漏洞&#xff0c;影响Internet Explorer&#xff08;IE&#xff09;浏览器。这个漏洞允许攻击者通过在CSS文件中使用import规则来加载外部CSS文件&#xff0…

SMU学习

SMU学习 1.参考资料 1.参考资料 TC3xx-SMU_EMS分析 英飞凌基础学习笔记&#xff08;SMU&#xff09;Safety Management Unit 为什么需要外部看门狗&#xff1f; ISO 26262 - Software Level of Functional Safety 简要概括下infeneon 芯片的SMU模块 ChatGPT 英飞凌&#xff0…

计算一个区间时间差值,时间保留剩下的差值

解决目的 begin end&#xff0c;去除集合类的其他区间差值List<rang> r1 new ArrayList(); 得到差值package com.jowoiot.wmzs.utils.date;import com.google.common.collect.Lists; import com.google.common.collect.Range; import org.apache.commons.lang.time.Dat…