在PyTorch框架中,可以通过自定义数据集类来加载和处理数据
要自定义数据集类,需要继承 PyTorch提供的 torch.utils.data.Dataset
类,并实现两个主要方法:__len__
和 __getitem__
下面是一个示例,展示如何基于PyTorch框架来自定义数据集类以获取数据:
import torch
from torch.utils.data import Datasetclass CustomDataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, index):item = self.data[index]# 在这里对数据进行预处理、转换等操作# 返回一个样本(通常是一个字典)return item# 创建数据集实例
data = [...] # 数据列表,包含训练样本
dataset = CustomDataset(data)# 创建数据加载器
batch_size = 32
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)# 遍历数据加载器获取数据批次
for batch in dataloader:# 处理每个批次的数据inputs = batch['input']labels = batch['label']# 在这里进行模型训练、推理等操作
在此示例中,定义了一个名为 CustomDataset
的自定义数据集类,该类继承自torch.utils.data. Dataset
__init__
方法是构造函数,传入数据列表 data 并将其保存为类的属性 self.data
__len__
方法返回数据集的长度,即样本数量
__getitem__
方法通过索引获取单个样本
然后,创建了一个数据集实例 dataset
,并使用 torch.utils.data.DataLoader
创建了一个数据加载器 dataloader
通过遍历数据加载器可以获取每个批次 输入数据inputs 以及 标签数据labels,进行模型训练、推理等操作
注意:根据具体的应用需求,可以在__getitem__
方法中对数据进行预处理、转换等操作,并将处理后的样本作为字典或其他形式返回, 这样,在训练过程中可以方便地获取输入数据和标签数据 ,并进行相应的操作
下面再来看一个例子,该例通过在 __getitem__方法中对数据进行预处理,并最终返回一个包含图片数据、对应的标签数据以及图像文件名的字典
class BipedDataset(Dataset): # 定义了一个名为BipedDataset的类,它继承自PyTorch的Dataset类,用于自定义数据集'''用于构建一个自定义数据集,可以在训练神经网络时使用它提供了加载图像、预处理数据等功能,以便用于深度学习模型的训练'''def __init__(self,data_root, img_height,img_width,mean_bgr, # 图像的均值(以BGR通道顺序表示)train_mode='train', # 训练模式,可以是 'train' 或 'test' 之一,默认为 'traincrop_img=False,arg=None):'''这是类的构造函数,用于初始化对象的属性它接受许多参数,包括数据根目录 data_root、图像高度 img_height、图像宽度 img_width、均值 mean_bgr、训练模式 train_mode 等'''self.data_root = data_rootself.train_mode = train_modeself.img_height = img_heightself.img_width = img_widthself.mean_bgr = mean_bgrself.crop_img = crop_imgself.arg = argself.data_index = self._build_index()def _build_index(self): # 用于构建数据索引data_root = os.path.abspath(self.data_root)sample_indices = [] # 用于存储图像和标签的文件路径对# 构建图像和标签的文件路径,其中 images_path 和 labels_path 分别指向数据集中图像和标签的存储路径# 使用两层循环遍历图像目录中的所有文件,构建图像和标签的文件路径,并将其添加到 sample_indices 列表中images_path = os.path.join(data_root,'edges\\imgs',self.train_mode)labels_path = os.path.join(data_root,'edges\\labels',self.train_mode)for file_name_ext in os.listdir(images_path):file_name = os.path.splitext(file_name_ext)[0]sample_indices.append(( os.path.join(images_path, file_name + '.tif'),os.path.join(labels_path, file_name + '.tif'), ))return sample_indices # 返回构建好的图像和标签的文件路径对列表def __len__(self): # 返回数据集的长度,即样本的数量return len(self.data_index)def __getitem__(self, idx): # 用于获取指定索引处的数据样本,它接受一个索引 idx 作为参数# get data sample'''首先,根据索引获取图像路径和标签路径然后,使用OpenCV加载图像和标签接下来,调用self.transform方法进行数据变换最后,返回一个包含图像、对应标签以及图像文件名的字典'''image_path, label_path = self.data_index[idx]# load dataimage = cv2.imread(image_path, cv2.IMREAD_COLOR)label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)image, label = self.transform(img=image, gt=label) # transform方法:用于对图像和标签进行预处理img_name = os.path.basename(image_path)file_name = os.path.splitext(img_name)[0] + ".png"return dict(images=image, labels=label, file_names=file_name)def transform(self, img, gt):# 将标签转换为浮点型数组,并将其归一化到 [0, 1] 的范围内gt = np.array(gt, dtype=np.float32)if len(gt.shape) == 3:gt = gt[:, :, 0]gt /= 255. # 将图像转换为浮点型数组,并减去均值 self.mean_bgrimg = np.array(img, dtype=np.float32)img -= self.mean_bgri_h, i_w, _ = img.shape # 获取图像的高度、宽度和通道数# 根据设定的裁剪大小 crop_size 对图像进行裁剪或缩放crop_size = self.img_height if self.img_height == self.img_width else None # 对于裁剪过程,它会在图像中随机选择一个位置来裁剪if i_w > crop_size and i_h > crop_size:i = random.randint(0, i_h - crop_size)j = random.randint(0, i_w - crop_size)img = img[i:i + crop_size, j:j + crop_size]gt = gt[i:i + crop_size, j:j + crop_size]else: # 如果图像的尺寸小于 crop_size,则会使用双线性插值进行缩放# New addidingsimg = cv2.resize(img, dsize=(crop_size, crop_size))gt = cv2.resize(gt, dsize=(crop_size, crop_size))# 对标签gt进行一些额外的处理,然后将图像img和标签gt转换为PyTorch的张量形式gt[gt > 0.1] += 0.2 gt = np.clip(gt, 0., 1.)img = img.transpose((2, 0, 1))img = torch.from_numpy(img.copy()).float()gt = torch.from_numpy(np.array([gt])).float()return img, gt
在此处就定义完成了一个数据集类 BipedDataset
如何使用自定义的 BipedDataset 类来对数据进行加载呢?下面以加载验证集数据为例来进行说明
首先,对这个类进行实例化得到实例化后的数据集对象 dataset_val
dataset_val = BipedDataset(args.input_dir,img_width =args.img_width,img_height =args.img_height,mean_bgr =args.mean_pixel_values,train_mode ='test',arg =args)
其次,将该对象传入DataLoader中创建验证集数据加载器 dataloader_val
dataloader_val = DataLoader(dataset_val,batch_size=1,shuffle=False,num_workers=args.workers)
然后,将数据集加载器 dataloader_val 作为参数传入进行验证过程的函数 validate_one_epoch 中
val_precision,val_recall,val_IoU = validate_one_epoch(epoch,dataloader_val,model,device,img_test_dir,arg=args)
def validate_one_epoch(epoch, dataloader, model, device, output_dir, arg=None):precision = 0.0recall = 0.0IoU = 0.0model.eval() with torch.no_grad():for _, sample_batched in enumerate(dataloader):images = sample_batched['images'].to(device)labels = sample_batched['labels'].to(device)file_names = sample_batched['file_names'] preds = model(images)labels = normalize_image(labels)preds = normalize_image(preds)precision += calculate_precision(preds, labels)recall += calculate_recall(preds, labels)IoU += calculate_iou(preds, labels)save_image_batch_to_disk(preds, output_dir, file_names,arg=arg)precision = precision / len(dataloader)recall = recall / len(dataloader)IoU = IoU / len(dataloader)print(time.ctime(), '[Val_Epoch]: {0} Precision:{1} Recall:{2} IoU:{3} '.format(epoch, precision, recall, IoU))print(f"第{epoch}次迭代的验证精确度为{precision},验证召回率为{recall},验证交并比为{IoU}")return precision, recall, IoU
最后,我们可以看到将 dataloader_val验证集数据加载器 传入 函数validate_one_epoch 中,通过遍历 dataloader 中的数据,可以通过 自定义类BipedDataset 返回的包含三个元素的字典来获取图像数据、对应的标签数据以及图像文件名,如下图所示
images = sample_batched['images'].to(device)labels = sample_batched['labels'].to(device)file_names = sample_batched['file_names']
综上所述, 就是关于如何基于PyTorch深度学习框架自定义数据集来获取数据的详细步骤了,如果你觉得有用,麻烦点赞关注一下哈,谢谢!