数据集结构
话不多说,直接上核心代码
myDataset.py
from collections import Counter
from torch.utils.data import Dataset
import os
from PIL import Imageclass MyDataset(Dataset):"""读取自制的数据集args:- image_dir: 图片的地址- label_dir: 标签的地址- name: 数据集的名称- transform: 数据集的预处理"""def __init__(self, image_dir:str, label_dir:str, name:str, transform=None):self.img_dir = os.path.join(image_dir, name)self.label_dir = os.path.join(label_dir, name)self.name = nameself.image_path = os.listdir(self.img_dir)self.label_path = os.listdir(self.label_dir)self.transform = transform"""读取数据集args:- index: 数据集的索引return:- image: 图片- label: 图片的标签"""def __getitem__(self, index:int)->tuple:# 获取图片的地址image = self.image_path[index]image = os.path.join(self.img_dir, image)# 获取图像image = Image.open(image)# 如果不是彩色图像,将下面的注释解开可以转换成彩色图像,不过图片的模样改变很大# if image.mode!= 'RGB':# image = image.convert('RGB')# 获取label的地址index_path = self.label_path[index]index_path = os.path.join(self.label_dir, index_path)label = self.parseTxt(index_path)if self.transform is not None:image = self.transform(image)return image, label"""将txt文件解析成数字description:> 这里每个txt文件下可能有多个label,选出现最多的,如果你的txt里面只有一个label的话,想办法读取出来返回就行args:- label: txt文件的地址return:- label: 图片的标签"""def parseTxt(self, label:str)->int:first_column = []with open(label, 'r') as f:for line in f.readlines():first_column.append(int(line.split()[0]))counter = Counter(first_column)return counter.most_common(1)[0][0]"""获取数据集的长度"""def __len__(self)->int:return len(self.image_path)
demo
train.py
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
# 导入加载数据集的类
from dataset import MyDataset
import osroot = os.path.join(os.getcwd(),'courseHomework','datasets')
transform = transforms.Compose([transforms.Resize((448, 448)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),# transforms.Normalize((0.5), (0.5,))
])train_dataset = MyDataset(root + '/images', root +'/labels', 'train', transform)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=False)for step, data in enumerate(train_loader):imgs, labels = dataprint(imgs[0].shape)transforms.ToPILImage()(imgs[0]).show()break
大家结构和我不一样可以自由发挥