代码怕忘记,现在贴上来,以防丢失
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torch
from torchvision import transformsdef get_mean_and_std(data_path, in_chans=3):dataset = ImageFolder(root=data_path, transform=transforms.ToTensor())loader = DataLoader(dataset, batch_size=1, shuffle=False, pin_memory=True)mean = torch.zeros(in_chans)std = torch.zeros(in_chans)num_samples = 0for X, _ in loader:for d in range(in_chans):mean[d] += X[:, d, :, :].mean()std[d] += X[:, d, :, :].std()num_samples += 1mean.div_(num_samples)std.div_(num_samples)mean = list(mean.numpy())std = list(std.numpy())print(f"Mean: {mean}")print(f"Standard Deviation: {std}")return mean, stddata_path = "G:\\04_deep-learning-for-image-processing-master\\pytorch_classification\\swin_transformer\\flower_photos"
mean, std = get_mean_and_std(data_path)#if __name__ == '__main__':#main("G:\\04_deep-learning-for-image-processing-master\\pytorch_classification\\swin_transformer\\flower_photos")