数据集
流程图
导包 设置tfs 创建datasets.ImageFolder 创建torch.utils.data.DataLoader()
import time
import os
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch
import torchvision
import torch. nn as nn
import torch. nn. functional as F
import matplotlib. pyplot as plt
from torch. utils. data import DataLoader# 忽略烦人的红色提示
import warnings
warnings. filterwarnings ( "ignore" ) # 获取计算硬件
# 有 GPU 就用 GPU,没有就用 CPU
device = torch. device ( 'cuda:0' if torch. cuda. is_available ( ) else 'cpu' )
print ( 'device' , device) from torchvision import transforms# 训练集图像预处理:缩放裁剪、图像增强、转 Tensor、归一化
train_transform = transforms. Compose ( [ transforms. RandomResizedCrop ( 224 ) , transforms. RandomHorizontalFlip ( ) , transforms. ToTensor ( ) , transforms. Normalize ( [ 0.485 , 0.456 , 0.406 ] , [ 0.229 , 0.224 , 0.225 ] ) ] ) # 测试集图像预处理- RCTN:缩放、裁剪、转 Tensor、归一化
test_transform = transforms. Compose ( [ transforms. Resize ( 256 ) , transforms. CenterCrop ( 224 ) , transforms. ToTensor ( ) , transforms. Normalize ( mean= [ 0.485 , 0.456 , 0.406 ] , std= [ 0.229 , 0.224 , 0.225 ] ) ] ) # 数据集文件夹路径
dataset_dir = "C:/Users/Administrator/Desktop/ResNet/data/xiaobo" train_path = os. path. join ( dataset_dir, 'train_img' )
test_path = os. path. join ( dataset_dir, 'test_img' )
print ( '训练集路径' , train_path)
print ( '测试集路径' , test_path) from torchvision import datasets
# 载入训练集
train_dataset = datasets. ImageFolder ( train_path, train_transform)
# 载入测试集
test_dataset = datasets. ImageFolder ( test_path, test_transform) BATCH_SIZE = 32 # 训练集的数据加载器
train_loader = DataLoader ( train_dataset, batch_size= BATCH_SIZE, shuffle= True, num_workers= 4 ) # 测试集的数据加载器
test_loader = DataLoader ( test_dataset, batch_size= BATCH_SIZE, shuffle= False, num_workers= 4 ) print ( '训练集图像数量' , len ( train_dataset) )
print ( '类别个数' , len ( train_dataset. classes) )
print ( '各类别名称' , train_dataset. classes)
print ( '测试集图像数量' , len ( test_dataset) )
print ( '类别个数' , len ( test_dataset. classes) )
print ( '各类别名称' , test_dataset. classes) # 各类别名称
class_names = train_dataset. classes
n_class = len ( class_names)
# 映射关系:类别 到 索引号
train_dataset. class_to_idx
# 映射关系:索引号 到 类别
idx_to_labels = { y: x for x, y in train_dataset. class_to_idx. items ( ) }
print ( idx_to_labels)
# 保存为本地的 npy 文件
np. save ( 'idx_to_labels.npy' , idx_to_labels)
np. save ( 'labels_to_idx.npy' , train_dataset. class_to_idx)