问题描述
torchvision.datasets.ImageFolder 假定:子文件名=子文件夹的图像的标签
但在KDEF文件夹中,子文件夹下有所有的类,不宜用ImageFolder读取path来得到dataset
My 实现:
# how to build dataset?
from torch.utils import data
from torchvision import transforms, utils
import os
from PIL import Imageto_exp={'SU':0, 'AF':1, 'DI':2, 'HA':3, 'SA':4, 'AN':5, 'NE':6
}mytransform = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()
])class PoseDataset(data.Dataset):def __init__(self, path):super(PoseDataset,self).__init__()root_dir = os.path.join(os.getcwd(),path)dir_list = os.listdir(root_dir)self.data = []for dir_name in dir_list:tmp_dir = os.path.join(os.path.join(root_dir,dir_name))img_list = os.listdir(tmp_dir)for img_name in img_list:t_img = mytransform(Image.open(os.path.join(tmp_dir,img_name)).convert("RGB"))print('\r{},{},{}'.format(img_name,type(t_img),t_img.shape),end='')# import pdb; pdb.set_trace()id_label1 = to_exp[img_name[4:6]]id_label2 = int(img_name[2:4])self.data.append((t_img,id_label1,id_label2))def __getitem__(self,index):return self.data[index]def __len__(self):return len(self.data)mydataset = PoseDataset('data\\KDEF\\AF')
- 其中
print('\r{},{},{}'.format(img_name,type(t_img),t_img.shape),end='')
在一行连续更新, - 用PIL.Image.open()打开image file,并convert为RGB格式
之后就可用torch.util.data.DataLoader成批读取img tensor和label用于训练了
train_loader = data.DataLoader(mydataset,batch_size=16,drop_last=True,shuffle=True)
for image,label1,label2 in train_loader:print(image.shape)print(label1)print(label2)
- torch.util.data.DataLoader自动转label list为label tensor~🙂😊
划分dataset with data.random_split
trainset,validset=torch.utils.data.random_split(dataset1,[train_size,valid_size])