学习pytorch 2
- 2. dataset实战
- 代码
- 数据集
2. dataset实战
B站小土堆视频
代码
from torch.utils.data import Dataset
from PIL import Image
#import cv2
import osclass MyData(Dataset):def __init__(self, root_dir, label_dir):self.root_dir = root_dirself.label_dir = label_dirself.path = os.path.join(self.root_dir, self.label_dir)self.img_list = os.listdir(self.path)def __getitem__(self, idx):img_name = self.img_list[idx]img_path = os.path.join(self.root_dir, self.label_dir, img_name)img = Image.open(img_path)label = self.label_dirreturn img, labeldef __len__(self):return len(self.img_list)if __name__ == '__main__':root_dir = "hymenoptera_data/train"ants_label_dir = "ants"bees_label_dir = "bees"ants_dataset = MyData(root_dir, ants_label_dir)bees_dataset = MyData(root_dir, bees_label_dir)train_dataset = ants_dataset + bees_datasetimg, label = ants_dataset[1]img, label = train_dataset[1]img.show()
数据集
https://download.pytorch.org/tutorial/hymenoptera_data.zip