PYG - Cora数据集加载 (自动加载+手动实现)

本文从Cora的例子来展示PYG如何加载图数据集。
Cora 是一个小型的有标注的图数据集,包含以下内容:

  • data.x:2708 个节点(即 2708 篇论文),每个节点有 1433 个特征,形状为 (2708, 1433)。
  • data.edge_index:5429 条边(即 5429 个引用关系),形状为 (2, 5429)。
  • data.y:节点标签,共 7 类,形状为 (2708,)。(共有 7 个类别,表示论文的研究领域)
  • data.train_mask:训练集掩码,布尔向量,表示哪些节点用于训练。
  • data.val_mask:验证集掩码,布尔向量,表示哪些节点用于验证。
  • data.test_mask:测试集掩码,布尔向量,表示哪些节点用于测试。

数据主要描述了论文之间的引用关系以及每篇论文的主题。可用于进行训练节点分类问题(即判断每篇论文属于哪个类别)

1.自动加载

1.1 数据加载操作详解

PYG库提供了自动加载数据集的方法:

from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='data/Planetoid', name='Cora')
dataset[0]
print(len(dataset))  # 输出: 1
print(data)

1
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

对于 Planetoid 类来说:

  • 它是一个专门为 Planetoid 系列数据集(Cora、CiteSeer、PubMed) 设计的类。
  • 这些数据集的主要特点是:它们实际上是单图数据集,即整个数据集中只包含一个图。

dataset 是一个包含 单个 Data 对象(图) 的数据集对象。


由于 Planetoid 类的数据集中只有一个图,因此:

  • dataset[0] 返回了这个唯一的图,类型是 Data 对象,表示整个 Cora 数据集的图。
  • Dataset 是一个可索引的对象,dataset[0] 的作用就是提取第一(也是唯一)个图。

  • dataset = Planetoid(root='data/Planetoid', name='Cora') 加载了 Cora 数据集,它是一个 单图数据集,包含一张图的节点特征、边索引、节点标签和数据集划分信息。
  • dataset[0] 提取了该图的数据,返回了一个 Data 对象,表示整个图。
  • dataset 本身是一个数据集管理器,帮助加载和存储数据,同时提供一些元信息和操作方法。

1. 2 数据加载的过程

  1. 下载数据:

    • 如果指定路径 'data/Planetoid' 下没有数据集文件,Planetoid 类会从 指定的远程服务器(由 PyG 维护)下载 Cora 数据集文件,并存储在 'data/Planetoid/Cora' 文件夹下。
    • 数据集下载地址为:
      • Cora 数据集原始文件
  2. 解压文件:

    • 下载的数据集是 .zip.tar 格式,会被自动解压为一系列文件,主要包括:
      • ind.cora.x:训练节点的特征矩阵;
      • ind.cora.tx:测试节点的特征矩阵;
      • ind.cora.allx:包含训练节点和一些验证节点的特征矩阵;
      • ind.cora.y:训练节点的标签;
      • ind.cora.ty:测试节点的标签;
      • ind.cora.ally:训练和验证节点的标签;
      • ind.cora.graph:节点的邻接表(图结构信息);
      • ind.cora.test.index:测试节点的索引。
        如图所示:
        请添加图片描述
  3. 解析数据:

    • PyG 将原始文件的内容解析为图数据格式(Data 对象),将以下内容整合起来:
      • 节点特征矩阵 x
      • 图的边信息 edge_index
      • 节点标签 y
      • 训练、验证和测试集的掩码(train_maskval_masktest_mask)。
  4. 数据存储:

    • 如果数据加载成功,解析后的数据将被缓存到指定路径(data/Planetoid/Cora)中,后续运行时会直接加载解析后的缓存文件,而不会重复下载和解析。

2. 数据集原始文件的形式

原始文件(以 ind.cora.* 为前缀)是以下几种内容的存储形式:

文件名内容描述
ind.cora.x稀疏矩阵,训练集中节点的特征矩阵,大小为 (num_train_nodes, num_features)
ind.cora.tx稀疏矩阵,测试集中节点的特征矩阵,大小为 (num_test_nodes, num_features)
ind.cora.allx稀疏矩阵,包含训练集和部分验证集中节点的特征矩阵,大小为 (num_allx_nodes, num_features)
ind.cora.y训练集的标签,大小为 (num_train_nodes, num_classes) 的独热编码矩阵。
ind.cora.ty测试集的标签,大小为 (num_test_nodes, num_classes) 的独热编码矩阵。
ind.cora.ally训练和验证集的标签,大小为 (num_allx_nodes, num_classes) 的独热编码矩阵。
ind.cora.graph字典格式,存储图的邻接表,键为节点 ID,值为该节点的邻居节点列表。
ind.cora.test.index列表形式,包含测试节点的索引。

3. 加载后的数据形式

加载后,数据以 torch_geometric.data.Data 对象的形式存储,主要包含以下内容:

属性描述形状
data.x节点的特征矩阵,每一行表示一个节点的特征向量。(num_nodes, num_features)
data.edge_index图的边信息,存储为 COO 格式的索引矩阵(两个一维数组,分别表示边的起始节点和结束节点)。(2, num_edges)
data.y节点的标签,每个节点对应一个整数,表示其所属类别的索引值。(num_nodes,)
data.train_mask训练节点的布尔掩码,值为 True 的位置表示该节点属于训练集。(num_nodes,)
data.val_mask验证节点的布尔掩码,值为 True 的位置表示该节点属于验证集。(num_nodes,)
data.test_mask测试节点的布尔掩码,值为 True 的位置表示该节点属于测试集。(num_nodes,)

4. 加载后的具体内容

Cora 数据集为例,加载后的数据具有以下具体特性:

  • 节点数num_nodes = 2708(共 2708 篇论文)。
  • 特征数num_features = 1433(每篇论文的特征是一个 1433 维向量,表示词袋模型中的单词出现情况)。
  • 边数num_edges = 10556(论文之间的引用关系,构成无向图)。
  • 类别数num_classes = 7(每篇论文属于 7 个主题之一)。
  • 掩码分布
    • 训练集:140 个节点;
    • 验证集:500 个节点;
    • 测试集:1000 个节点。

手动读取数据集

下面手动实现的 CoraData 类代码,经过修改后与 PyTorch Geometric (PyG) 的 Planetoid 类功能一致,可以直接生成标准的 Data 对象,用于图神经网络训练。


完整代码:CoraData

import os
import os.path as osp
import pickle
import numpy as np
import torch
from torch_geometric.data import Data
import scipy.sparse as sp
import urllib.requestclass CoraData(object):download_url = "https://github.com/kimiyoung/planetoid/raw/master/data"filenames = ["ind.cora.{}".format(name) for name in['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']]def __init__(self, data_root="cora", rebuild=False):"""Cora 数据加载器,包括下载、处理和缓存功能。处理后的数据可以通过属性 .data 获取,返回 PyG 标准的 Data 对象。Args:data_root: str, 数据存储的根目录rebuild: bool, 是否强制重新构建数据"""self.data_root = data_rootsave_file = osp.join(self.data_root, "processed_cora.pkl")if osp.exists(save_file) and not rebuild:print("Using Cached file: {}".format(save_file))self._data = pickle.load(open(save_file, "rb"))else:self.maybe_download()self._data = self.process_data()with open(save_file, "wb") as f:pickle.dump(self.data, f)print("Cached file: {}".format(save_file))@propertydef data(self):"""返回 PyG 标准的 Data 对象"""return self._datadef maybe_download(self):save_path = osp.join(self.data_root, "raw")for name in self.filenames:if not osp.exists(osp.join(save_path, name)):self.download_data("{}/{}".format(self.download_url, name), save_path)def process_data(self):"""处理数据并生成 PyG 标准的 Data 对象,包括以下属性:- x: 节点特征,(2708, 1433)- y: 节点标签,共 7 类,(2708,)- edge_index: 图边索引,(2, num_edges)- train_mask: 训练集掩码,(2708,)- val_mask: 验证集掩码,(2708,)- test_mask: 测试集掩码,(2708,)"""print("Processing data ...")# 读取原始数据x, tx, allx, y, ty, ally, graph, test_index = [self.read_data(osp.join(self.data_root, "raw", name)) for name in self.filenames]train_index = np.arange(y.shape[0])  # 训练集索引 [0, 1, ..., 139]val_index = np.arange(y.shape[0], y.shape[0] + 500)  # 验证集索引 [140, ..., 639]sorted_test_index = sorted(test_index)  # 排序后的测试集索引# 特征和标签拼接x = np.concatenate((allx, tx), axis=0)  # (2708, 1433)y = np.concatenate((ally, ty), axis=0).argmax(axis=1)  # (2708,)# 重新排序测试集数据x[test_index] = x[sorted_test_index]y[test_index] = y[sorted_test_index]# 创建训练、验证、测试掩码num_nodes = x.shape[0]train_mask = np.zeros(num_nodes, dtype=np.bool_)val_mask = np.zeros(num_nodes, dtype=np.bool_)test_mask = np.zeros(num_nodes, dtype=np.bool_)train_mask[train_index] = Trueval_mask[val_index] = Truetest_mask[test_index] = True# 构造 edge_indexedge_index = self.build_edge_index(graph)# 转换为 PyTorch 格式x = torch.tensor(x, dtype=torch.float32)y = torch.tensor(y, dtype=torch.long)edge_index = torch.tensor(edge_index, dtype=torch.long)train_mask = torch.tensor(train_mask, dtype=torch.bool)val_mask = torch.tensor(val_mask, dtype=torch.bool)test_mask = torch.tensor(test_mask, dtype=torch.bool)# 打印基本信息print("Node feature shape: ", x.shape)print("Node label shape: ", y.shape)print("Edge index shape: ", edge_index.shape)print("Number of training nodes: ", train_mask.sum().item())print("Number of validation nodes: ", val_mask.sum().item())print("Number of test nodes: ", test_mask.sum().item())# 返回 PyG 的 Data 对象return Data(x=x, y=y, edge_index=edge_index,train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)@staticmethoddef build_edge_index(graph):"""根据邻接表生成 edge_index 格式 (2, num_edges)。"""edge_index = []for src, dst in graph.items():edge_index.extend([[src, v] for v in dst])  # 正向边edge_index.extend([[v, src] for v in dst])  # 反向边edge_index = np.array(edge_index).T  # 转置为 (2, num_edges)return edge_index@staticmethoddef read_data(path):"""读取数据文件,根据文件名选择加载方式。"""name = osp.basename(path)if name == "ind.cora.test.index":out = np.genfromtxt(path, dtype="int64")return outelse:out = pickle.load(open(path, "rb"), encoding="latin1")out = out.toarray() if hasattr(out, "toarray") else outreturn out@staticmethoddef download_data(url, save_path):"""从指定 URL 下载数据,并保存到本地路径。"""if not os.path.exists(save_path):os.makedirs(save_path)data = urllib.request.urlopen(url)filename = os.path.split(url)[-1]with open(os.path.join(save_path, filename), 'wb') as f:f.write(data.read())return True

代码解析

  1. 下载和缓存功能

    • 如果处理后的数据已缓存 (processed_cora.pkl),直接加载缓存。
    • 如果未缓存,则从 GitHub 下载原始数据,处理后存储为缓存文件。
  2. 数据处理:process_data

    • 加载原始数据,并将训练、验证、测试节点特征拼接成完整矩阵。
    • 生成 PyG 格式的 edge_index(用于图神经网络的邻接表表示)。
    • 生成训练、验证和测试集掩码。
  3. 邻接表转换为边索引

    • build_edge_index 将邻接表 (graph) 转换为 edge_index 格式。
    • edge_index 是一个形状为 (2, num_edges) 的数组,列表示一条边的起点和终点。
  4. 返回 PyG 数据对象

    • 数据对象包括 xyedge_indextrain_maskval_masktest_mask

运行代码测试

要测试 CoraData 类,可以直接运行以下代码:

cora_data = CoraData(data_root="cora", rebuild=True)
data = cora_data.data  # 获取 PyG 的 Data 对象
print(data)

输出示例:

Processing data ...
Node feature shape:  torch.Size([2708, 1433])
Node label shape:  torch.Size([2708])
Edge index shape:  torch.Size([2, 10556])
Number of training nodes:  140
Number of validation nodes:  500
Number of test nodes:  1000
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

该类的功能与 PyTorch Geometric 的 Planetoid 类一致,支持加载 Cora 数据集,并生成标准的 PyG Data 对象,适用于图神经网络模型训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/63754.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器学习基础算法 (二)-逻辑回归

python 环境的配置参考 从零开始:Python 环境搭建与工具配置 逻辑回归是一种用于解决二分类问题的机器学习算法,它可以预测输入数据属于某个类别的概率。本文将详细介绍逻辑回归的原理、Python 实现、模型评估和调优,并结合垃圾邮件分类案例进…

BiTCN-BiGRU基于双向时间卷积网络结合双向门控循环单元的数据多特征分类预测(多输入单输出)

Matlab实现BiTCN-BiGRU基于双向时间卷积网络结合双向门控循环单元的数据多特征分类预测(多输入单输出) 目录 Matlab实现BiTCN-BiGRU基于双向时间卷积网络结合双向门控循环单元的数据多特征分类预测(多输入单输出)分类效果基本描述…

51c大模型~合集94

我自己的原文哦~ https://blog.51cto.com/whaosoft/12897659 #D(R,O) Grasp 重塑跨智能体灵巧手抓取,NUS邵林团队提出全新交互式表征,斩获CoRL Workshop最佳机器人论文奖 本文的作者均来自新加坡国立大学 LinS Lab。本文的共同第一作者为上海交通大…

【大学英语】英语范文十八篇,书信,议论文,材料分析

关注作者了解更多 我的其他CSDN专栏 过程控制系统 工程测试技术 虚拟仪器技术 可编程控制器 工业现场总线 数字图像处理 智能控制 传感器技术 嵌入式系统 复变函数与积分变换 单片机原理 线性代数 大学物理 热工与工程流体力学 数字信号处理 光电融合集成电路…

一起学Git【第一节:Git的安装】

Git是什么? Git是什么?相信大家点击进来已经有了初步的认识,这里就简单的进行介绍。 Git是一个开源的分布式版本控制系统,由Linus Torvalds创建,用于有效、高速地处理从小到大的项目版本管理。Git是目前世界上最流行…

【day11】面向对象编程进阶(继承)

概述 本文深入探讨面向对象编程的核心概念,包括继承、方法重写、this和super关键字的使用,以及抽象类和方法的定义与实现。通过本文的学习,你将能够: 理解继承的优势。掌握继承的使用方法。了解继承后成员变量和成员方法的访问特…

随手记:小程序兼容后台的wangEditor富文本配置链接

场景&#xff1a; 在后台配置wangEditor富文本&#xff0c;可以文字配置链接&#xff0c;图片配置链接&#xff0c;产生的json格式为&#xff1a; 例子&#xff1a; <h1><a href"https://uniapp.dcloud.net.cn/" target"_blank"><span sty…

6.8 Newman自动化运行Postman测试集

欢迎大家订阅【软件测试】 专栏&#xff0c;开启你的软件测试学习之旅&#xff01; 文章目录 1 安装Node.js2 安装Newman3 使用Newman运行Postman测试集3.1 导出Postman集合3.2 使用Newman运行集合3.3 Newman常用参数3.4 Newman报告格式 4 使用定时任务自动化执行脚本4.1 编写B…

计算机网络之王道考研读书笔记-2

第 2 章 物理层 2.1 通信基础 2.1.1 基本概念 1.数据、信号与码元 通信的目的是传输信息。数据是指传送信息的实体。信号则是数据的电气或电磁表现&#xff0c;是数据在传输过程中的存在形式。码元是数字通信中数字信号的计量单位&#xff0c;这个时长内的信号称为 k 进制码…

法规标准-C-NCAP评测标准解析(2024版)

文章目录 什么是C-NCAP&#xff1f;C-NCAP 评测标准C-NCAP评测维度三大维度的评测场景及对应分数评星标准 自动驾驶相关评测场景评测方法及评测标准AEB VRU——评测内容(测什么&#xff1f;)AEB VRU——评测方法(怎么测&#xff1f;)车辆直行与前方纵向行走的行人测试场景&…

第十七届山东省职业院校技能大赛 中职组“网络安全”赛项任务书正式赛题

第十七届山东省职业院校技能大赛 中职组“网络安全”赛项任务书-A 目录 一、竞赛阶段 二、竞赛任务书内容 &#xff08;一&#xff09;拓扑图 &#xff08;二&#xff09;模块A 基础设施设置与安全加固(200分) &#xff08;三&#xff09;B模块安全事件响应/网络安全数据取证/…

Halcon例程代码解读:安全环检测(附源码|图像下载链接)

安全环检测核心思路与代码详解 项目目标 本项目的目标是检测图像中的安全环位置和方向。通过形状匹配技术&#xff0c;从一张模型图像中提取安全环的特征&#xff0c;并在后续图像中识别多个实例&#xff0c;完成检测和方向标定。 实现思路 安全环检测分为以下核心步骤&…

Java——多线程进阶知识

目录 一、常见的锁策略 乐观锁VS悲观锁 读写锁 重量级锁VS轻量级锁 总结&#xff1a; 自旋锁&#xff08;Spin Lock&#xff09; 公平锁VS非公平锁 可重入锁VS不可重入锁 二、CAS 何为CAS CAS有哪些应用 1&#xff09;实现原子类 2&#xff09;实现自旋锁 CAS的ABA…

达梦 本地编码:PG_GBK, 导入文件编码:PG_UTF8错误

问题 达梦 本地编码&#xff1a;PG_GBK, 导入文件编码&#xff1a;PG_UTF8错误 解决 右键管理服务器 查看配置 新建一个数据库实例&#xff0c;配置跟之前的保持一致 新建一个用户&#xff0c;跟以前的用户名一样 在用户上&#xff0c;右键导入&#xff0c;选择dmp的位置 导…

深度学习卷积神经网络CNN之MobileNet模型网络模型详解说明(超详细理论篇)

1.MobileNet背景 2.MobileNet V1论文 3. MobileNett改进史 4. MobileNet模型结构 5. 特点&#xff08;超详细创新、优缺点及新知识点&#xff09; 一、MobileNet背景 随着移动设备的普及&#xff0c;深度学习模型的应用场景逐渐扩展至移动端和嵌入式设备。然而&#xff0c;传统…

垂起固定翼无人机大面积森林草原巡检技术详解

垂起固定翼无人机大面积森林草原巡检技术是一种高效、精准的监测手段&#xff0c;以下是对该技术的详细解析&#xff1a; 一、垂起固定翼无人机技术特点 垂起固定翼无人机结合了多旋翼和固定翼无人机的优点&#xff0c;具备垂直起降、飞行距离长、速度快、高度高等特点。这种无…

kubernates实战

使用k8s来部署tomcat 1、创建一个部署&#xff0c;并指定镜像地址 kubectl create deployment tomcat6 --imagetomcat:6.0.53-jre82、查看部署pod状态 kubectl get pods # 获取default名称空间下的pods kubectl get pods --all-namespaces # 获取所有名称空间下的pods kubect…

数据挖掘之认识数据

在数据挖掘过程中&#xff0c;数据的认识是非常重要的一步&#xff0c;它为后续的数据分析、建模、特征选择等工作奠定基础。以鸢尾花数据集&#xff08;Iris Dataset&#xff09;数据集之鸢尾花数据集&#xff08;Iris Dataset&#xff09;-CSDN博客为例&#xff0c;下面将介绍…

统信UOS 1071 AI客户端接入本地大模型配置手册

文章来源&#xff1a;统信UOS 1071本地大模型配置手册 | 统信软件-知识分享平台 1. OS版本确认 1.1. 版本查看 要求&#xff1a;计算机&#xff0c;属性&#xff0c;查看版本&#xff08;1070,构建号> 101.100&#xff09; 2. UOS AI版本确认 UOS AI&#xff0c;设置&am…

定时任务——定时任务技术选型

摘要 本文深入探讨了定时任务调度系统的核心问题、技术选型&#xff0c;并对Quartz、Elastic-Job、XXL-Job、Spring Task/ScheduledExecutor、Apache Airflow和Kubernetes CronJob等开源定时任务框架进行了比较分析&#xff0c;包括它们的特点、适用场景和技术栈。文章还讨论了…