本文从Cora的例子来展示PYG如何加载图数据集。
Cora 是一个小型的有标注的图数据集,包含以下内容:
- data.x:2708 个节点(即 2708 篇论文),每个节点有 1433 个特征,形状为 (2708, 1433)。
- data.edge_index:5429 条边(即 5429 个引用关系),形状为 (2, 5429)。
- data.y:节点标签,共 7 类,形状为 (2708,)。(共有 7 个类别,表示论文的研究领域)
- data.train_mask:训练集掩码,布尔向量,表示哪些节点用于训练。
- data.val_mask:验证集掩码,布尔向量,表示哪些节点用于验证。
- data.test_mask:测试集掩码,布尔向量,表示哪些节点用于测试。
数据主要描述了论文之间的引用关系以及每篇论文的主题。可用于进行训练节点分类问题(即判断每篇论文属于哪个类别)
1.自动加载
1.1 数据加载操作详解
PYG库提供了自动加载数据集的方法:
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='data/Planetoid', name='Cora')
dataset[0]
print(len(dataset)) # 输出: 1
print(data)
1
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
对于 Planetoid
类来说:
- 它是一个专门为 Planetoid 系列数据集(Cora、CiteSeer、PubMed) 设计的类。
- 这些数据集的主要特点是:它们实际上是单图数据集,即整个数据集中只包含一个图。
dataset
是一个包含 单个 Data
对象(图) 的数据集对象。
由于 Planetoid
类的数据集中只有一个图,因此:
dataset[0]
返回了这个唯一的图,类型是Data
对象,表示整个 Cora 数据集的图。Dataset
是一个可索引的对象,dataset[0]
的作用就是提取第一(也是唯一)个图。
dataset = Planetoid(root='data/Planetoid', name='Cora')
加载了 Cora 数据集,它是一个 单图数据集,包含一张图的节点特征、边索引、节点标签和数据集划分信息。dataset[0]
提取了该图的数据,返回了一个Data
对象,表示整个图。dataset
本身是一个数据集管理器,帮助加载和存储数据,同时提供一些元信息和操作方法。
1. 2 数据加载的过程
-
下载数据:
- 如果指定路径
'data/Planetoid'
下没有数据集文件,Planetoid
类会从 指定的远程服务器(由 PyG 维护)下载 Cora 数据集文件,并存储在'data/Planetoid/Cora'
文件夹下。 - 数据集下载地址为:
- Cora 数据集原始文件
- 如果指定路径
-
解压文件:
- 下载的数据集是
.zip
或.tar
格式,会被自动解压为一系列文件,主要包括:ind.cora.x
:训练节点的特征矩阵;ind.cora.tx
:测试节点的特征矩阵;ind.cora.allx
:包含训练节点和一些验证节点的特征矩阵;ind.cora.y
:训练节点的标签;ind.cora.ty
:测试节点的标签;ind.cora.ally
:训练和验证节点的标签;ind.cora.graph
:节点的邻接表(图结构信息);ind.cora.test.index
:测试节点的索引。
如图所示:
- 下载的数据集是
-
解析数据:
- PyG 将原始文件的内容解析为图数据格式(
Data
对象),将以下内容整合起来:- 节点特征矩阵
x
; - 图的边信息
edge_index
; - 节点标签
y
; - 训练、验证和测试集的掩码(
train_mask
、val_mask
、test_mask
)。
- 节点特征矩阵
- PyG 将原始文件的内容解析为图数据格式(
-
数据存储:
- 如果数据加载成功,解析后的数据将被缓存到指定路径(
data/Planetoid/Cora
)中,后续运行时会直接加载解析后的缓存文件,而不会重复下载和解析。
- 如果数据加载成功,解析后的数据将被缓存到指定路径(
2. 数据集原始文件的形式
原始文件(以 ind.cora.*
为前缀)是以下几种内容的存储形式:
文件名 | 内容描述 |
---|---|
ind.cora.x | 稀疏矩阵,训练集中节点的特征矩阵,大小为 (num_train_nodes, num_features) 。 |
ind.cora.tx | 稀疏矩阵,测试集中节点的特征矩阵,大小为 (num_test_nodes, num_features) 。 |
ind.cora.allx | 稀疏矩阵,包含训练集和部分验证集中节点的特征矩阵,大小为 (num_allx_nodes, num_features) 。 |
ind.cora.y | 训练集的标签,大小为 (num_train_nodes, num_classes) 的独热编码矩阵。 |
ind.cora.ty | 测试集的标签,大小为 (num_test_nodes, num_classes) 的独热编码矩阵。 |
ind.cora.ally | 训练和验证集的标签,大小为 (num_allx_nodes, num_classes) 的独热编码矩阵。 |
ind.cora.graph | 字典格式,存储图的邻接表,键为节点 ID,值为该节点的邻居节点列表。 |
ind.cora.test.index | 列表形式,包含测试节点的索引。 |
3. 加载后的数据形式
加载后,数据以 torch_geometric.data.Data
对象的形式存储,主要包含以下内容:
属性 | 描述 | 形状 |
---|---|---|
data.x | 节点的特征矩阵,每一行表示一个节点的特征向量。 | (num_nodes, num_features) |
data.edge_index | 图的边信息,存储为 COO 格式的索引矩阵(两个一维数组,分别表示边的起始节点和结束节点)。 | (2, num_edges) |
data.y | 节点的标签,每个节点对应一个整数,表示其所属类别的索引值。 | (num_nodes,) |
data.train_mask | 训练节点的布尔掩码,值为 True 的位置表示该节点属于训练集。 | (num_nodes,) |
data.val_mask | 验证节点的布尔掩码,值为 True 的位置表示该节点属于验证集。 | (num_nodes,) |
data.test_mask | 测试节点的布尔掩码,值为 True 的位置表示该节点属于测试集。 | (num_nodes,) |
4. 加载后的具体内容
以 Cora 数据集为例,加载后的数据具有以下具体特性:
- 节点数:
num_nodes = 2708
(共 2708 篇论文)。 - 特征数:
num_features = 1433
(每篇论文的特征是一个 1433 维向量,表示词袋模型中的单词出现情况)。 - 边数:
num_edges = 10556
(论文之间的引用关系,构成无向图)。 - 类别数:
num_classes = 7
(每篇论文属于 7 个主题之一)。 - 掩码分布:
- 训练集:140 个节点;
- 验证集:500 个节点;
- 测试集:1000 个节点。
手动读取数据集
下面手动实现的 CoraData
类代码,经过修改后与 PyTorch Geometric (PyG
) 的 Planetoid
类功能一致,可以直接生成标准的 Data
对象,用于图神经网络训练。
完整代码:CoraData
import os
import os.path as osp
import pickle
import numpy as np
import torch
from torch_geometric.data import Data
import scipy.sparse as sp
import urllib.requestclass CoraData(object):download_url = "https://github.com/kimiyoung/planetoid/raw/master/data"filenames = ["ind.cora.{}".format(name) for name in['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']]def __init__(self, data_root="cora", rebuild=False):"""Cora 数据加载器,包括下载、处理和缓存功能。处理后的数据可以通过属性 .data 获取,返回 PyG 标准的 Data 对象。Args:data_root: str, 数据存储的根目录rebuild: bool, 是否强制重新构建数据"""self.data_root = data_rootsave_file = osp.join(self.data_root, "processed_cora.pkl")if osp.exists(save_file) and not rebuild:print("Using Cached file: {}".format(save_file))self._data = pickle.load(open(save_file, "rb"))else:self.maybe_download()self._data = self.process_data()with open(save_file, "wb") as f:pickle.dump(self.data, f)print("Cached file: {}".format(save_file))@propertydef data(self):"""返回 PyG 标准的 Data 对象"""return self._datadef maybe_download(self):save_path = osp.join(self.data_root, "raw")for name in self.filenames:if not osp.exists(osp.join(save_path, name)):self.download_data("{}/{}".format(self.download_url, name), save_path)def process_data(self):"""处理数据并生成 PyG 标准的 Data 对象,包括以下属性:- x: 节点特征,(2708, 1433)- y: 节点标签,共 7 类,(2708,)- edge_index: 图边索引,(2, num_edges)- train_mask: 训练集掩码,(2708,)- val_mask: 验证集掩码,(2708,)- test_mask: 测试集掩码,(2708,)"""print("Processing data ...")# 读取原始数据x, tx, allx, y, ty, ally, graph, test_index = [self.read_data(osp.join(self.data_root, "raw", name)) for name in self.filenames]train_index = np.arange(y.shape[0]) # 训练集索引 [0, 1, ..., 139]val_index = np.arange(y.shape[0], y.shape[0] + 500) # 验证集索引 [140, ..., 639]sorted_test_index = sorted(test_index) # 排序后的测试集索引# 特征和标签拼接x = np.concatenate((allx, tx), axis=0) # (2708, 1433)y = np.concatenate((ally, ty), axis=0).argmax(axis=1) # (2708,)# 重新排序测试集数据x[test_index] = x[sorted_test_index]y[test_index] = y[sorted_test_index]# 创建训练、验证、测试掩码num_nodes = x.shape[0]train_mask = np.zeros(num_nodes, dtype=np.bool_)val_mask = np.zeros(num_nodes, dtype=np.bool_)test_mask = np.zeros(num_nodes, dtype=np.bool_)train_mask[train_index] = Trueval_mask[val_index] = Truetest_mask[test_index] = True# 构造 edge_indexedge_index = self.build_edge_index(graph)# 转换为 PyTorch 格式x = torch.tensor(x, dtype=torch.float32)y = torch.tensor(y, dtype=torch.long)edge_index = torch.tensor(edge_index, dtype=torch.long)train_mask = torch.tensor(train_mask, dtype=torch.bool)val_mask = torch.tensor(val_mask, dtype=torch.bool)test_mask = torch.tensor(test_mask, dtype=torch.bool)# 打印基本信息print("Node feature shape: ", x.shape)print("Node label shape: ", y.shape)print("Edge index shape: ", edge_index.shape)print("Number of training nodes: ", train_mask.sum().item())print("Number of validation nodes: ", val_mask.sum().item())print("Number of test nodes: ", test_mask.sum().item())# 返回 PyG 的 Data 对象return Data(x=x, y=y, edge_index=edge_index,train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)@staticmethoddef build_edge_index(graph):"""根据邻接表生成 edge_index 格式 (2, num_edges)。"""edge_index = []for src, dst in graph.items():edge_index.extend([[src, v] for v in dst]) # 正向边edge_index.extend([[v, src] for v in dst]) # 反向边edge_index = np.array(edge_index).T # 转置为 (2, num_edges)return edge_index@staticmethoddef read_data(path):"""读取数据文件,根据文件名选择加载方式。"""name = osp.basename(path)if name == "ind.cora.test.index":out = np.genfromtxt(path, dtype="int64")return outelse:out = pickle.load(open(path, "rb"), encoding="latin1")out = out.toarray() if hasattr(out, "toarray") else outreturn out@staticmethoddef download_data(url, save_path):"""从指定 URL 下载数据,并保存到本地路径。"""if not os.path.exists(save_path):os.makedirs(save_path)data = urllib.request.urlopen(url)filename = os.path.split(url)[-1]with open(os.path.join(save_path, filename), 'wb') as f:f.write(data.read())return True
代码解析
-
下载和缓存功能:
- 如果处理后的数据已缓存 (
processed_cora.pkl
),直接加载缓存。 - 如果未缓存,则从 GitHub 下载原始数据,处理后存储为缓存文件。
- 如果处理后的数据已缓存 (
-
数据处理:
process_data
:- 加载原始数据,并将训练、验证、测试节点特征拼接成完整矩阵。
- 生成 PyG 格式的
edge_index
(用于图神经网络的邻接表表示)。 - 生成训练、验证和测试集掩码。
-
邻接表转换为边索引:
build_edge_index
将邻接表 (graph
) 转换为edge_index
格式。edge_index
是一个形状为(2, num_edges)
的数组,列表示一条边的起点和终点。
-
返回 PyG 数据对象:
- 数据对象包括
x
、y
、edge_index
、train_mask
、val_mask
和test_mask
。
- 数据对象包括
运行代码测试
要测试 CoraData
类,可以直接运行以下代码:
cora_data = CoraData(data_root="cora", rebuild=True)
data = cora_data.data # 获取 PyG 的 Data 对象
print(data)
输出示例:
Processing data ...
Node feature shape: torch.Size([2708, 1433])
Node label shape: torch.Size([2708])
Edge index shape: torch.Size([2, 10556])
Number of training nodes: 140
Number of validation nodes: 500
Number of test nodes: 1000
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
该类的功能与 PyTorch Geometric 的 Planetoid
类一致,支持加载 Cora
数据集,并生成标准的 PyG Data
对象,适用于图神经网络模型训练。