一、dataset和dataloader要点说明
在我们搭建自己的网络时,往往需要定义自己的dataset
和dataloader
,将图像和标签数据送入模型。
(1)在我们定义dataset
时,需要继承torch.utils.data.dataset
,再重写三个方法:
init
方法,主要用来定义数据的预处理getitem
方法,数据增强;返回数据的item和labellen
方法,返回数据数量
(2)在我们定义dataloader
时,需要考虑下面几个参数:
dataset
:使用哪个数据集batch_size
:将数据集拆成一组多少个进行训练shuffle
:是否需要打乱数据num_workers
:几个mini_batch并行计算,一般<=你的电脑cpu数目collect_fn
:数据打包方式
(3)通过迭代的方式,按批次,获取dataloader
中的数据
(4)关系图
二、核心代码框架
import os
import cv2
from torchvision import transforms
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader# -------------------------------------------------------------#
# 自定义dataset需要继承torch.utils.data.dataset,
# 再重写def __init__,def __len__,def __getitem__三个方法
# -------------------------------------------------------------#
class YourDataset(Dataset):def __init__(self, root_path):super(YourDataset, self).__init__()self.root_path = root_path#-------------------------------------------------------------------------## 获取样本名,以jpg原始图片为参考,修改后缀名为json,png,获取json,png标签文件路径#-------------------------------------------------------------------------#self.sample_names = []jpg_path = os.path.join(os.path.join(self.root_path, "images"),)for file in os.listdir(jpg_path):if file.endswith(".jpg"):self.sample_names.append(os.path.splitext(file)[0]) # 去掉.jsondef __len__(self):#----------------------## 返回数据数量#----------------------#return len(self.sample_names)def __getitem__(self, index):name = self.sample_names[index]# ----------------------## 读取图像# ----------------------#img_path = os.path.join(os.path.join(self.root_path, "images"), name + '.jpg')image = cv2.imread(img_path)# ----------------------## 读取标签# ----------------------#label_path = os.path.join(os.path.join(self.root_path, "jsons"), name + '.json')with open(label_path) as label_file:points = self.get_data_from_json(label_file)#----------------------## 图像数据增强#----------------------#image = self.random_color(image)#----------------------## 标签归一化#----------------------#labels = self.convert_labels(points)return image, labels# -------------------------------------#
# 图片和标签格式转换后,按批次(batch)打包
# -------------------------------------#
def dataloader_collate_fn(batch):images = []labels = []for img, label in batch:images.append(transforms.ToTensor()(img))labels.append(label)return images, labelsif __name__ == '__main__':# -------------------------------------## 构建dataset# -------------------------------------#path = './data/train'train_dataset = YourDataset(path)# -------------------------------------## 构建Dataloader# -------------------------------------#dataset = train_datasetbatch_size = 32shuffle = Truenum_workers = 0collate_fn = dataloader_collate_fnsampler = Nonetrain_gen = DataLoader(dataset=dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers, pin_memory=True,drop_last=True, collate_fn=collate_fn, sampler=sampler)# ---------------------------------------------## 通过迭代的方式,一批一批读取训练集中的图像和标签数据# ---------------------------------------------#for iter, batch in enumerate(train_gen):images, labels = batch