torchvision
datasets
torchvision.datasets
包含了许多标准数据集的加载器。例如,CIFAR10
和ImageFolder
是其中两个非常常用的类。
CIFAR10
CIFAR10 数据集是一个广泛使用的数据集,包含10类彩色图像,每类有6000张图像(5000张训练集,1000张测试集)。下面是如何加载 CIFAR10 的示例:
import torch
from torchvision import datasets, transforms# 定义数据转换
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 加载训练集
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)# 加载测试集
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)# 输出类别
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
ImageFolder
ImageFolder
用于加载按照类别分文件夹存储的图像数据集。
import os
from torchvision import datasets, transformsdata_dir = './path/to/dataset'
transform = transforms.Compose([transforms.Resize(255),transforms.CenterCrop(224),transforms.ToTensor()])image_datasets = datasets.ImageFolder(data_dir, transform=transform)
dataloaders = torch.utils.data.DataLoader(image_datasets, batch_size=4, shuffle=True, num_workers=2)
models
torchvision.models
提供了一系列预训练模型,如 ResNet、VGG、InceptionV3 等。
ResNet模型:
SetsNet并不是torchvision
中的一个组件,而是指一类处理集合数据的神经网络。SetsNet和其他类似的网络(如DeepSets)旨在处理无序的集合输入,这些输入可以是点云、图像集合、特征向量集合等。SetsNet的设计原则是输入集合的顺序不会影响输出,即网络应该对输入的排列不变。
import torch
import torchvision.models as modelsmodel = models.resnet50(pretrained=True)
model.eval()# 预处理图像数据
preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# 加载图像
img_path = './path/to/image.jpg'
img = Image.open(img_path)
img_tensor = preprocess(img)
batch_img_tensor = torch.unsqueeze(img_tensor, 0)# 预测
out = model(batch_img_tensor)
VGG模型:
VGG
网络是一种经典的卷积神经网络架构,广泛应用于图像分类。下面是如何加载预训练的VGG
模型并在一张图像上进行预测的示例:
import torch
from torchvision import models, transforms
from PIL import Image# 加载预训练的VGG16模型
vgg16 = models.vgg16(pretrained=True)
vgg16.eval()# 图像预处理
preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# 加载图像
img_path = './path/to/image.jpg'
img_pil = Image.open(img_path)
img_tensor = preprocess(img_pil)
batch_img_tensor = torch.unsqueeze(img_tensor, 0)# 预测
out = vgg16(batch_img_tensor)
_, pred = torch.max(out, 1)
print("Predicted class:", pred.item())
Inception模型:
InceptionV3
是一种更复杂的卷积神经网络架构,设计用于处理高分辨率图像。以下是如何加载预训练的InceptionV3
模型并进行预测:
import torch
from torchvision import models, transforms
from PIL import Image# 加载预训练的InceptionV3模型
inceptionv3 = models.inception_v3(pretrained=True)
inceptionv3.eval()# 图像预处理
preprocess = transforms.Compose([transforms.Resize(299),transforms.CenterCrop(299),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# 加载图像
img_path = './path/to/image.jpg'
img_pil = Image.open(img_path)
img_tensor = preprocess(img_pil)
batch_img_tensor = torch.unsqueeze(img_tensor, 0)# 预测
out = inceptionv3(batch_img_tensor)
_, pred = torch.max(out, 1)
print("Predicted class:", pred.item())
utils
make_grid 网格排列
是一个用于在PyTorch中将多个图像张量组合成一个图像网格的函数。这对于可视化数据集、模型输出或者训练过程中的变化非常有用。make_grid
接受一系列图像张量,并返回一个单一的张量,该张量包含了所有输入图像按网格排列的结果
import torchvision.utils as vutils# 假设有数据加载器 dataloaders
dataiter = iter(dataloaders)
images, labels = dataiter.next()# 使用 make_grid 创建图像网格
img_grid = vutils.make_grid(images)# 显示图像网格
imshow(img_grid.numpy().transpose((1, 2, 0)))
save_image 保存图像
save_image
函数可以用来保存一个张量为图像文件。下面是一个如何保存图像的例子:
import torch
from torchvision.utils import save_image
from PIL import Image# 假设我们有一个图像张量
img_tensor = torch.randn(3, 224, 224)# 保存图像
save_image(img_tensor, 'saved_image.jpg')# 也可以从PIL Image转换为张量并保存
img_pil = Image.new('RGB', (224, 224), color='white')
img_tensor = transforms.ToTensor()(img_pil)
save_image(img_tensor, 'saved_image_from_pil.jpg')
请确保替换上述代码中的./path/to/image.jpg
为实际的图像路径,并确保在运行代码之前有正确的权限访问指定的路径。此外,如果还没有安装torchvision
和Pillow
,可能需要先安装:
pip install torchvision pillow
transforms
是PyTorch中一个重要的模块,用于进行图像预处理和数据增强。它位于torchvision.transforms模块中,主要用于处理PIL图像和Tensor图像。transforms可以帮助你在训练神经网络时对数据进行各种变换,例如随机裁剪、大小调整、正则化等,以增加数据的多样性和模型的鲁棒性。
常见的transforms包括:
数据类型转换:
ToTensor()
: 将PIL图像或NumPy数组转换为PyTorch的Tensor格式。几何变换:
Resize(size)
: 调整图像大小。CenterCrop(size)
: 中心裁剪图像。RandomCrop(size)
: 随机裁剪图像。RandomHorizontalFlip(p=0.5)
: 随机水平翻转图像。色彩变换:
ColorJitter(brightness, contrast, saturation, hue)
: 随机调整图像的亮度、对比度、饱和度和色调。正则化:
Normalize(mean, std)
: 标准化图像像素值。
使用transforms
通常需要将它们组合成一个transforms.Compose对象,以便按顺序应用到图像数据上。这样可以灵活地定义数据增强的流程,适应不同的任务需求和数据特征。
当使用transforms进行图像预处理和数据增强时,通常需要按照以下步骤进行操作:
1.导入必要的库:
from torchvision import transformsfrom PIL import Image
2.定义transforms操作:可以根据需求选择合适的transforms进行组合。
transform = transforms.Compose([transforms.Resize((256, 256)), # 调整图像大小为256x256transforms.RandomCrop(224), # 随机裁剪图像为224x224transforms.RandomHorizontalFlip(), # 随机水平翻转图像transforms.ToTensor(), # 将图像转换为Tensor,并归一化至[0, 1]transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])# 标准化图像像素值])
3.加载图像并应用transforms:
# 假设有一张名为image.jpg的图像img = Image.open('image.jpg')# 应用transformsimg_transformed = transform(img)
4.查看处理后的图像:处理后的图像会转换为Tensor,并进行了resize、crop、翻转等操作。
print(img_transformed.size()) # 输出处理后的图像大小
在上面的例子中,transforms.Compose用于将多个transforms组合起来,依次应用到图像上。这种方式能够让你根据任务需求定义灵活的图像处理流程,例如在训练神经网络时进行数据增强,提升模型的泛化能力。