之前写过加载数据集的一些小笔记,这里详细内容就不再叙述了
详细学习可以参考该博文二、PyTorch加载数据
一、分析
因为U-net网络架构是输入1通道,大小为(572,572)的灰度图,图片大小无所谓,我的思路是将三通道的图像使用OpenCV进行相关的处理,转换为单通道突图片,之后再送入网络模型中
二、准备数据集
数据集的采集和制作可以参考该篇博文
四、采集和制作数据集
三、完整加载数据集代码
test_dataset
import torch
import cv2
import os
import glob
from torch.utils.data import Dataset
import randomclass Beyond_loader(Dataset):def __init__(self, data_path):# 初始化函数,读取所有data_path下的图片self.data_path = data_pathself.imgs_path = glob.glob(os.path.join(data_path, 'image/*.png'))def augment(self, image, flipCode):# 使用cv2.flip进行数据增强,filpCode为1水平翻转,0垂直翻转,-1水平+垂直翻转flip = cv2.flip(image, flipCode)return flipdef __getitem__(self, index):# 根据index读取图片image_path = self.imgs_path[index]# 根据image_path生成label_pathlabel_path = image_path.replace('image', 'label')# 读取训练图片和标签图片image = cv2.imread(image_path)label = cv2.imread(label_path)# 将数据转为单通道的图片image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)image = image.reshape(1, image.shape[0], image.shape[1])label = label.reshape(1, label.shape[0], label.shape[1])# 处理标签,将像素值为255的改为1if label.max() > 1:label = label / 255# 随机进行数据增强,为2时不做处理flipCode = random.choice([-1, 0, 1, 2])if flipCode != 2:image = self.augment(image, flipCode)label = self.augment(label, flipCode)return image, labeldef __len__(self):# 返回训练集大小return len(self.imgs_path)if __name__ == "__main__":beyond_loader = Beyond_loader("./dataset/train")print("数据个数:", len(beyond_loader))train_loader = torch.utils.data.DataLoader(dataset=beyond_loader,batch_size=1,shuffle=True)for image, label in train_loader:print(image.shape)
一共有6张图像,batch_size设为1,故train_loader有6组
数据个数: 6
torch.Size([1, 1, 320, 320])
torch.Size([1, 1, 320, 320])
torch.Size([1, 1, 320, 320])
torch.Size([1, 1, 320, 320])
torch.Size([1, 1, 320, 320])
torch.Size([1, 1, 320, 320])