【PYG】dataloader和densedataloader

DenseDataLoader 是专门用于处理稠密图数据的,而 DataLoader 通常用于处理稀疏图数据。两者的主要区别在于它们的输入数据格式和处理方式。DenseDataLoader 适合处理固定大小的邻接矩阵和节点特征矩阵的数据,而 DataLoader 更加灵活,可以处理稀疏表示的图数据。

主要区别

  • DataLoader:

    • 适合处理稀疏图数据。
    • 通常与 torch_geometric.data.Data 一起使用,其中边索引是稀疏表示的。
    • 更加灵活,适合处理各种不同形状和大小的图。
  • DenseDataLoader:

    • 适合处理稠密图数据。
    • 通常与固定大小的邻接矩阵和节点特征矩阵一起使用。
    • 更高效地处理固定大小的图数据。

使用示例

使用 DenseDataLoader

如果你有固定大小的邻接矩阵和节点特征矩阵,可以直接使用 DenseDataLoader 加载数据:

1. 导入必要的库
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DenseDataLoader
2. 定义数据集类
class MyDenseDataset(torch.utils.data.Dataset):def __init__(self, num_samples, num_nodes, num_node_features):self.num_samples = num_samplesself.num_nodes = num_nodesself.num_node_features = num_node_featuresself.adj_matrix = self.create_adj_matrix(num_nodes)def create_adj_matrix(self, num_nodes):# 创建环形图的邻接矩阵adj_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.float)for i in range(num_nodes):adj_matrix[i, (i + 1) % num_nodes] = 1adj_matrix[(i + 1) % num_nodes, i] = 1return adj_matrixdef __len__(self):return self.num_samplesdef __getitem__(self, idx):# 创建随机特征和标签x = torch.randn((self.num_nodes, self.num_node_features))y = torch.randn((self.num_nodes, 1))  # 每个节点一个标签return Data(x=x, adj=self.adj_matrix, y=y)
3. 创建数据集和封装数据
# 参数设置
num_samples = 100  # 样本数
num_nodes = 10  # 每个图中的节点数
num_node_features = 8  # 每个节点的特征数# 创建数据集
dataset = MyDenseDataset(num_samples, num_nodes, num_node_features)
4. 使用 DenseDataLoader
# 使用 DenseDataLoader 加载数据
loader = DenseDataLoader(dataset, batch_size=32, shuffle=True)# 从 DenseDataLoader 中获取一个批次的数据并查看其形状
for data in loader:print("Batch node features shape:", data.x.shape)  # 期望输出形状为 (32, 10, 8)print("Batch adjacency matrix shape:", data.adj.shape)  # 期望输出形状为 (32, 10, 10)print("Batch labels shape:", data.y.shape)  # 期望输出形状为 (32, 10, 1)break  # 仅查看第一个批次的形状

解释

  1. 导入库

    • 导入 torchtorch_geometric.data 中的 Datatorch_geometric.loader 中的 DenseDataLoader
  2. 定义 MyDenseDataset

    • __init__ 方法初始化数据集参数,并创建邻接矩阵。
    • create_adj_matrix 方法创建环形图的邻接矩阵。
    • __len__ 方法返回数据集的样本数量。
    • __getitem__ 方法生成每个样本的随机节点特征和标签,并返回节点特征矩阵、邻接矩阵和标签。
  3. 创建数据集

    • 使用 MyDenseDataset 类创建一个包含 100 个样本的数据集,每个样本包含 10 个节点,每个节点有 8 个特征。
  4. 使用 DenseDataLoader

    • 使用 DenseDataLoader 加载 dataset,设置批次大小为 32,并进行随机打乱。
    • 在获取一个批次的数据时,检查 xadjy 的形状,以确保其符合期望的三维形状。

通过这个完整的示例代码,你可以生成、封装和加载稠密图数据,并确保每个批次的数据形状保持正确。这种方法适合处理节点数和边数固定的图数据,提高数据加载和处理的效率。

定义数据集类并使用 DenseDataLoader

import torch
from torch_geometric.data import Data
from torch_geometric.loader import DenseDataLoader  # 更新导入路径class MyDenseDataset(torch.utils.data.Dataset):def __init__(self, num_samples, num_nodes, num_node_features):self.num_samples = num_samplesself.num_nodes = num_nodesself.num_node_features = num_node_featuresself.adj_matrix = self.create_adj_matrix(num_nodes)def create_adj_matrix(self, num_nodes):# 创建环形图的邻接矩阵adj_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.float)for i in range(num_nodes):adj_matrix[i, (i + 1) % num_nodes] = 1adj_matrix[(i + 1) % num_nodes, i] = 1print(adj_matrix)return adj_matrixdef __len__(self):return self.num_samplesdef __getitem__(self, idx):# 创建随机特征和标签x = torch.randn((self.num_nodes, self.num_node_features))y = torch.randn((self.num_nodes, 1))  # 每个节点一个标签return Data(x, self.adj_matrix, y=y)# 创建数据集
num_samples = 100  # 样本数
num_nodes = 10  # 每个图中的节点数
num_node_features = 8  # 每个节点的特征数
dataset = MyDenseDataset(num_samples, num_nodes, num_node_features)# 使用 DenseDataLoader 加载数据
loader = DenseDataLoader(dataset, batch_size=32, shuffle=True)# 从 DenseDataLoader 中获取一个批次的数据并查看其形状
for data in loader:print("Batch node features shape:", data.x.shape)  # 期望输出形状为 (32, 10, 8)# print("Batch adjacency matrix shape:", data.adj.shape)  # 期望输出形状为 (32, 10, 10)print("Batch labels shape:", data.y.shape)  # 期望输出形状为 (32, 10, 1)break  # 仅查看第一个批次的形状

使用 DataLoader

如果你使用的是 DataLoader,则数据应当是 torch_geometric.data.Data 对象,并将数据封装在列表中:

import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader  # 更新导入路径class MyDataset(torch.utils.data.Dataset):def __init__(self, num_samples, num_nodes, num_node_features):self.num_samples = num_samplesself.num_nodes = num_nodesself.num_node_features = num_node_featuresdef __len__(self):return self.num_samplesdef __getitem__(self, idx):x = torch.randn(self.num_nodes, self.num_node_features)edge_index = torch.tensor([[i, (i + 1) % self.num_nodes] for i in range(self.num_nodes)], dtype=torch.long).t().contiguous()y = torch.randn(self.num_nodes, 1)return Data(x=x, edge_index=edge_index, y=y)# 创建数据集
num_samples = 100  # 样本数
num_nodes = 10  # 每个图中的节点数
num_node_features = 8  # 每个节点的特征数
dataset = MyDataset(num_samples, num_nodes, num_node_features)# 使用 DataLoader 加载数据
loader = DataLoader(dataset, batch_size=32, shuffle=True)# 迭代加载数据
for batch in loader:print("Batch node features shape:", batch.x.shape)  # 期望输出形状为 (320, 8)print("Batch edge index shape:", batch.edge_index.shape)

总结

  • DenseDataLoader:处理固定大小的邻接矩阵和节点特征矩阵的数据,__getitem__ 返回Data(x, adj, y)。
  • DataLoader:处理 torch_geometric.data.Data 对象,__getitem__ 返回一个 Data 对象。

确保数据格式与使用的加载器相匹配,以避免属性错误和其他兼容性问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/39930.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

flask中解决图片不显示的问题(很细微的点)

我在编写flask项目的时候,在编写html的时候,发现不管我的图片路径如何变化,其就是显示不出来。如下图我框中的地方。 我尝试过使用浏览器打开,是可以的。 一旦运行这个flask项目,就无法显示了。 我查阅资料后。发现…

简易版async/await

参考:https://juejin.cn/post/7007031572238958629?searchId20240704101813568E9B5B1013C881A239#heading-15 总结一下async/await的知识点 1、 await只能在async函数中使用,不然会报错 2、 async函数返回的是一个Promise对象,有无值看有…

泛微开发修炼之旅--29用计划任务定时发送邮件提醒

文章链接:29用计划任务定时发送邮件提醒

[单master节点k8s部署]17.监控系统构建(二)Prometheus安装

prometheus server安装 创建sa账号,对prometheus server进行授权。因为Prometheus是安装在pod里面,以pod的形式去运行的,因此需要创建sa,并对他做rbac授权。 apiVersion: v1 kind: ServiceAccount metadata:name: monitornamesp…

k8s-第九节-命名空间

命名空间 如果一个集群中部署了多个应用,所有应用都在一起,就不太好管理,也可以导致名字冲突等。 我们可以使用 namespace 把应用划分到不同的命名空间,跟代码里的 namespace 是一个概念,只是为了划分空间。 # 创建命…

LeetCode热题100刷题4:76. 最小覆盖子串、239. 滑动窗口最大值、53. 最大子数组和、56. 合并区间

76. 最小覆盖子串 滑动窗口解决字串问题。 labuladong的算法小抄中关于滑动窗口的算法总结&#xff1a; class Solution { public:string minWindow(string s, string t) {unordered_map<char,int> need,window;for(char c : t) {need[c];}int left 0, right 0;int …

2.8亿东亚五国建筑数据分享

数据是GIS的血液&#xff01; 我们现在为你分享东亚5国的2.8亿条建筑轮廓数据&#xff0c;该数据包括中国、日本、朝鲜、韩国和蒙古5个东亚国家完整、高质量的建筑物轮廓数据&#xff0c;你可以在文末查看领取方法。 数据介绍 虽然开源的全球的建筑数据已经有微软的建筑数据…

elementUI中table组件固定列时会渲染两次模板内容问题

今天在使用elementUI的table组件时&#xff0c;由于业务需要固定表格的前几项列&#xff0c;然后获取表格对象时发现竟然有两个对象。 查阅资料发现&#xff0c;elementUI的固定列的实现原理是将两个表格拼装而成&#xff0c;因此获取的对象也是两个。对于需要使用对象的方法的…

vxe-table的序号一样

使用vxe-table的时候&#xff0c;有的时候会出现序号相同的现象&#xff0c;这种现象一般出现在我们后面自己添加的行中&#xff0c;就像这种 此时的这三个序号是相同的&#xff0c;我来说一下原因&#xff0c;这是在添加新的一行的时候&#xff0c;有的时候数据很多&#xff0…

Mac 运行 Windows 软件,Parallels Desktop 19和 CrossOver 24全面对比

Parallels Desktop 和 CrossOver 都是能满足你「在 Mac 上运行 Windows 软件」需求的工具。可能很多人都已经知道 Parallels Desktop 是「虚拟机」&#xff0c;但 CrossOver 其实并不是「虚拟机」。这两款软件有相同的作用&#xff0c;但由于实现原理的不同&#xff0c;两者也有…

系统提示我未定义与 ‘double‘ 类型的输入参数相对应的函数 ‘finverse‘,如何解决?

&#x1f3c6;本文收录于「Bug调优」专栏&#xff0c;主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案&#xff0c;希望能够助你一臂之力&#xff0c;帮你早日登顶实现财富自由&#x1f680;&#xff1b;同时&#xff0c;欢迎大家关注&&收藏&&…

Kubernetes 部署简单的应用

Kubernetes 部署简单的应用 Kubernetes 是一个强大的容器编排平台&#xff0c;它可以帮助我们自动化应用程序的部署、扩展和管理。在本期文章中&#xff0c;我们将学习如何使用 Kubernetes 部署一个简单的应用程序。 1. 环境准备 确保你已经安装了 Kubernetes 集群&#xff…

【python模块】argparse

文章目录 argparse模块介绍基本用法add_argument() argparse模块介绍 argparse 模块是 Python 标准库中的一个用于编写用户友好的命令行接口&#xff08;CLI&#xff09;的模块。它允许程序定义它所需要的命令行参数&#xff0c;然后 argparse 会自动从 sys.argv 解析出那些参…

TCP粘包解决方法

一. 产生原因及解决方法 产生原因&#xff1a;TCP是面向连接、基于字节流的协议&#xff0c;其无边界标记。当服务端处理速度比不其接收速度时&#xff0c;就很容易产生粘包现象。 解决方法&#xff1a;目前主要有两种解决方法&#xff0c;一个是在内容中添加分割标识&#xf…

人脸识别考勤系统

人脸识别考勤系统是一种利用生物识别技术进行自动身份验证的现代解决方案&#xff0c;它通过分析和比对人脸特征来进行员工的出勤记录。这种系统不仅提升了工作效率&#xff0c;还大大减少了人为错误和欺诈行为的可能性。 一、工作原理 人脸识别考勤系统的核心在于其生物识别…

深入剖析Python中的Pandas库:通过实战案例全方位解读数据清洗与预处理艺术

引言 随着大数据时代的到来&#xff0c;数据的质量直接影响到最终分析结果的可靠性和有效性。在这个背景下&#xff0c;Python凭借其灵活强大且易于上手的特点&#xff0c;在全球范围内被广泛应用于数据科学领域。而在Python的数据处理生态中&#xff0c;Pandas库无疑是最耀眼…

高级策略:解读 SQL 中的复杂连接

了解基本连接 在深入研究复杂连接之前&#xff0c;让我们先回顾一下基本连接的基础知识。 INNER JOIN&#xff1a;根据指定的连接条件检索两个表中具有匹配值的记录。LEFT JOIN&#xff1a;从左表检索所有记录&#xff0c;并从右表中检索匹配的记录&#xff08;如果有&#x…

管道支架安装

工程结构施工完毕后&#xff0c;系统管道安装完毕后的第一步任务就是管道支架的制作安装&#xff0c;作为对管道固定和承重作用至关重要的支、托、吊架&#xff0c;有些项目部在施工中却往往因为对它们的重要性认识不足&#xff0c;因存在侥幸心里或经验主义&#xff0c;导致支…

NIO为什么会导致CPU100%?

1. Java IO 类型概览 BIO&#xff1a;阻塞I/O&#xff0c;每个连接一个线程&#xff0c;简单但遇到高并发时性能瓶颈明显。NIO&#xff1a;非阻塞I/O&#xff0c;JDK 1.4引入&#xff0c;一个线程处理多个IO操作&#xff0c;提高资源利用率和系统吞吐量。AIO&#xff1a;异步I…

技术探索:利用Python库wxauto实现Windows微信客户端的全面自动化管理

项目地址&#xff1a;github-wxauto 点击即可访问 项目官网&#xff1a;wxauto 点击即可访问 &#x1f602;什么是wxauto? wxauto 是作者在2020年开发的一个基于 UIAutomation 的开源 Python 微信自动化库&#xff0c;最初只是一个简单的脚本&#xff0c;只能获取消息和发送…