多分类自定义采样比例
import torch
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torchvision import transforms
from torchvision.datasets import ImageFolder# 假设你有一个自定义的数据集类
class CustomDataset(Dataset):def __init__(self, data_dir, transform=None):self.dataset = ImageFolder(data_dir, transform=transform)self.class_weights = self.calculate_class_weights()def calculate_class_weights(self):# 计算每个类别的样本权重,可以根据不同的策略进行调整class_counts = torch.tensor([self.dataset.targets.count(i) for i in range(len(self.dataset.classes))])class_weights = 1.0 / class_countsreturn class_weightsdef __len__(self):return len(self.dataset)def __getitem__(self, idx):return self.dataset[idx]# 数据集目录
data_dir = "path/to/your/dataset"# 定义图像转换
transform = transforms.Compose([transforms.Resize((224, 224)),tra