GCN,GraphSAGE 到底在训练什么呢?

根据DGL 来做的,按照DGL 实现来讲述

1. GCN Cora 训练代码:

import osos.environ["DGLBACKEND"] = "pytorch"
import dgl
import dgl.data
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GraphConvclass GCN(nn.Module):def __init__(self, in_feats, h_feats, num_classes):super(GCN, self).__init__()self.conv1 = GraphConv(in_feats, h_feats)self.conv2 = GraphConv(h_feats, num_classes)def forward(self, g, in_feat):h = self.conv1(g, in_feat)h = F.relu(h)h = self.conv2(g, h)return hdef train(g, model):optimizer = torch.optim.Adam(model.parameters(), lr=0.01)best_val_acc = 0best_test_acc = 0features = g.ndata["feat"]labels = g.ndata["label"]train_mask = g.ndata["train_mask"]val_mask = g.ndata["val_mask"]test_mask = g.ndata["test_mask"]for e in range(100):# Forwardlogits = model(g, features)# Compute predictionpred = logits.argmax(1)# Compute loss# Note that you should only compute the losses of the nodes in the training set.loss = F.cross_entropy(logits[train_mask], labels[train_mask])# Compute accuracy on training/validation/testtrain_acc = (pred[train_mask] == labels[train_mask]).float().mean()val_acc = (pred[val_mask] == labels[val_mask]).float().mean()test_acc = (pred[test_mask] == labels[test_mask]).float().mean()# Save the best validation accuracy and the corresponding test accuracy.if best_val_acc < val_acc:best_val_acc = val_accbest_test_acc = test_acc# Backwardoptimizer.zero_grad()loss.backward()optimizer.step()if e % 5 == 0:print(f"In epoch {e}, loss: {loss:.3f}, val acc: {val_acc:.3f} (best {best_val_acc:.3f}), test acc: {test_acc:.3f} (best {best_test_acc:.3f})")if __name__ == "__main__" :dataset = dgl.data.CoraGraphDataset()# print(f"Number of categories: {dataset.num_classes}")g = dataset[0]g = g.to('cuda')model = GCN(g.ndata["feat"].shape[1], 16, dataset.num_classes).to('cuda')train(g, model)

一些基础python torch.tensor语法概述:

1.  

if __name__ == "__main__" :XXXXXXXXXXXXXX

当我们直接执行这个脚本时,__name__属性被设置为__main__,因此满足if条件,语句块中的代码被调用。
但如果我们将该脚本作为模块导入到另一个脚本中,则__name__属性会被设置为模块的名称(例如"example"),语句块中的代码不会被执行。

2. 

# Compute prediction
pred = logits.argmax(1)    # 返回沿着第一个维度(即维度索引为1)的最大值的索引。# 即,加入有5个样本,每个样本有3个维度的评分,那么就会给出没个样本3中维度评分最高的哪个维度的索引序号

 

3. numpy 关于 tensor 的一个用法:

在DGL 中使用一串 True 或 False 组成的 一维tensor 来标识 这个节点到底是属于 train test val 哪一类

train_mask = g.ndata["train_mask"]
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]

而后,由于对于torch中的tensor来说:

就可以:select_label_tensor = labels[train_mask] 了

import torch# 定义一个Tensor
tensor = torch.tensor([1, 2, 3, 4, 5])# 定义一个布尔数组,选择索引为1和4的元素
mask = torch.tensor([False, True, False, False, True])# 通过布尔索引选择元素
selected_tensor = tensor[mask]print(selected_tensor)  # tensor([2, 5])

顺便,查看一个变量到底是什么类型可以使用 type() 函数:

train_mask = g.ndata["train_mask"]
print(type(train_mask))# 输出为:
# <class 'torch.Tensor'>

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/197053.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux系统安装Python3环境

1、默认情况下&#xff0c;Linux会自带安装Python&#xff0c;可以运行python --version命令查看&#xff0c;如图&#xff1a; 我们看到Linux中已经自带了Python2.7.5。再次运行python命令后就可以使用python命令窗口了&#xff08;CtrlD退出python命令窗口&#xff09;。 2…

数据结构——二叉树(相关术语、性质、遍历过程)

遍历操作 二叉树的层次遍历-CSDN博客 二叉树的基本操作-CSDN博客 二叉树的先序遍历非递归实现-CSDN博客 后序遍历的非递归方式实现-CSDN博客 二叉树&#xff1a;已知先序中序求后序或者其他&#xff08;秒解&#xff09;-CSDN博客 因为之前发过一遍&#xff0c;我就不复制…

MES物料的动态批次管理漫谈

在制造企业中&#xff0c;原辅材料占产品制造总成本基本在60%以上&#xff0c;特殊材料加工企业可能达到80%以上&#xff0c;按“2/8管理原则”管理好物料就基本做好制造企业的成本管理&#xff0c;这也许是很多企业向“数字化转型”的一个主要原因&#xff0c;希望借助数字信息…

智能指针与动态内存

动态内存 new placement new 是 C 中的一种内存分配方式&#xff0c;它允许在给定的内存地址上构造对象&#xff0c;而不是在默认的堆上分配新的内存。这对于某些特殊的内存管理场景非常有用&#xff0c;例如在特定的内存池中分配对象。 C11 引入了 "new auto" 语法…

LiveGBS流媒体平台GB/T28181功能-概览中负载信息直播、回放、播放、录像、H265、级联查看负载会话列表

LiveGBS常见问题-概览中负载信息具体表示什么直播、回放、播放、录像、H265、级联等 1、负载信息2、负载信息说明3、会话列表查看3.1、会话列表 4、搭建GB28181视频直播平台 1、负载信息 实时展示直播、回放、播放、录像、H265、级联等使用数目 2、负载信息说明 直播&#x…

4.grid_sample理解与使用

pytorch中的grid_sample 文章目录 pytorch中的grid_samplegrid_samplegrid_sample函数原型实例 欢迎访问个人网络日志&#x1f339;&#x1f339;知行空间&#x1f339;&#x1f339; grid_sample 直译为网格采样&#xff0c;给定一个mask patch&#xff0c;根据在目标图像上的…

Http和WebSocket

客户端发送一次http请求&#xff0c;服务器返回一次http响应。 问题&#xff1a;如何在客户端没有发送请求的情况下&#xff0c;返回服务端的响应&#xff0c;网页可以得服务器数据&#xff1f; 1&#xff1a;http定时轮询 客户端定时发送http请求&#xff0c;eg&#…

2023经典软件测试面试题

1、问&#xff1a;你在测试中发现了一个bug&#xff0c;但是开发经理认为这不是一个bug&#xff0c;你应该怎样解决&#xff1f; 首先&#xff0c;将问题提交到缺陷管理库里面进行备案。 然后&#xff0c;要获取判断的依据和标准&#xff1a; 根据需求说明书、产品说明、设计…

AI浪潮下,非科班出身还有机会入行程序开发领域么?

前言 随着人工智能技术的快速发展和广泛应用&#xff0c;程序开发领域正迎来前所未有的挑战和机遇。但是对于非科班出身的个人而言&#xff0c;是否还有机会进入这个充满竞争的行业&#xff0c;成为一名程序员&#xff1f;那么本文就来聊聊AI浪潮下&#xff0c;分析当前程序员就…

JAVA泛型概念的理解

什么是泛型&#xff1f; 泛型是JAVA语言中一种增强类型安全性的机制&#xff0c;它允许程序员在类&#xff0c;接口和方法中使用类型参数&#xff0c;以便在编译时进行类型检查&#xff0c;并在运行时生成正确的代码。泛型的主要目的是提高代码的可重用性和可读性&#xff0c;…

整数和浮点数在内存中的存储

文章目录 每日一言整数在内存中的存储方式浮点数在内存中的存储结语 每日一言 You just can’t beat the person who never gives up. 你无法打败那位永不放弃的人。 整数在内存中的存储方式 整数在内存中的存储方式通常采用二进制形式&#xff0c;即将整数的数值转化为二进制…

zxjy002- 后端项目环境搭建

1、前后端分离的概念 1.1 前端开发思想 主要绑定数据显示在页面 1.2 后端开发思想 返回数据或者操作数据&#xff0c;开发controller、service、mapper过程 2、Java后端项目环境搭建 2.1 数据库及表 创建数据库名&#xff1a;guli_edu创建数据库表&#xff1a; 执行 gul…

vue3自定义指令-文本超出宽度滚动

fontScroll.ts 指令文件 import { Directive } from vuefunction randomInt(min, max) {return Math.floor(Math.random() * (max - min 1)) min; } export default {// 可控制滚动速度&#xff0c;默认滚动速度20px/s,最低动画时长2smounted: (el, binding, vNode): void &…

daterangepicker的使用

daterangepicker 是基于JQuery的时间日期选择插件&#xff0c;是一个配合 bootstrap 框架使用的时间范围选择 js 组件 引入 js 和 css <link href"bootstrap.min.css"> <link href"daterangepicker.css"> <script src"jquer…

ubuntu16.04升级openssl

Ubuntu16.04 默认带的openssl版本为1.0.2 查看&#xff1a;openssl version 1.下载openssl wget https://www.openssl.org/source/openssl-1.1.1.tar.gz 编译安装 tar xvf openssl-1.1.1.tar.gz cd openssl-1.1.1 ./config make sudo make install sudo ldconfig 删除旧版本 su…

XXL-Job详解(五):动态添加、启动任务

目录 前言XXL-Job API接口添加任务API动态添加任务动态启动任务 前言 看该文章之前&#xff0c;最好看一下之前的文章&#xff0c;比较方便我们理解 XXL-Job详解&#xff08;一&#xff09;&#xff1a;组件架构 XXL-Job详解&#xff08;二&#xff09;&#xff1a;安装部署 X…

沐风老师3DMAX随机变换工具RandomTransform插件使用方法详解

3DMAX随机变换工具RandomTransform插件使用方法 3dMax随机变换工具RandomTransform&#xff0c;是一款用MAXScript脚本语言开发的3dsMax小工具&#xff0c;可以随机变换选中的单个或多个对象的位置、角度及大小。 在3dMax中“变换”工具是最常用的工具&#xff08;移动、旋转和…

vue3+ts项目中导入组件时报错has no default export

下面这句会报错has no default export import Button from "./components/Button.vue";使用vetur这个插件&#xff08;我目前的版本是0.37.3&#xff0c;应该是这个版本之前的都不支持&#xff09;。但是依旧报错&#xff0c;所以我选择禁用了&#xff0c;就不报错了…

Redis的基本数据类型及常用命令

Redis通用命令 KEYS命令用于查找所有匹配给定模式 pattern 的 key 。生产环境下不建议使用KEYS命令&#xff0c;会影响效率。 匹配规则&#xff1a; h?llo 匹配 hello, hallo 和 hxlloh*llo 匹配 hllo 和 heeeelloh[ae]llo 匹配 hello and hallo, 不匹配 hilloh[^e]llo 匹配 …

selenium自动化测试实战案例

Chrome DevTools 简介 Chrome DevTools 是一组直接内置在基于 Chromium 的浏览器&#xff08;如 Chrome、Opera 和 Microsoft Edge&#xff09;中的工具&#xff0c;用于帮助开发人员调试和研究网站。 借助 Chrome DevTools&#xff0c;开发人员可以更深入地访问网站&#xf…