1、数据集的保存形式:一行一行的。
比如说预测两个值的加法:a+b=c,那么传进Dataset的形式应该是
a1,b1,c1
a2,b2,c2
...
an,bn,cn
2、代码
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset# 创建数据
np.random.seed(2024) # 保证后续使用random函数时,产生固定的随机数
data_rand = np.random.rand(10, 2)
datas = np.insert(data_rand, 2, data_rand.sum(axis=1), axis=1)
print("\ndatas.shape=", datas.shape)
print("datas=\n", datas)train_data = datas[:int(len(datas) * 0.9)]
test_data = datas[int(len(datas) * 0.9):]debug_flag = False # False,Trueclass PreDataSet(Dataset):def __init__(self, _data):self.x_data = torch.Tensor(_data[:, :-1])self.y_data = torch.Tensor(_data[:, -1])if debug_flag:print(">>self.x_data.shape=", self.x_data.shape)print(">>self.y_data.shape=", self.y_data.shape)self.n_getitem = 0 # 记录进入__getitem__的次数self.n_len = 0 # 记录进入__len__的次数def __getitem__(self, index):self.n_getitem = self.n_getitem + 1if debug_flag:print(">>index=", index, "n_getitem=", self.n_getitem)print(">>x_data[index].shape=", self.x_data[index].shape)print(">>y_data[index].shape=", self.y_data[index].shape)return self.x_data[index], self.y_data[index]def __len__(self):self.n_len = self.n_len + 1if debug_flag:print(">>len(self.x_data)=", len(self.x_data), "n_len=", self.n_len)return len(self.x_data)train_dataset = PreDataSet(train_data)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=False)# 2、输出看结果
for x, y in train_dataloader:print("\nx=", x)print("y=", y)if debug_flag:print("x.shape=", x.shape)print("y.shape=", y.shape)
3、运行结果
D:\SoftProgram\JetBrains\anaconda3_202303\python.exe E:\program\python\DKASCProject\10DistributedPV\tst_dataloader_end.py datas.shape= (10, 3)
datas=[[0.58801452 0.69910875 1.28712327][0.18815196 0.04380856 0.23196052][0.20501895 0.10606287 0.31108183][0.72724014 0.67940052 1.40664067][0.4738457 0.44829582 0.92214153][0.01910695 0.75259834 0.77170529][0.60244854 0.96177758 1.56422611][0.66436865 0.60662962 1.27099827][0.44915131 0.22535416 0.67450548][0.6701743 0.73576659 1.40594089]]x= tensor([[0.5880, 0.6991]])
y= tensor([1.2871])x= tensor([[0.1882, 0.0438]])
y= tensor([0.2320])x= tensor([[0.2050, 0.1061]])
y= tensor([0.3111])x= tensor([[0.7272, 0.6794]])
y= tensor([1.4066])x= tensor([[0.4738, 0.4483]])
y= tensor([0.9221])x= tensor([[0.0191, 0.7526]])
y= tensor([0.7717])x= tensor([[0.6024, 0.9618]])
y= tensor([1.5642])x= tensor([[0.6644, 0.6066]])
y= tensor([1.2710])x= tensor([[0.4492, 0.2254]])
y= tensor([0.6745])进程已结束,退出代码为 0
参考B站视频
【2、数据集加载(Dataset和DataLoader)】