GCN,GraphSAGE 到底在训练什么呢?

根据DGL 来做的,按照DGL 实现来讲述

1. GCN Cora 训练代码:

import osos.environ["DGLBACKEND"] = "pytorch"
import dgl
import dgl.data
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GraphConvclass GCN(nn.Module):def __init__(self, in_feats, h_feats, num_classes):super(GCN, self).__init__()self.conv1 = GraphConv(in_feats, h_feats)self.conv2 = GraphConv(h_feats, num_classes)def forward(self, g, in_feat):h = self.conv1(g, in_feat)h = F.relu(h)h = self.conv2(g, h)return hdef train(g, model):optimizer = torch.optim.Adam(model.parameters(), lr=0.01)best_val_acc = 0best_test_acc = 0features = g.ndata["feat"]labels = g.ndata["label"]train_mask = g.ndata["train_mask"]val_mask = g.ndata["val_mask"]test_mask = g.ndata["test_mask"]for e in range(100):# Forwardlogits = model(g, features)# Compute predictionpred = logits.argmax(1)# Compute loss# Note that you should only compute the losses of the nodes in the training set.loss = F.cross_entropy(logits[train_mask], labels[train_mask])# Compute accuracy on training/validation/testtrain_acc = (pred[train_mask] == labels[train_mask]).float().mean()val_acc = (pred[val_mask] == labels[val_mask]).float().mean()test_acc = (pred[test_mask] == labels[test_mask]).float().mean()# Save the best validation accuracy and the corresponding test accuracy.if best_val_acc < val_acc:best_val_acc = val_accbest_test_acc = test_acc# Backwardoptimizer.zero_grad()loss.backward()optimizer.step()if e % 5 == 0:print(f"In epoch {e}, loss: {loss:.3f}, val acc: {val_acc:.3f} (best {best_val_acc:.3f}), test acc: {test_acc:.3f} (best {best_test_acc:.3f})")if __name__ == "__main__" :dataset = dgl.data.CoraGraphDataset()# print(f"Number of categories: {dataset.num_classes}")g = dataset[0]g = g.to('cuda')model = GCN(g.ndata["feat"].shape[1], 16, dataset.num_classes).to('cuda')train(g, model)

一些基础python torch.tensor语法概述:

1.  

if __name__ == "__main__" :XXXXXXXXXXXXXX

当我们直接执行这个脚本时,__name__属性被设置为__main__,因此满足if条件,语句块中的代码被调用。
但如果我们将该脚本作为模块导入到另一个脚本中,则__name__属性会被设置为模块的名称(例如"example"),语句块中的代码不会被执行。

2. 

# Compute prediction
pred = logits.argmax(1)    # 返回沿着第一个维度(即维度索引为1)的最大值的索引。# 即,加入有5个样本,每个样本有3个维度的评分,那么就会给出没个样本3中维度评分最高的哪个维度的索引序号

 

3. numpy 关于 tensor 的一个用法:

在DGL 中使用一串 True 或 False 组成的 一维tensor 来标识 这个节点到底是属于 train test val 哪一类

train_mask = g.ndata["train_mask"]
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]

而后,由于对于torch中的tensor来说:

就可以:select_label_tensor = labels[train_mask] 了

import torch# 定义一个Tensor
tensor = torch.tensor([1, 2, 3, 4, 5])# 定义一个布尔数组,选择索引为1和4的元素
mask = torch.tensor([False, True, False, False, True])# 通过布尔索引选择元素
selected_tensor = tensor[mask]print(selected_tensor)  # tensor([2, 5])

顺便,查看一个变量到底是什么类型可以使用 type() 函数:

train_mask = g.ndata["train_mask"]
print(type(train_mask))# 输出为:
# <class 'torch.Tensor'>

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/197053.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux系统安装Python3环境

1、默认情况下&#xff0c;Linux会自带安装Python&#xff0c;可以运行python --version命令查看&#xff0c;如图&#xff1a; 我们看到Linux中已经自带了Python2.7.5。再次运行python命令后就可以使用python命令窗口了&#xff08;CtrlD退出python命令窗口&#xff09;。 2…

数据结构——二叉树(相关术语、性质、遍历过程)

遍历操作 二叉树的层次遍历-CSDN博客 二叉树的基本操作-CSDN博客 二叉树的先序遍历非递归实现-CSDN博客 后序遍历的非递归方式实现-CSDN博客 二叉树&#xff1a;已知先序中序求后序或者其他&#xff08;秒解&#xff09;-CSDN博客 因为之前发过一遍&#xff0c;我就不复制…

MES物料的动态批次管理漫谈

在制造企业中&#xff0c;原辅材料占产品制造总成本基本在60%以上&#xff0c;特殊材料加工企业可能达到80%以上&#xff0c;按“2/8管理原则”管理好物料就基本做好制造企业的成本管理&#xff0c;这也许是很多企业向“数字化转型”的一个主要原因&#xff0c;希望借助数字信息…

智能指针与动态内存

动态内存 new placement new 是 C 中的一种内存分配方式&#xff0c;它允许在给定的内存地址上构造对象&#xff0c;而不是在默认的堆上分配新的内存。这对于某些特殊的内存管理场景非常有用&#xff0c;例如在特定的内存池中分配对象。 C11 引入了 "new auto" 语法…

LiveGBS流媒体平台GB/T28181功能-概览中负载信息直播、回放、播放、录像、H265、级联查看负载会话列表

LiveGBS常见问题-概览中负载信息具体表示什么直播、回放、播放、录像、H265、级联等 1、负载信息2、负载信息说明3、会话列表查看3.1、会话列表 4、搭建GB28181视频直播平台 1、负载信息 实时展示直播、回放、播放、录像、H265、级联等使用数目 2、负载信息说明 直播&#x…

4.grid_sample理解与使用

pytorch中的grid_sample 文章目录 pytorch中的grid_samplegrid_samplegrid_sample函数原型实例 欢迎访问个人网络日志&#x1f339;&#x1f339;知行空间&#x1f339;&#x1f339; grid_sample 直译为网格采样&#xff0c;给定一个mask patch&#xff0c;根据在目标图像上的…

Http和WebSocket

客户端发送一次http请求&#xff0c;服务器返回一次http响应。 问题&#xff1a;如何在客户端没有发送请求的情况下&#xff0c;返回服务端的响应&#xff0c;网页可以得服务器数据&#xff1f; 1&#xff1a;http定时轮询 客户端定时发送http请求&#xff0c;eg&#…

2023经典软件测试面试题

1、问&#xff1a;你在测试中发现了一个bug&#xff0c;但是开发经理认为这不是一个bug&#xff0c;你应该怎样解决&#xff1f; 首先&#xff0c;将问题提交到缺陷管理库里面进行备案。 然后&#xff0c;要获取判断的依据和标准&#xff1a; 根据需求说明书、产品说明、设计…

AI浪潮下,非科班出身还有机会入行程序开发领域么?

前言 随着人工智能技术的快速发展和广泛应用&#xff0c;程序开发领域正迎来前所未有的挑战和机遇。但是对于非科班出身的个人而言&#xff0c;是否还有机会进入这个充满竞争的行业&#xff0c;成为一名程序员&#xff1f;那么本文就来聊聊AI浪潮下&#xff0c;分析当前程序员就…

整数和浮点数在内存中的存储

文章目录 每日一言整数在内存中的存储方式浮点数在内存中的存储结语 每日一言 You just can’t beat the person who never gives up. 你无法打败那位永不放弃的人。 整数在内存中的存储方式 整数在内存中的存储方式通常采用二进制形式&#xff0c;即将整数的数值转化为二进制…

ubuntu16.04升级openssl

Ubuntu16.04 默认带的openssl版本为1.0.2 查看&#xff1a;openssl version 1.下载openssl wget https://www.openssl.org/source/openssl-1.1.1.tar.gz 编译安装 tar xvf openssl-1.1.1.tar.gz cd openssl-1.1.1 ./config make sudo make install sudo ldconfig 删除旧版本 su…

XXL-Job详解(五):动态添加、启动任务

目录 前言XXL-Job API接口添加任务API动态添加任务动态启动任务 前言 看该文章之前&#xff0c;最好看一下之前的文章&#xff0c;比较方便我们理解 XXL-Job详解&#xff08;一&#xff09;&#xff1a;组件架构 XXL-Job详解&#xff08;二&#xff09;&#xff1a;安装部署 X…

沐风老师3DMAX随机变换工具RandomTransform插件使用方法详解

3DMAX随机变换工具RandomTransform插件使用方法 3dMax随机变换工具RandomTransform&#xff0c;是一款用MAXScript脚本语言开发的3dsMax小工具&#xff0c;可以随机变换选中的单个或多个对象的位置、角度及大小。 在3dMax中“变换”工具是最常用的工具&#xff08;移动、旋转和…

vue3+ts项目中导入组件时报错has no default export

下面这句会报错has no default export import Button from "./components/Button.vue";使用vetur这个插件&#xff08;我目前的版本是0.37.3&#xff0c;应该是这个版本之前的都不支持&#xff09;。但是依旧报错&#xff0c;所以我选择禁用了&#xff0c;就不报错了…

selenium自动化测试实战案例

Chrome DevTools 简介 Chrome DevTools 是一组直接内置在基于 Chromium 的浏览器&#xff08;如 Chrome、Opera 和 Microsoft Edge&#xff09;中的工具&#xff0c;用于帮助开发人员调试和研究网站。 借助 Chrome DevTools&#xff0c;开发人员可以更深入地访问网站&#xf…

8.4 Windows驱动开发:文件微过滤驱动入门

MiniFilter 微过滤驱动是相对于SFilter传统过滤驱动而言的&#xff0c;传统文件过滤驱动相对来说较为复杂&#xff0c;且接口不清晰并不符合快速开发的需求&#xff0c;为了解决复杂的开发问题&#xff0c;微过滤驱动就此诞生&#xff0c;微过滤驱动在编写时更简单&#xff0c;…

全网最牛最“刑”的Fiddler移动端抓包

本篇文章&#xff0c;博主想使用通俗易懂的话语&#xff0c;让大家明白以下内容&#xff1a; 什么是抓包哪些场景需要用到抓包Fiddler抓包的原理怎样使用Fiddler进行移动端抓包 抓包 包 (Packet) 是TCP/IP协议通信传输中的数据单位&#xff0c;一般也称“数据包”。 我们平常…

无人机智慧工地:助力工地管理的未来之选

在现代工地管理中&#xff0c;无人机凭借其小巧、轻便和多角度拍摄等特点得到广泛应用&#xff0c;尤其在智慧工地的现场管理中发挥着重要作用。 一、无人机代替人工巡检省时省力 以往&#xff0c;施工现场检查主要依赖人工巡检方式&#xff0c;需要较长时间。而现在&#xff…

链表【2】

文章目录 &#x1f95d;24. 两两交换链表中的节点&#x1f951;题目&#x1f33d;算法原理&#x1f96c;代码实现 &#x1f34e;143. 重排链表&#x1f352;题目&#x1f345;算法原理&#x1f353;代码实现 &#x1f95d;24. 两两交换链表中的节点 &#x1f951;题目 题目链接…

KMP字符串

试题传送门&#xff1a;831. KMP字符串 给定一个字符串 S&#xff0c;以及一个模式串 P&#xff0c;所有字符串中只包含大小写英文字母以及阿拉伯数字。 模式串 P 在字符串 S 中多次作为子串出现。 求出模式串 P 在字符串 S 中所有出现的位置的起始下标。 输入格式 第一行输入…