ori_train = torchvision.datasets.ImageFolder(root= args.datadir + '/tiny-imagenet-200/train/', transform=transform)#可以获取class_idx的映射class_idx = ori_train.class_to_idx
val_annotations.txt中存储着每个图片对应的类别
获取验证集的标签
test_target = []#读取val_annotations.txttest_data_dir = "./data/tiny-imagenet-200/val"with open(test_data_dir + "/val_annotations.txt", 'r') as file:# 读取每一行并存储在数组中lines = file.readlines()# 输出每一行的数据for line in lines:content = line.strip().split("\t")target = class_idx[content[1]]test_target.append(target)
读取图片信息
ori_test_o = torchvision.datasets.ImageFolder(root= args.datadir + '/tiny-imagenet-200/val/', transform=transform)
自定义Dataset
ori_test = Imagenet_dataset(ori_test_o,test_target)class Imagenet_dataset(torch.utils.data.Dataset):def __init__(self, dataset, targets):self.dataset = datasetself.targets = targetsdef __getitem__(self, idx):img = self.dataset[idx][0]label = self.targets[idx]return (img, label)def __len__(self):return len(self.dataset)