数据使用方角度
首先从pytorch出发,torchvision.transforms()要求传入的图像是PIL.Image格式(通道要求是RGB格式的),另外模型处理输入要转换为[1,channel,H,W];
所以最终导入torchvision.transforms()的图像格式需要转成PIL.Image,且需要在转换后增加batch维度([channel,H,W]变成[1,channel,H,W])
# 制作dataset
class MyDataset(Dataset):def __init__(self,data_path,label_path,transforms=None):self.data_path = data_pathself.label_path = label_pathself.transform = transformsself.images = sorted(os.listdir(data_path))#这里的label_path就是csv文件self.label= pd.read_csv(self.label_path)['label']def __len__(self):return len(self.images)def __getitem__(self,idx):img_path = os.path.join(self.data_path , self.images[idx]) # 将image_dir与图像列表中的每张图的名字连接成地址#这的标签是一个数字表示的类别,所以不用像图片一样进行操作。直接使用self.label来进行读取即可# image_data = cv2.imread(img_path)image_data = Image.open(img_path)if self.transform is not None :image_data = self.transform(image_data)label_data = self.label[idx]return image_data,label_data
# values from ImageNet, recommended by PyTorch
transform_mean = [0.485, 0.456, 0.406]
transform_std = [0.229, 0.224, 0.225]trans = transforms.Compose([transforms.Resize((224, 224)),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=transform_mean, std=transform_std),
])
#划分数据集,将原来的训练集划分为 训练集和验证集
dataset = MyDataset(data_path,label_path,trans)
train_dataset,valid_dataset = random_split(dataset, [0.8,0.2])#创建数据加载器
train_loader = DataLoader(train_dataset,batch_size=4,shuffle=True,num_workers=2)
valid_loader = DataLoader(valid_dataset,batch_size=4,shuffle=True,num_workers=2)
print(len(train_dataset))
print(len(valid_dataset))
上面这个代码 image_data = cv2.imread(img_path)
会报错,原因是
-
torchvision.transforms.Resize预期接收的输入类型是PIL图像,但是它收到了一个NumPy数组。这导致了类型错误。
具体来说,在你的数据集类中,你使用了OpenCV
(cv2)库读取图像数据。OpenCV读取的图像数据是NumPy数组,而不是PIL图像。因此,直接将这些NumPy数组传递给torchvision.transforms.Resize时,就会导致类型不匹配的错误。
总结:
具体来说这两种方式读取图片没有大的区别,但是建议用image_data = Image.open(img_path)
的方式来读。