PyG-GAT-Cora(在Cora数据集上应用GAT做节点分类)

文章目录

  • model.py
  • main.py
  • 参数设置
  • 运行图

model.py

import torch.nn as nn
from torch_geometric.nn import GATConv
import torch.nn.functional as F
class gat_cls(nn.Module):def __init__(self,in_dim,hid_dim,out_dim,dropout_size=0.5):super(gat_cls,self).__init__()self.conv1 = GATConv(in_dim,hid_dim)self.conv2 = GATConv(hid_dim,hid_dim)self.fc = nn.Linear(hid_dim,out_dim)self.relu  = nn.ReLU()self.dropout_size = dropout_sizedef forward(self,x,edge_index):x = self.conv1(x,edge_index)x = F.dropout(x,p=self.dropout_size,training=self.training)x = self.relu(x)x = self.conv2(x,edge_index)x = self.relu(x)x = self.fc(x)return x

main.py

import torch
import torch.nn as nn
from torch_geometric.datasets import Planetoid
from model import gat_cls
import torch.optim as optim
dataset = Planetoid(root='./data/Cora', name='Cora')
print(dataset[0])
cora_data = dataset[0]epochs = 50
lr = 1e-3
weight_decay = 5e-3
momentum = 0.5
hidden_dim = 128
output_dim = 7net = gat_cls(cora_data.x.shape[1],hidden_dim,output_dim)
optimizer = optim.AdamW(net.parameters(),lr=lr,weight_decay=weight_decay)
#optimizer = optim.SGD(net.parameters(),lr = lr,momentum=momentum)
criterion = nn.CrossEntropyLoss()
print("****************Begin Training****************")
net.train()
for epoch in range(epochs):out = net(cora_data.x,cora_data.edge_index)optimizer.zero_grad()loss_train = criterion(out[cora_data.train_mask],cora_data.y[cora_data.train_mask])loss_val   = criterion(out[cora_data.val_mask],cora_data.y[cora_data.val_mask])loss_train.backward()print('epoch',epoch+1,'loss-train {:.2f}'.format(loss_train),'loss-val {:.2f}'.format(loss_val))optimizer.step()net.eval()
out = net(cora_data.x,cora_data.edge_index)
loss_test = criterion(out[cora_data.test_mask],cora_data.y[cora_data.test_mask])
_,pred = torch.max(out,dim=1)
pred_label = pred[cora_data.test_mask]
true_label = cora_data.y[cora_data.test_mask]
acc = sum(pred_label==true_label)/len(pred_label)
print("****************Begin Testing****************")
print('loss-test {:.2f}'.format(loss_test),'acc {:.2f}'.format(acc))

参数设置

epochs = 50
lr = 1e-3
weight_decay = 5e-3
momentum = 0.5
hidden_dim = 128
output_dim = 7

运行图

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/82627.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java学习--day6(数组)

文章目录 day5作业今天的内容1.数组1.1开发中为啥要有数组1.2在Java中如何定义数组1.3对第二种声明方式进行赋值1.4对数组进行取值1.5二维数组【了解】1.6数组可以当成一个方法的参数【重点】1.7数组可以当成一个方法的返回值1.8数组在内存中如何分配的【了解】 2.数组方法循环…

PFEA111–20 PFEA111–20 人工智能如何颠覆石油和天然气行业

PFEA111–20 PFEA111–20 人工智能如何颠覆石油和天然气行业 人工智能(AI)和机器学习(ML)等新技术的到来正在改变几十年来行业的运营方式。这些技术正在带来革命性的变革,影响着整个行业。石油和天然气行业在其运营过程中面临着许多挑战,如未连接的环境…

在SpringBoot项目中整合SpringSession,基于Redis实现对Session的管理和事件监听

1、SpringSession简介 SpringSession是基于Spring框架的Session管理解决方案。它基于标准的Servlet容器API,提供了Session的分布式管理解决方案,支持把Session存储在多种场景下,比如内存、MongoDB、Redis等,并且能够快速集成到Spr…

AI 标注终结人工标注,效率高 100 倍,成本为 14%

AI 标注终结人工标注,效率高 100 倍,成本为 14% 稀缺性如何解决数据标注廉价的新方法数据标注自动化AnthropicAutolabel 安装 (Python)AI 标注终结人工标注,效率高 100 倍,成本为 14% 稀缺性 大模型满天飞的时代,AI行业最缺的是什么 算力(显卡)高质量的数据OpenAI 正…

Java21 LTS版本

一、前言 除了众所周知的 JEP 之外,Java 21 还有更多内容。首先请确认 java 版本: $ java -version openjdk version "21" 2023-09-19 OpenJDK Runtime Environment (build 2135-2513) OpenJDK 64-Bit Server VM (build 2135-2513, mixed mo…

activiti7的数据表和字段的解释

activiti7的数据表和字段的解释 activiti7版本有25张表,而activiti6有28张表,activiti5有27张表,绝大部分的表和字段的含义都是一样的,所以本次整理的activiti7数据表和字段的解释,也同样适用于activiti6和5。 1、总览…

pcl--第五节 点云表面法线估算

估算点云表面法线 * 表面法线是几何表面的重要属性,在许多领域(例如计算机图形应用程序)中大量使用,以应用正确的光源以产生阴影和其他视觉效果。 给定一个几何表面,通常很难将表面某个点的法线方向推断为垂直于该点…

python execute() 使用%s 拼接sql 避免sql注入攻击 好于.format

1 execute(参数一:sql 语句) # 锁定当前查询结果行 cursor.execute("SELECT high, low, vol FROM table_name WHERE symbol %s FOR UPDATE;"% (symbol,)) 2 .format() cursor.execute("SELECT high, low, vol FROM table_name WHERE symbol {} FOR UPDATE;…

Pytorch实现LSTM预测模型并使用C++相应的ONNX模型推理

Pytorch实现RNN模型 代码 import torch import torch.nn as nnclass LSTM(nn.Module):def __init__(self, input_size, output_size, out_channels, num_layers, device):super(LSTM, self).__init__()self.device deviceself.input_size input_sizeself.hidden_size inpu…

CockroachDB集群部署

CockroachDB集群部署 1、CockroachDB简介 CockroachDB(有时简称为CRDB)是一个免费的、开源的分布式 SQL 数据库,它建立在一个事务性和强一致性的键 值存储之上。它由 PebbleDB(一个受 RocksDB/leveldb 启发的 K/B 存储库)支持,并使用 Raft 分布式共识…

TypeScript入门

目录 一:语言特性 二:TypeScript安装 NPM 安装 TypeScript 三:TypeScript基础语法 第一个 TypeScript 程序 四:TypeScript 保留关键字 空白和换行 TypeScript 区分大小写 TypeScript 注释 TypeScript 支持两种类型的注释 …

初识C语言——详细入门一(系统性学习day4)

目录 前言 一、C语言简单介绍、特点、基本构成 简单介绍: 特点: 基本构成: 二、认识C语言程序 标准格式: 简单C程序: 三、基本构成分类详细介绍 (1)关键字 (2&#xf…

fork函数

二.fork函数 2.1函数原型 fork()函数在 C 语言中的原型如下&#xff1a; #include <unistd.h>pid_t fork(void);其中pid_t是一个整型数据类型&#xff0c;用于表示进程ID。fork()函数返回值是一个pid_t类型的值&#xff0c;具体含义如下&#xff1a; 如果调用fork()的…

MyBatis中当实体类中的属性名和表中的字段名不一样,怎么办

方法1&#xff1a; 在mybatis核心配置文件中指定&#xff0c;springboot加载mybatis核心配置文件 springboot项目的一个特点就是0配置&#xff0c;本来就省掉了mybatis的核心配置文件&#xff0c;现在又加回去算什么事&#xff0c;总之这种方式可行但没人这样用 具体操作&…

MFC C++ 数据结构及相互转化 CString char * char[] byte PCSTR DWORE unsigned

CString&#xff1a; char * char [] BYTE BYTE [] unsigned char DWORD CHAR&#xff1a;单字节字符8bit WCHAR为Unicode字符:typedef unsigned short wchar_t TCHAR : 如果当前编译方式为ANSI(默认)方式&#xff0c;TCHAR等价于CHAR&#xff0c;如果为Unicode方式&#xff0c…

Python灰帽编程——错误异常处理与面向对象

文章目录 错误异常处理与面向对象1. 错误和异常1.1 基本概念1.1.1 Python 异常 1.2 检测&#xff08;捕获&#xff09;异常1.2.1 try except 语句1.2.2 捕获多种异常1.2.3 捕获所有异常 1.3 处理异常1.4 特殊场景1.4.1 with 语句 1.5 脚本完善 2. 内网主机存活检测程序2.1 scap…

Git从入门到起飞(详细)

Git从入门到起飞 Git从入门到起飞什么是Git&#xff1f;使用git前提(注册git)下载Git在Windows上安装Git在macOS上安装Git在Linux上安装Git 配置Git配置全局用户信息配置文本编辑器 创建第一个Git仓库初始化仓库拉取代码添加文件到仓库提交更改推送 Git基本操作查看提交历史比较…

【Java 基础篇】Java字符打印流详解:文本数据的输出利器

在Java编程中&#xff0c;我们经常需要将数据输出到文件或其他输出源中。Java提供了多种输出流来帮助我们完成这项任务&#xff0c;其中字符打印流是一个非常有用的工具。本文将详细介绍Java字符打印流的用法&#xff0c;以及如何在实际编程中充分利用它。 什么是字符打印流&a…

矩阵 m * M = c

文章目录 题1题2 题1 (2023江苏领航杯-prng) 题目来源&#xff1a;https://dexterjie.github.io/2023/09/12/%E8%B5%9B%E9%A2%98%E5%A4%8D%E7%8E%B0/2023%E9%A2%86%E8%88%AA%E6%9D%AF/ 题目描述&#xff1a; (没有原数据&#xff0c;自己生成的数据) from Crypto.Util.number…

DNG格式详解,DNG是什么?为何DNG可以取代RAW统一单反相机、苹果安卓移动端相机拍摄输出原始图像数据标准

返回图像处理总目录&#xff1a;《JavaCV图像处理合集总目录》 前言 在DNG格式发布之前&#xff0c;我们先了解一下之前单反相机、苹果和安卓移动端相机拍照输出未经处理的原始图像格式是什么&#xff1f; RAW 什么是RAW&#xff1f; RAW是未经处理、也未经压缩的格式。可以…