RootNeighboursDataset(helpers.dataset_classes文件中的root_neighbours_dataset.py)

任务类型:回归
用途:在 `RootNeighboursDataset` 中,任务是给定一棵根树,预测根节点度数为6的邻居的特征平均值。因此,模型需要基于根节点的结构,找到度为6的邻居,并计算其特征的平均值。这属于回归问题,因为目标是预测连续值(特征的平均值)

from helpers.dataset_classes.root_neighbours_dataset import RootNeighboursDataset

import torch
from torch_geometric.data import Data, Batch
from typing import Dict, Tuple, List
from torch import Tensorclass RootNeighboursDataset(object):def __init__(self, seed: int, print_flag: bool = False):super().__init__()self.seed = seedself.plot_flag = print_flagself.generator = torch.Generator().manual_seed(seed)self.constants_dict = self.initialize_constants()self._data = self.create_data()def get(self) -> Data:return self._datadef create_data(self) -> Data:# train, val, testdata_list = []for num in range(self.constants_dict['NUM_COMPONENTS']):data_list.append(self.generate_component())return Batch.from_data_list(data_list)def mask_task(self, num_nodes_per_fold: List[int]) -> Tuple[Tensor, Tensor, Tensor]:num_nodes = sum(num_nodes_per_fold)train_mask = torch.zeros(size=(num_nodes,), dtype=torch.bool)val_mask = torch.zeros(size=(num_nodes,), dtype=torch.bool)test_mask = torch.zeros(size=(num_nodes,), dtype=torch.bool)train_mask[0] = Trueval_mask[num_nodes_per_fold[0]] = Truetest_mask[num_nodes_per_fold[0] + num_nodes_per_fold[1]] = Truereturn train_mask, val_mask, test_maskdef generate_component(self) -> Data:data_per_fold, num_nodes_per_fold = [], []for fold_idx in range(3):data = self.generate_fold(eval=(fold_idx != 0))num_nodes_per_fold.append(data.x.shape[0])data_per_fold.append(data)train_mask, val_mask, test_mask = self.mask_task(num_nodes_per_fold=num_nodes_per_fold)batch = Batch.from_data_list(data_per_fold)return Data(x=batch.x, edge_index=batch.edge_index, y=batch.y, train_mask=train_mask, val_mask=val_mask,test_mask=test_mask)def initialize_constants(self) -> Dict[str, int]:return {'NUM_COMPONENTS': 1000, 'MAX_HUBS': 3, 'MAX_1HOP_NEIGHBORS': 10, 'ADD_HUBS': 2, 'HUB_NEIGHBORS': 5,'MAX_2HOP_NEIGHBORS': 3, 'NUM_FEATURES': 5}def generate_fold(self, eval: bool) -> Data:constant_dict = self.initialize_constants()MAX_HUBS, MAX_1HOP_NEIGHBORS, ADD_HUBS, HUB_NEIGHBORS, MAX_2HOP_NEIGHBORS, NUM_FEATURES =\[constant_dict[key] for key in ['MAX_HUBS', 'MAX_1HOP_NEIGHBORS', 'ADD_HUBS', 'HUB_NEIGHBORS','MAX_2HOP_NEIGHBORS', 'NUM_FEATURES']]assert MAX_HUBS + ADD_HUBS <= MAX_1HOP_NEIGHBORSadd_hubs = ADD_HUBS if eval else 0num_hubs = torch.randint(1, MAX_HUBS + 1, size=(1,), generator=self.generator).item() + add_hubsnum_1hop_neighbors = torch.randint(MAX_HUBS + add_hubs, MAX_1HOP_NEIGHBORS + 1, size=(1,),generator=self.generator).item()assert num_hubs <= num_1hop_neighborslist_num_2hop_neighbors = torch.randint(1, MAX_2HOP_NEIGHBORS, size=(num_1hop_neighbors - num_hubs,),generator=self.generator).tolist()list_num_2hop_neighbors = [HUB_NEIGHBORS] * num_hubs + list_num_2hop_neighbors# 2 hop edge indexnum_nodes = 1  # root node is 0idx_1hop_neighbors = []list_edge_index = []for num_2hop_neighbors in list_num_2hop_neighbors:idx_1hop_neighbors.append(num_nodes)if num_2hop_neighbors > 0:clique_edge_index = torch.tensor([[0] * num_2hop_neighbors, list(range(1, num_2hop_neighbors + 1))])# clique_edge_index = torch.combinations(torch.arange(num_2hop_neighbors), r=2).Tlist_edge_index.append(clique_edge_index + num_nodes)num_nodes += num_2hop_neighbors + 1# 1 hop edge indexidx_0hop = torch.tensor([0] * num_1hop_neighbors)idx_1hop_neighbors = torch.tensor(idx_1hop_neighbors)hubs = idx_1hop_neighbors[:num_hubs]list_edge_index.append(torch.stack((idx_0hop, idx_1hop_neighbors), dim=0))edge_index = torch.cat(list_edge_index, dim=1)# undirectedge_index_other_direction = torch.stack((edge_index[1], edge_index[0]), dim=0)edge_index = torch.cat((edge_index_other_direction, edge_index), dim=1)# featuresx = 4 * torch.rand(size=(num_nodes, NUM_FEATURES), generator=self.generator) - 2# labelsy = torch.zeros_like(x)y[0] = torch.mean(x[hubs], dim=0)return Data(x=x, edge_index=edge_index, y=y)if __name__ == '__main__':data = RootNeighboursDataset(seed=0, print_flag=True)

这个 RootNeighboursDataset通过随机生成的树状图数据来模拟一种节点关系,并基于图结构生成特征和标签。代码使用了 PyTorchPyTorch Geometric 的功能来处理图数据。下面逐块详细解释该代码实现:

1. RootNeighboursDataset 类构造器

import torch
from torch_geometric.data import Data, Batch
from typing import Dict, Tuple, List
from torch import Tensorclass RootNeighboursDataset(object):def __init__(self, seed: int, print_flag: bool = False):super().__init__()self.seed &#

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/57380.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++ 抛异常

目录 一.抛异常与运行崩溃的区别 1.运行崩溃 2.抛异常 二.抛异常机制存在的意义 1.清晰的处理错误 2.结构化的错误管理 3.跨函数传递错误信息 4.异常对象多态性 三.抛异常的使用方法 1.抛出异常 (throw) 2.捕获异常 (catch) 3.标准异常类 四.抛异常的处理机制 1.抛…

【MySQL备份】Percona XtraBackup

这份文档针对的是最新发布的版本&#xff1a;Percona XtraBackup 2.4.29&#xff08;发布说明&#xff09;。 Percona XtraBackup是一款针对MySQL系列服务器的开源热备份工具&#xff0c;在备份过程中不会锁定您的数据库。它能够对MySQL 5.1、5.5、5.6和5.7服务器以及带有Xtra…

UDP传输协议Linux C语言实战

文章目录 1.UDP简介1.1特点1.2 UDP协议头部格式1.2.1 **UDP头部**&#xff1a;1.2.2 **头部意义**&#xff1a;1.2.3 **头部参数**&#xff1a; 1.3 UDP数据长度控制1.4 UDP协议建立框架 2. 函数介绍2.1 sendto函数2.2 recvform函数2.3 其他函数 3.实例3.1 通用结构体、IPV4结构…

转置卷积的一些理解

转置卷积 当图像输入到卷积网络中&#xff0c;最终生成的特征图的宽高会减小 在语义分割中标签和原始图像大小一致&#xff0c;若输出宽高减小&#xff0c;不利于标签比对 于是使用转置卷积将图像宽高还原 在卷积的时候&#xff0c;通常输入大于输出&#xff0c;可根据输入大小…

如何通过 Service Mesh 构建高效、安全的微服务系统

1. 引言 1.1.什么是 Service Mesh&#xff1f; Service Mesh 是一种基础架构层&#xff0c;负责处理微服务之间的通信&#xff0c;它通过在每个服务旁边部署代理&#xff08;通常称为 Sidecar&#xff09;来捕获和管理服务间的网络流量。这种方式解耦了微服务的业务逻辑和基础…

【Linux】waitpid函数 及其 非阻塞等待和阻塞等待

父进程等待子进程结束可以通过两种方式实现&#xff1a;阻塞等待和非阻塞等待。这两种方式各有优缺点&#xff0c;适用于不同的场景。 简单来说&#xff1a; 阻塞等待&#xff1a;先等你&#xff0c;我再继续 非阻塞等待&#xff1a;不等你&#xff0c;我继续做自己的事&…

使用Python实现某易云音乐歌曲下载

前言 在这篇文章中,我们将探讨如何通过Python结合JavaScript代码来逆向网易云音乐的API接口,以获取并下载指定歌曲。请注意,本文仅用于技术学习与交流目的,实际使用时请遵守相关法律法规及服务条款。 目标网站 1. 准备工作 首先,我们需要安装一些必要的库: execjs:用…

NVIDIA RTX 5080移动版GPU真身首曝!全系要用GDDR7

英伟达下一代移动版GPU的神秘面纱似乎正在揭开&#xff0c;Moore’s Law is Dead的最新视频首次曝光了疑似RTX 5080移动版GPU的工程样品照片。 这款工程样品印有N22W-ES-A1&#xff0c;与Clevo的下一代笔记本主板规格表相匹配&#xff0c;表明该芯片确实基于NVIDIA的下一代芯片…

java 提示 避免用Apache Beanutils进行属性的copy。

避免用Apache Beanutils进行属性的copy。 Inspection info: 避免用Apache Beanutils进行属性的copy。 说明&#xff1a;Apache BeanUtils性能较差&#xff0c;可以使用其他方案比如Spring BeanUtils, Cglib BeanCopier。 TestObject a new TestObject(); TestObject b new Te…

Cadence元件A属性和B属性相互覆盖

最近在使用第三方插件集成到Cadence,协助导出BOM到平台上&#xff0c;方便对BOM进行管理和修改&#xff0c;结果因为属性A和属性B不相同&#xff0c;导致导出的BOM错误。如下图&#xff1a; ​​ 本来我们需要导出Q12&#xff0c;结果给我们导出了Q13&#xff0c;或者反之&…

【Python】基础语法错误和异常

在Python中&#xff0c;语法错误和异常是两个常见的问题。下面对它们进行简要介绍。 1.语法错误 (Syntax Error) 语法错误是指代码的语法不符合Python的语言规则。当Python解释器读取程序代码时&#xff0c;如果发现语法不正确&#xff0c;就会抛出语法错误。这种错误通常在代…

SpringBoot实现的高效民宿预订平台

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统&#xff0c;它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等&#xff0c;非常…

AWD的复现

学习awd的相关资料&#xff1a;速成AWD并获奖的学习方法和思考记录- Track 知识社区 - 掌控安全在线教育 - Powered by 掌控者&#xff08;包含使用脚本去批量修改密码&#xff09; 在复现之前去了解了以下AWD的相关脚本 资料&#xff1a;AWD批量攻击脚本使用教程-CSDN博客 …

cfg80211-- 修复添加EHT的capabilities能力供驱动使用

需求: 添加支持以检索信标模板中用户空间传递的EHT功能和EHT操作元素,并将指针存储在结构cfg80211_ap_settings中供驱动程序使用 在nl80211_calculate_ap_params方法种 ---------------------------- net/wireless/nl80211.c ---------------------------- index f3ad5f2.…

etcd入门到实战

概述&#xff1a;本文将介绍etcd特性、使用场景、基本原理以及Linux环境下的实战操作 入门 什么是etcd&#xff1f; etcd是一个分布式键值存储数据库 关键字解析&#xff1a; 键值存储&#xff1a;存储协议是 key—value 的形式&#xff0c;类似于redis分布式&#xff1a;…

13_渲染器的设计

目录 渲染器与响应式系统的结合渲染器的基本概念自定义渲染器 渲染器与响应式系统的结合 渲染器与响应式系统是相辅相成的&#xff0c;渲染器负责将响应式系统中的响应式数据渲染到视图中&#xff0c;而响应式系统则负责监听数据的变化并通知渲染器进行更新。 渲染器在浏览器…

大数据-184 Elasticsearch - 原理剖析 - DocValues 机制原理 压缩与禁用

点一下关注吧&#xff01;&#xff01;&#xff01;非常感谢&#xff01;&#xff01;持续更新&#xff01;&#xff01;&#xff01; 目前已经更新到了&#xff1a; Hadoop&#xff08;已更完&#xff09;HDFS&#xff08;已更完&#xff09;MapReduce&#xff08;已更完&am…

在 Docker 中搭建 PostgreSQL16 主从同步环境

1. 环境搭建 本文介绍了如何在同一台机器上使用 Docker 容器搭建 PostgreSQL 的主从同步环境。通过创建互联网络和配置主库及从库&#xff0c;详细讲解了数据库初始化、角色创建、数据同步和验证步骤。主要步骤包括设置主库的连接信息、创建用于复制的角色、使用 pg_basebacku…

2024系统分析师考试---论区块链技术及其应用

试题三论区块链技术及其应用 区块链作为一种分布式记账技术,目前已经被应用到了资产管理、物联网、医疗管理、政务监管等多个领域,从网络层面来讲,区块链是一个对等网络(Peer to Peer,P2P),网络中的节点地位对等,每个节点都保存完整的账本数据,系统的运行不依赖中心化节…

成都跃享未来教育咨询有限公司抖音小店新生态

在数字化浪潮席卷全球的今天&#xff0c;教育行业正经历着前所未有的变革与升级。作为一座历史悠久而又充满活力的城市&#xff0c;成都凭借其深厚的文化底蕴和前瞻性的发展眼光&#xff0c;孕育了众多创新型企业。其中&#xff0c;成都跃享未来教育咨询有限公司&#xff08;以…