为了使用你提到的封装方式来创建一个包含多个 Data
对象的列表并使用 DataLoader
来加载这些数据,我们可以按照以下步骤进行:
- 创建数据:生成节点特征矩阵、边索引矩阵和标签。
- 封装数据:使用
Data
对象将这些数据封装起来。 - 使用
DataLoader
:确保批次数据的形状符合期望。
具体步骤
1. 创建数据
首先,我们创建节点特征矩阵、边索引矩阵和标签数据。
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DenseDataLoader # 更新导入路径# 参数设置
num_samples = 100 # 样本数
num_nodes = 10 # 每个图中的节点数
num_node_features = 8 # 每个节点的特征数# 生成数据
features = [torch.randn((num_nodes, num_node_features)) for _ in range(num_samples)]
labels = [torch.randn((num_nodes, 1)) for _ in range(num_samples)]
adj_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.float)
for i in range(num_nodes):adj_matrix[i, (i + 1) % num_nodes] = 1adj_matrix[(i + 1) % num_nodes, i] = 1
print(adj_matrix)
2. 封装数据
使用 Data
对象将每个样本的数据封装起来。
data_list = [Data(x=features[i], adj=adj_matrix, y=labels[i]) for i in range(num_samples)]
3. 使用 DataLoader
# 创建 DataLoader
loader = DenseDataLoader(data_list, batch_size=32, shuffle=True)# 从 DenseDataLoader 中获取一个批次的数据并查看其形状
for data in loader:print("Batch node features shape:", data.x.shape) # 期望输出形状为 (32, 10, 8)print("Batch adjacency matrix shape:", data.adj.shape) # 期望输出形状为 (32, 10, 10)print("Batch labels shape:", data.y.shape) # 期望输出形状为 (32, 10, 1)break # 仅查看第一个批次的形状
总结
- 生成数据:我们生成了包含节点特征、边索引和标签的样本数据。
- 封装数据:我们使用
Data
对象将每个样本的数据封装起来。
完整代码
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DenseDataLoader # 更新导入路径# 参数设置
num_samples = 100 # 样本数
num_nodes = 10 # 每个图中的节点数
num_node_features = 8 # 每个节点的特征数# 生成数据
features = [torch.randn((num_nodes, num_node_features)) for _ in range(num_samples)]
labels = [torch.randn((num_nodes, 1)) for _ in range(num_samples)]
adj_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.float)
for i in range(num_nodes):adj_matrix[i, (i + 1) % num_nodes] = 1adj_matrix[(i + 1) % num_nodes, i] = 1
print(adj_matrix)data_list = [Data(x=features[i], adj=adj_matrix, y=labels[i]) for i in range(num_samples)]# 创建 DataLoader
loader = DenseDataLoader(data_list, batch_size=32, shuffle=True)# 从 DenseDataLoader 中获取一个批次的数据并查看其形状
for data in loader:print("Batch node features shape:", data.x.shape) # 期望输出形状为 (32, 10, 8)print("Batch adjacency matrix shape:", data.adj.shape) # 期望输出形状为 (32, 10, 10)print("Batch labels shape:", data.y.shape) # 期望输出形状为 (32, 10, 1)break # 仅查看第一个批次的形状