PyG-GAT-Cora(在Cora数据集上应用GAT做节点分类)

文章目录

  • model.py
  • main.py
  • 参数设置
  • 运行图

model.py

import torch.nn as nn
from torch_geometric.nn import GATConv
import torch.nn.functional as F
class gat_cls(nn.Module):def __init__(self,in_dim,hid_dim,out_dim,dropout_size=0.5):super(gat_cls,self).__init__()self.conv1 = GATConv(in_dim,hid_dim)self.conv2 = GATConv(hid_dim,hid_dim)self.fc = nn.Linear(hid_dim,out_dim)self.relu  = nn.ReLU()self.dropout_size = dropout_sizedef forward(self,x,edge_index):x = self.conv1(x,edge_index)x = F.dropout(x,p=self.dropout_size,training=self.training)x = self.relu(x)x = self.conv2(x,edge_index)x = self.relu(x)x = self.fc(x)return x

main.py

import torch
import torch.nn as nn
from torch_geometric.datasets import Planetoid
from model import gat_cls
import torch.optim as optim
dataset = Planetoid(root='./data/Cora', name='Cora')
print(dataset[0])
cora_data = dataset[0]epochs = 50
lr = 1e-3
weight_decay = 5e-3
momentum = 0.5
hidden_dim = 128
output_dim = 7net = gat_cls(cora_data.x.shape[1],hidden_dim,output_dim)
optimizer = optim.AdamW(net.parameters(),lr=lr,weight_decay=weight_decay)
#optimizer = optim.SGD(net.parameters(),lr = lr,momentum=momentum)
criterion = nn.CrossEntropyLoss()
print("****************Begin Training****************")
net.train()
for epoch in range(epochs):out = net(cora_data.x,cora_data.edge_index)optimizer.zero_grad()loss_train = criterion(out[cora_data.train_mask],cora_data.y[cora_data.train_mask])loss_val   = criterion(out[cora_data.val_mask],cora_data.y[cora_data.val_mask])loss_train.backward()print('epoch',epoch+1,'loss-train {:.2f}'.format(loss_train),'loss-val {:.2f}'.format(loss_val))optimizer.step()net.eval()
out = net(cora_data.x,cora_data.edge_index)
loss_test = criterion(out[cora_data.test_mask],cora_data.y[cora_data.test_mask])
_,pred = torch.max(out,dim=1)
pred_label = pred[cora_data.test_mask]
true_label = cora_data.y[cora_data.test_mask]
acc = sum(pred_label==true_label)/len(pred_label)
print("****************Begin Testing****************")
print('loss-test {:.2f}'.format(loss_test),'acc {:.2f}'.format(acc))

参数设置

epochs = 50
lr = 1e-3
weight_decay = 5e-3
momentum = 0.5
hidden_dim = 128
output_dim = 7

运行图

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/82627.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java学习--day6(数组)

文章目录 day5作业今天的内容1.数组1.1开发中为啥要有数组1.2在Java中如何定义数组1.3对第二种声明方式进行赋值1.4对数组进行取值1.5二维数组【了解】1.6数组可以当成一个方法的参数【重点】1.7数组可以当成一个方法的返回值1.8数组在内存中如何分配的【了解】 2.数组方法循环…

PFEA111–20 PFEA111–20 人工智能如何颠覆石油和天然气行业

PFEA111–20 PFEA111–20 人工智能如何颠覆石油和天然气行业 人工智能(AI)和机器学习(ML)等新技术的到来正在改变几十年来行业的运营方式。这些技术正在带来革命性的变革,影响着整个行业。石油和天然气行业在其运营过程中面临着许多挑战,如未连接的环境…

在SpringBoot项目中整合SpringSession,基于Redis实现对Session的管理和事件监听

1、SpringSession简介 SpringSession是基于Spring框架的Session管理解决方案。它基于标准的Servlet容器API,提供了Session的分布式管理解决方案,支持把Session存储在多种场景下,比如内存、MongoDB、Redis等,并且能够快速集成到Spr…

Java21 LTS版本

一、前言 除了众所周知的 JEP 之外,Java 21 还有更多内容。首先请确认 java 版本: $ java -version openjdk version "21" 2023-09-19 OpenJDK Runtime Environment (build 2135-2513) OpenJDK 64-Bit Server VM (build 2135-2513, mixed mo…

activiti7的数据表和字段的解释

activiti7的数据表和字段的解释 activiti7版本有25张表,而activiti6有28张表,activiti5有27张表,绝大部分的表和字段的含义都是一样的,所以本次整理的activiti7数据表和字段的解释,也同样适用于activiti6和5。 1、总览…

pcl--第五节 点云表面法线估算

估算点云表面法线 * 表面法线是几何表面的重要属性,在许多领域(例如计算机图形应用程序)中大量使用,以应用正确的光源以产生阴影和其他视觉效果。 给定一个几何表面,通常很难将表面某个点的法线方向推断为垂直于该点…

CockroachDB集群部署

CockroachDB集群部署 1、CockroachDB简介 CockroachDB(有时简称为CRDB)是一个免费的、开源的分布式 SQL 数据库,它建立在一个事务性和强一致性的键 值存储之上。它由 PebbleDB(一个受 RocksDB/leveldb 启发的 K/B 存储库)支持,并使用 Raft 分布式共识…

TypeScript入门

目录 一:语言特性 二:TypeScript安装 NPM 安装 TypeScript 三:TypeScript基础语法 第一个 TypeScript 程序 四:TypeScript 保留关键字 空白和换行 TypeScript 区分大小写 TypeScript 注释 TypeScript 支持两种类型的注释 …

初识C语言——详细入门一(系统性学习day4)

目录 前言 一、C语言简单介绍、特点、基本构成 简单介绍: 特点: 基本构成: 二、认识C语言程序 标准格式: 简单C程序: 三、基本构成分类详细介绍 (1)关键字 (2&#xf…

fork函数

二.fork函数 2.1函数原型 fork()函数在 C 语言中的原型如下&#xff1a; #include <unistd.h>pid_t fork(void);其中pid_t是一个整型数据类型&#xff0c;用于表示进程ID。fork()函数返回值是一个pid_t类型的值&#xff0c;具体含义如下&#xff1a; 如果调用fork()的…

MyBatis中当实体类中的属性名和表中的字段名不一样,怎么办

方法1&#xff1a; 在mybatis核心配置文件中指定&#xff0c;springboot加载mybatis核心配置文件 springboot项目的一个特点就是0配置&#xff0c;本来就省掉了mybatis的核心配置文件&#xff0c;现在又加回去算什么事&#xff0c;总之这种方式可行但没人这样用 具体操作&…

Python灰帽编程——错误异常处理与面向对象

文章目录 错误异常处理与面向对象1. 错误和异常1.1 基本概念1.1.1 Python 异常 1.2 检测&#xff08;捕获&#xff09;异常1.2.1 try except 语句1.2.2 捕获多种异常1.2.3 捕获所有异常 1.3 处理异常1.4 特殊场景1.4.1 with 语句 1.5 脚本完善 2. 内网主机存活检测程序2.1 scap…

【Java 基础篇】Java字符打印流详解:文本数据的输出利器

在Java编程中&#xff0c;我们经常需要将数据输出到文件或其他输出源中。Java提供了多种输出流来帮助我们完成这项任务&#xff0c;其中字符打印流是一个非常有用的工具。本文将详细介绍Java字符打印流的用法&#xff0c;以及如何在实际编程中充分利用它。 什么是字符打印流&a…

DNG格式详解,DNG是什么?为何DNG可以取代RAW统一单反相机、苹果安卓移动端相机拍摄输出原始图像数据标准

返回图像处理总目录&#xff1a;《JavaCV图像处理合集总目录》 前言 在DNG格式发布之前&#xff0c;我们先了解一下之前单反相机、苹果和安卓移动端相机拍照输出未经处理的原始图像格式是什么&#xff1f; RAW 什么是RAW&#xff1f; RAW是未经处理、也未经压缩的格式。可以…

Rust通用编程概念(3)

Rust通用编程概念 1.变量和可变性1.执行cargo run2.变量3.变量的可变性4.常量5.遮蔽5.1遮蔽与mut区别1.遮蔽2.mut 2.数据类型1.标量类型1.1整数类型1.2浮点数类型1.3数字运算1.4布尔类型1.5字符类型 2.复合类型2.1元组类型2.2数组类型1.访问数组2.无效的数组元素访问 3.函数3.1…

如何解决 503 Service Temporarily Unavailable?

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f405;&#x1f43e;猫头虎建议程序员必备技术栈一览表&#x1f4d6;&#xff1a; &#x1f6e0;️ 全栈技术 Full Stack: &#x1f4da…

想要精通算法和SQL的成长之路 - 填充书架

想要精通算法和SQL的成长之路 - 填充书架 前言一. 填充书架1.1 优化 前言 想要精通算法和SQL的成长之路 - 系列导航 一. 填充书架 原题链接 题目中有一个值得注意的点就是&#xff1a; 需要按照书本顺序摆放。每一层当中&#xff0c;只要厚度不够了&#xff0c;当前层最高…

【考研数学】高等数学第六模块 —— 空间解析几何(1,向量基本概念与运算)

文章目录 引言一、空间解析几何的理论1.1 基本概念1.2 向量的运算 写在最后 引言 我自认空间想象能力较差&#xff0c;所以当初学这个很吃力。希望现在再接触&#xff0c;能好点。 一、空间解析几何的理论 1.1 基本概念 1.向量 —— 既有大小&#xff0c;又有方向的量称为向…

C语言指针,深度长文全面讲解

指针对于C来说太重要。然而&#xff0c;想要全面理解指针&#xff0c;除了要对C语言有熟练的掌握外&#xff0c;还要有计算机硬件以及操作系统等方方面面的基本知识。所以本文尽可能的通过一篇文章完全讲解指针。 为什么需要指针&#xff1f; 指针解决了一些编程中基本的问题。…

spring aop源码解析

spring知识回顾 spring的两个重要功能&#xff1a;IOC、AOP&#xff0c;在ioc容器的初始化过程中&#xff0c;会触发2种处理器的调用&#xff0c; 前置处理器(BeanFactoryPostProcessor)后置处理器(BeanPostProcessor)。 前置处理器的调用时机是在容器基本创建完成时&#xff…