使用dgl库实现GCN【官方案例】

学习目的

学习使用gnn进行节点分类的基本工作流程,即预测图中节点的类别。

关于GCN节点分类的综述

在图数据上最流行和广泛采用的任务之一是节点分类,其中模型需要预测每个节点的真实类别。
在图神经网络之前,许多被提出的方法要么单独使用连通性(如DeepWalk或node2vec),要么简单地结合连通性和节点自身的特征。相比之下,gnn通过结合局部邻域的连通性和特征提供了获得节点表示的机会。Kipf等人将节点分类问题表述为半监督节点分类任务的例子。图神经网络(GNN)仅借助一小部分标记节点就能准确预测其他节点的类别。笔记将展示如何在Cora数据集上仅使用少量标签构建半监督节点分类的GNN,一个以论文为节点,引用为边的引文网络。任务是预测给定论文的类别。每个论文节点包含一个单词计数向量作为其特征,标准化后它们的总和为1。

导包

import os
os.environ["DGLBACKEND"] = "pytorch"
import dgl
import dgl.data
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import GraphConv

os.environ[“DGLBACKEND”] = "pytorch"该行代码十分重要,因为它指定了DGL所运行的后端环境。

加载Cora数据集

dataset = dgl.data.CoraGraphDataset()
print(f"Number of categories: {dataset.num_classes}")

一个DGL数据集对象可以包含一个或多个图。本教程中使用的Cora数据集只包含一个图。

查看节点和边

"""
一个DGL数据集对象可以包含一个或多个图。本教程中使用的Cora数据集只包含一个图。
"""
g = dataset[0]print("Node features")
print(g.ndata)
print("Edge features")
print(g.edata)

定义图卷积网络(GCN)

# 定义图卷积网络(GCN)
class GCN(nn.Module):def __init__(self, in_feats, h_feats, num_classes):super(GCN, self).__init__()self.conv1 = GraphConv(in_feats, h_feats)self.conv2 = GraphConv(h_feats, num_classes)def forward(self, g, in_feat):h = self.conv1(g, in_feat)h = F.relu(h)h = self.conv2(g, h)return h

创建给定维度的模型

model = GCN(g.ndata["feat"].shape[1], 16, dataset.num_classes)

模型训练

使用DGL库训练网络类似于神经网络的训练,对于神经网络的训练步骤可以参考本博客:【转载】Pytorch模型实现的四部曲

# 模型训练
def train(g, model):optimizer = torch.optim.Adam(model.parameters(), lr=0.01)best_val_acc = 0best_test_acc = 0features = g.ndata["feat"]labels = g.ndata["label"]train_mask = g.ndata["train_mask"]val_mask = g.ndata["val_mask"]test_mask = g.ndata["test_mask"]for epoch in range(100):# 前向传播logits = model(g, features)# 计算预测值pred = logits.argmax(1)# 计算损失# 只计算训练集中节点的损失loss = F.cross_entropy(logits[train_mask], labels[train_mask])# Compute accuracy on training/validation/testtrain_acc = (pred[train_mask] == labels[train_mask]).float().mean()val_acc = (pred[val_mask] == labels[val_mask]).float().mean()test_acc = (pred[test_mask] == labels[test_mask]).float().mean()# Save the best validation accuracy and the corresponding test accuracy.if best_val_acc < val_acc:best_val_acc = val_accbest_test_acc = test_acc# Backwardoptimizer.zero_grad()loss.backward()optimizer.step()if e % 5 == 0:print(f"In epoch {e}, loss: {loss:.3f}, val acc: {val_acc:.3f} (best {best_val_acc:.3f}), test acc: {test_acc:.3f} (best {best_test_acc:.3f})")

在cuda上训练GCN

g = g.to('cuda') #在cuda上训练GCN
model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes).to('cuda')
train(g, model)

省流版:完整代码

"""学习使用gnn进行节点分类的基本工作流程,即预测图中节点的类别。"""
import os
os.environ["DGLBACKEND"] = "pytorch"
import dgl
import dgl.data
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import GraphConv
"""GNN节点分类综述
在图数据上最流行和广泛采用的任务之一是节点分类,其中模型需要预测每个节点的真实类别。
在图神经网络之前,许多被提出的方法要么单独使用连通性(如DeepWalk或node2vec),
要么简单地结合连通性和节点自身的特征。相比之下,gnn通过结合局部邻域的连通性和特征提供了获得节点表示的机会。
Kipf等人将节点分类问题表述为半监督节点分类任务的例子。图神经网络(GNN)
仅借助一小部分标记节点就能准确预测其他节点的类别。本教程将展示如何在Cora数据集上仅使用少量标签构建半监督节点分类的GNN,一个以论文为节点,
引用为边的引文网络。任务是预测给定论文的类别。每个论文节点包含一个单词计数向量作为其特征,
标准化后它们的总和为1
"""# 加载Cora数据集¶
dataset = dgl.data.CoraGraphDataset()
print(f"Number of categories: {dataset.num_classes}")"""
一个DGL数据集对象可以包含一个或多个图。本教程中使用的Cora数据集只包含一个图。
"""
g = dataset[0]print("Node features")
print(g.ndata)
print("Edge features")
print(g.edata)# 定义图卷积网络(GCN)
class GCN(nn.Module):def __init__(self, in_feats, h_feats, num_classes):super(GCN, self).__init__()self.conv1 = GraphConv(in_feats, h_feats)self.conv2 = GraphConv(h_feats, num_classes)def forward(self, g, in_feat):h = self.conv1(g, in_feat)h = F.relu(h)h = self.conv2(g, h)return h# 创建具有给定维度的模型
model = GCN(g.ndata["feat"].shape[1], 16, dataset.num_classes)# 模型训练
def train(g, model):optimizer = torch.optim.Adam(model.parameters(), lr=0.01)best_val_acc = 0best_test_acc = 0features = g.ndata["feat"]labels = g.ndata["label"]train_mask = g.ndata["train_mask"]val_mask = g.ndata["val_mask"]test_mask = g.ndata["test_mask"]for epoch in range(100):# 前向传播logits = model(g, features)# 计算预测值pred = logits.argmax(1)# 计算损失# 只计算训练集中节点的损失loss = F.cross_entropy(logits[train_mask], labels[train_mask])# Compute accuracy on training/validation/testtrain_acc = (pred[train_mask] == labels[train_mask]).float().mean()val_acc = (pred[val_mask] == labels[val_mask]).float().mean()test_acc = (pred[test_mask] == labels[test_mask]).float().mean()# Save the best validation accuracy and the corresponding test accuracy.if best_val_acc < val_acc:best_val_acc = val_accbest_test_acc = test_acc# Backwardoptimizer.zero_grad()loss.backward()optimizer.step()if e % 5 == 0:print(f"In epoch {e}, loss: {loss:.3f}, val acc: {val_acc:.3f} (best {best_val_acc:.3f}), test acc: {test_acc:.3f} (best {best_test_acc:.3f})")
g = g.to('cuda') #在cuda上训练GCN
model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes).to('cuda')
train(g, model)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/141386.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spark3.0中的AOE、DPP和Hint增强

1 Spark3.0 AQE Spark 在 3.0 版本推出了 AQE&#xff08;Adaptive Query Execution&#xff09;&#xff0c;即自适应查询执行。AQE 是 Spark SQL 的一种动态优化机制&#xff0c;在运行时&#xff0c;每当 Shuffle Map 阶段执行完毕&#xff0c;AQE 都会结合这个阶段的统计信…

【KVM-7】KVM管理工具

前言 大家好&#xff0c;我是秋意零。 &#x1f47f; 简介 &#x1f3e0; 个人主页&#xff1a; 秋意零&#x1f525; 账号&#xff1a;全平台同名&#xff0c; 秋意零 账号创作者、 云社区 创建者&#x1f9d1; 个人介绍&#xff1a;在校期间参与众多云计算相关比赛&#x…

微信小程序广告banner、滚动屏怎么做?

使用滑块视图容器swiper和swiper-item可以制作滚动屏&#xff0c;代码如下&#xff1a; wxml: <swiper indicator-dots indicator-color"rgba(255,255,255,0.5)" indicator-active-color"white" autoplay interval"3000"><swiper-ite…

【人工智能Ⅰ】6-机器学习之分类

【人工智能Ⅰ】6-机器学习之分类 6-1 机器学习在人工智能中的地位 学习能力是智能的本质 人工智能 > 机器学习 > 深度学习 什么是机器学习&#xff1f; baidu&#xff1a;多领域交叉学科&#xff08;做什么&#xff09; wiki&#xff1a;the study of algorithms and…

OpenCV踩坑笔记使用笔记入门笔记整合SpringBoot笔记大全

springboot开启摄像头抓拍照片并上传实现&问题记录 NotAllowedErrot: 请求的媒体源不能使用&#xff0c;以下情况会返回该错误: 当前页面内容不安全&#xff0c;没有使用HTTPS没有通过用户授权NotFoundError: 没有找到指定的媒体通道NoReadableError: 访问硬件设备出错Ov…

计算机基础知识50

数据的增删改查(insert update delete select) # 用户列表的展示&#xff1a; # 把数据表中得用户数据都给查询出来展示在页面上 1. 查询 from app01 import models models.UserInfo.objects.all() # 查询所有的字段信息和数据 resmodels.UserInfo.objects.first() # 查询…

什么是状态机?

什么是状态机&#xff1f; 定义 我们先来给出状态机的基本定义。一句话&#xff1a; 状态机是有限状态自动机的简称&#xff0c;是现实事物运行规则抽象而成的一个数学模型。 先来解释什么是“状态”&#xff08; State &#xff09;。现实事物是有不同状态的&#xff0c;例…

Spark SQL 每年的1月1日算当年的第一个自然周, 给出日期,计算是本年的第几周

一、问题 按每年的1月1日算当年的第一个自然周 (遇到跨年也不管&#xff0c;如果1月1日是周三&#xff0c;那么到1月5号&#xff08;周日&#xff09;算是本年的第一个自然周, 如果按周一是一周的第一天) 计算是本年的第几周&#xff0c;那么 spark sql 如何写 ? 二、分析 …

P6入门:项目初始化9-项目详情之资源 Resource

前言 使用项目详细信息查看和编辑有关所选项目的详细信息&#xff0c;在项目创建完成后&#xff0c;初始化项目是一项非常重要的工作&#xff0c;涉及需要设置的内容包括项目名&#xff0c;ID,责任人&#xff0c;日历&#xff0c;预算&#xff0c;资金&#xff0c;分类码等等&…

npm install导致的OOM解决方案

文章目录 问题记录解决方法Linux重启排查方法 如何排查Linux自动重启的原因 问题记录 我在华为云服务器配置npm开发环境的时候&#xff0c; SSH远程连接一直掉线&#xff0c;无奈提了工单&#xff0c;被告知是NPM install导致的OOM问题。无语了&#xff0c;破NPM还有这个问题呢…

SOME/IP学习笔记2

1. SOME/IP 协议 SOME/IP目前支持UDP&#xff08;用户传输协议&#xff09;和TCP&#xff08;传输控制协议&#xff09;&#xff0c; PS:UDP和TCP区别如下 TCP面向连接的&#xff0c;可靠的数据传输服务&#xff1b;UDP面向无连接的&#xff0c;尽最大努力的数据传输服务&…

详细推导MOSFET的跨导、小信号模型、输出阻抗、本征增益

目录 前言 什么是跨导 什么是小信号模型 什么是输入阻抗和输出阻抗 什么是MOS管的输出阻抗 什么是MOS管的本征增益 共源极放大电路的输入和输出阻抗 一些其它MOS拓扑电路的增益 负载为恒流源 负载为二极管 前言 相信很多人在学习集成电路领域的时候 都对MOS管的…

Python 框架学习 Django篇 (十) Redis 缓存

开发服务器系统的时候&#xff0c;程序的性能是至关重要的。经过我们前面框架的学习&#xff0c;得知一个请求的处理基本分为接受http请求、数据库处理、返回json数据&#xff0c;而这3个部分中就属链接数据库请求的响应速度最慢&#xff0c;因为数据库操作涉及到数据库服务处理…

怎么在uni-app中使用Vuex(第一篇)

Vuex简介 vuex的官方网址如下 https://vuex.vuejs.org/zh/ 阅读官网请带着几个问题去阅读&#xff1a; vuex用于什么场景&#xff1f;vuex能给我们带来什么好处&#xff1f;我们为什么要用vuex?vuex如何实现状态集中管理&#xff1f; Vuex用于哪些场景&#xff1f; 组件之…

[量化投资-学习笔记012]Python+TDengine从零开始搭建量化分析平台-策略回测

上一章节《MACD金死叉策略回测》中&#xff0c;对平安银行这只股票&#xff0c;按照金死叉策略进行了回测。 但通常我们的股票池中有许多股票&#xff0c;每完成一个交易策略都需要对整个股票池进行回测。 下面使用简单的轮询&#xff0c;对整个股票池进行回测。 # 计算单只…

动态规划-构建乘积数组

** 描述 给定一个数组 A[0,1,…,n-1] ,请构建一个数组 B[0,1,…,n-1] ,其中 B 的元素 B[i]A[0]A[1]…*A[i-1]A[i1]…*A[n-1]&#xff08;除 A[i] 以外的全部元素的的乘积&#xff09;。程序中不能使用除法。&#xff08;注意&#xff1a;规定 B[0] A[1] * A[2] * … * A[n-1…

.Net 6 Nacos日志控制台疯狂发输出+Log4Net日志过滤

我们的项目配置了Log4Net 作为日志输出工具&#xff0c;在引入Nacos后&#xff0c;控制台和日志里疯狂输出nacos心跳日志和其他相关信息&#xff0c;导致自己记录的信息被淹没了&#xff0c;找了很多解决办法&#xff1a; 1、提高nacos日志级别&#xff0c;然后再屏蔽相应级别…

RK3568平台开发系列讲解(Linux系统篇)Linux内核定时器详解

🚀返回专栏总目录 文章目录 一、系统节拍率二、内核定时器简介三、内核定时器API四、延时函数沉淀、分享、成长,让自己和他人都能有所收获!😄 📢 Linux 内核中有大量的函数需要时间管理,比如周期性的调度程序、延时程序、对于我们驱动编写者来说最常用的定时器。硬件定…

win11下安装odoo17(conda python11)

win11下安装odoo17 odoo17发行了&#xff0c;据说&#xff0c;UI做了很大改进&#xff0c;今天有空&#xff0c;体验一下 打开官方仓库&#xff1a; https://github.com/odoo/odoo 默认的版本已经变成17了 打开odoo/odoo/init.py&#xff0c;发现对python版本的要求也提高了…

GCN代码讲解

这里写的有点抽象&#xff0c;所以具体的可以参照下面代码块中的注释&#xff1a; def load_data(path"../data/cora/", dataset"cora"):"""Load citation network dataset (cora only for now)"""print(Loading {} datase…