PyTorch是使用GPU和CPU优化的深度学习张量库——torchvision

torchvision

datasets

torchvision.datasets 包含了许多标准数据集的加载器。例如,CIFAR10ImageFolder 是其中两个非常常用的类。

CIFAR10

CIFAR10 数据集是一个广泛使用的数据集,包含10类彩色图像,每类有6000张图像(5000张训练集,1000张测试集)。下面是如何加载 CIFAR10 的示例:

import torch
from torchvision import datasets, transforms# 定义数据转换
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 加载训练集
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)# 加载测试集
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)# 输出类别
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

ImageFolder

ImageFolder 用于加载按照类别分文件夹存储的图像数据集。

import os
from torchvision import datasets, transformsdata_dir = './path/to/dataset'
transform = transforms.Compose([transforms.Resize(255),transforms.CenterCrop(224),transforms.ToTensor()])image_datasets = datasets.ImageFolder(data_dir, transform=transform)
dataloaders = torch.utils.data.DataLoader(image_datasets, batch_size=4, shuffle=True, num_workers=2)

models

torchvision.models 提供了一系列预训练模型,如 ResNet、VGG、InceptionV3 等。

ResNet模型:

SetsNet并不是torchvision中的一个组件,而是指一类处理集合数据的神经网络。SetsNet和其他类似的网络(如DeepSets)旨在处理无序的集合输入,这些输入可以是点云、图像集合、特征向量集合等。SetsNet的设计原则是输入集合的顺序不会影响输出,即网络应该对输入的排列不变。

import torch
import torchvision.models as modelsmodel = models.resnet50(pretrained=True)
model.eval()# 预处理图像数据
preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# 加载图像
img_path = './path/to/image.jpg'
img = Image.open(img_path)
img_tensor = preprocess(img)
batch_img_tensor = torch.unsqueeze(img_tensor, 0)# 预测
out = model(batch_img_tensor)

VGG模型:

VGG网络是一种经典的卷积神经网络架构,广泛应用于图像分类。下面是如何加载预训练的VGG模型并在一张图像上进行预测的示例:

import torch
from torchvision import models, transforms
from PIL import Image# 加载预训练的VGG16模型
vgg16 = models.vgg16(pretrained=True)
vgg16.eval()# 图像预处理
preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# 加载图像
img_path = './path/to/image.jpg'
img_pil = Image.open(img_path)
img_tensor = preprocess(img_pil)
batch_img_tensor = torch.unsqueeze(img_tensor, 0)# 预测
out = vgg16(batch_img_tensor)
_, pred = torch.max(out, 1)
print("Predicted class:", pred.item())

Inception模型:

InceptionV3是一种更复杂的卷积神经网络架构,设计用于处理高分辨率图像。以下是如何加载预训练的InceptionV3模型并进行预测:

import torch
from torchvision import models, transforms
from PIL import Image# 加载预训练的InceptionV3模型
inceptionv3 = models.inception_v3(pretrained=True)
inceptionv3.eval()# 图像预处理
preprocess = transforms.Compose([transforms.Resize(299),transforms.CenterCrop(299),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# 加载图像
img_path = './path/to/image.jpg'
img_pil = Image.open(img_path)
img_tensor = preprocess(img_pil)
batch_img_tensor = torch.unsqueeze(img_tensor, 0)# 预测
out = inceptionv3(batch_img_tensor)
_, pred = torch.max(out, 1)
print("Predicted class:", pred.item())

utils

make_grid 网格排列

是一个用于在PyTorch中将多个图像张量组合成一个图像网格的函数。这对于可视化数据集、模型输出或者训练过程中的变化非常有用。make_grid接受一系列图像张量,并返回一个单一的张量,该张量包含了所有输入图像按网格排列的结果

import torchvision.utils as vutils# 假设有数据加载器 dataloaders
dataiter = iter(dataloaders)
images, labels = dataiter.next()# 使用 make_grid 创建图像网格
img_grid = vutils.make_grid(images)# 显示图像网格
imshow(img_grid.numpy().transpose((1, 2, 0)))

save_image 保存图像

save_image函数可以用来保存一个张量为图像文件。下面是一个如何保存图像的例子:

import torch
from torchvision.utils import save_image
from PIL import Image# 假设我们有一个图像张量
img_tensor = torch.randn(3, 224, 224)# 保存图像
save_image(img_tensor, 'saved_image.jpg')# 也可以从PIL Image转换为张量并保存
img_pil = Image.new('RGB', (224, 224), color='white')
img_tensor = transforms.ToTensor()(img_pil)
save_image(img_tensor, 'saved_image_from_pil.jpg')

请确保替换上述代码中的./path/to/image.jpg为实际的图像路径,并确保在运行代码之前有正确的权限访问指定的路径。此外,如果还没有安装torchvisionPillow,可能需要先安装:

pip install torchvision pillow

transforms

是PyTorch中一个重要的模块,用于进行图像预处理和数据增强。它位于torchvision.transforms模块中,主要用于处理PIL图像和Tensor图像。transforms可以帮助你在训练神经网络时对数据进行各种变换,例如随机裁剪、大小调整、正则化等,以增加数据的多样性和模型的鲁棒性。

常见的transforms包括:

  1. 数据类型转换

    • ToTensor(): 将PIL图像或NumPy数组转换为PyTorch的Tensor格式。
  2. 几何变换

    Resize(size): 调整图像大小。                                                              CenterCrop(size): 中心裁剪图像。                                                              RandomCrop(size): 随机裁剪图像。                                          RandomHorizontalFlip(p=0.5): 随机水平翻转图像。
  3. 色彩变换

    ColorJitter(brightness, contrast, saturation, hue): 随机调整图像的亮度、对比度、饱和度和色调。
  4. 正则化

    Normalize(mean, std): 标准化图像像素值。

使用transforms

通常需要将它们组合成一个transforms.Compose对象,以便按顺序应用到图像数据上。这样可以灵活地定义数据增强的流程,适应不同的任务需求和数据特征。

当使用transforms进行图像预处理数据增强时,通常需要按照以下步骤进行操作:

1.导入必要的库:

 from torchvision import transformsfrom PIL import Image


2.定义transforms操作:可以根据需求选择合适的transforms进行组合。

 transform = transforms.Compose([transforms.Resize((256, 256)),     # 调整图像大小为256x256transforms.RandomCrop(224),        # 随机裁剪图像为224x224transforms.RandomHorizontalFlip(), # 随机水平翻转图像transforms.ToTensor(),             # 将图像转换为Tensor,并归一化至[0, 1]transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])# 标准化图像像素值])


3.加载图像并应用transforms: 

 # 假设有一张名为image.jpg的图像img = Image.open('image.jpg')# 应用transformsimg_transformed = transform(img)


4.查看处理后的图像:处理后的图像会转换为Tensor,并进行了resize、crop、翻转等操作。

  print(img_transformed.size())  # 输出处理后的图像大小

在上面的例子中,transforms.Compose用于将多个transforms组合起来,依次应用到图像上。这种方式能够让你根据任务需求定义灵活的图像处理流程,例如在训练神经网络时进行数据增强,提升模型的泛化能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/46157.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++ 类和对象(上)

再C中,我们使用类定义自己的数据类型。通过定义新的类型来反映待解决的问题中的各种概念,可以使我们更容易编写,调试和修改程序。 类定义格式 首先类的定义格式和结构体差不多,而结构体的那一套语法也可以在C中使用。class是定义…

RC-u3 跑团机器人

这道题目要求我们模拟一个跑团机器人,解析玩家输入的包含骰子掷出和加减法运算的指令,计算出不同种类的骰子分别需要掷出几个,并根据输入指令得到可能的最小结果和最大结果。 题目分析 我们需要解析输入的表达式,处理其中的加法…

oracle数据库的plsql免安装版安装

这个是连接oracle数据库的,注意安装不能有中文路径。以下只是示例。 1、打开D:\ruanjian\plsql\plsql\plsql,发送plsqldev.exe快捷方式到桌面。 2、新弹出的页面填写cancel,什么也不写。 3、将instanceclient解压,并复制文件路径。 修改tool…

[Python学习篇] Python搭建静态web服务器

Python内置的web静态服务器 Python内置的http.server模块可以快速启动一个简单的HTTP服务器。 在Python 3中,打开命令行或终端,在你想要作为静态服务器根目录的文件夹下,运行以下命令: python -m http.server 8000 这将会在…

纯vue+js实现数字0到增加到指定数字动画效果功能

关于数字增加动画效果网上基本上都是借助第三方插件实现的,就会导致有的项目安装插件总会出问题,所有最好使用原生vue+js实现,比较稳妥 纯vue+js实现数字0到增加到指定数字动画效果功能 vue+js 实现数字增加动画功能 效果图 其中,关于数字变化的间隔时间,延时效果都可…

数据结构与算法 —— DFS的定义与原理

DFS(Distributed File System,分布式文件系统)是一种允许网络连接的多个计算机之间共享信息的系统架构。与传统的文件系统(如NTFS、HFS等)不同,DFS分布在多个文件服务器或多个位置,通过计算机网…

澳门建筑插画:成都亚恒丰创教育科技有限公司

澳门建筑插画:绘就东方之珠的斑斓画卷 在浩瀚的中华大地上,澳门以其独特的地理位置和丰富的历史文化,如同一颗璀璨的明珠镶嵌在南国海疆。这座城市,不仅是东西方文化交融的典范,更是建筑艺术的宝库。当画笔轻触纸面&a…

一个spring boot项目的启动过程分析

1、web.xml 定义入口类 <context-param><param-name>contextConfigLocation</param-name><param-value>com.baosight.ApplicationBoot</param-value> </context-param> 2、主入口类: ApplicationBoot,SpringBoot项目的mian函数 SpringBo…

时间序列学习篇

今天看了一些时间序列算法相关的文档和帖子。很惭愧&#xff0c;也是搞了很长时间预测算法的人了&#xff0c;但是都没能详细学习一下时间序列的理论。 首先&#xff0c;要预测一个时序问题&#xff0c;可以从什么路径解决呢&#xff1f;一种是认为过去序列状态影响将来的状态…

视频调整帧率、分辨率+音画同步

# python data_utils/pre_video/multi_fps_crop_sync.pyimport cv2 import os from tqdm import tqdm import subprocess# 加载人脸检测模型 face_cascade cv2.CascadeClassifier(cv2.data.haarcascades haarcascade_frontalface_default.xml)def contains_face(frame):gray …

淘宝/天猫店铺商品搜索利器:taobao.item_search_shop API返回值详解

taobao.item_search_shop 这个API名称听起来像是针对淘宝或天猫平台的一个商品搜索接口&#xff0c;但实际上&#xff0c;淘宝和天猫的官方API体系中并没有直接命名为taobao.item_search_shop的公开API。不过&#xff0c;为了解答关于类似功能的API返回值详解&#xff0c;我们可…

(三)Redis持久化,你真的懂了吗?万字分析AOF和RDB的优劣 AOF的刷盘、重写策略 什么叫混合重写 MP-AOF方案是什么

引言 —— Redis基础概念 Redis概念&#xff1a;Redis (REmote DIctionary Server) 是用 C 语言开发的一个开源的高性能键值对&#xff08;key-value&#xff09;数据库。 为什么会出现Redis呢&#xff1f;它的到来是为了解决什么样的问题&#xff1f; Redis 是一个NOSQL类型…

27 设备流转使用心得 三

前两部分参考心得 25 26 分布式文件传输 1 源端 1 获取分布式文件路径 读取文件 写入分布式文件 2 对端 1 通过应用沙箱获取分布式文件路径 读取文件路径 与状态数据绑定 2 绑定之后UI渲染 Index Row({space:8}){//用户当前选中的所有图片ForEach(this.photos, (p:str…

操作系统真象还原:创建文件系统

14.2 创建文件系统 14.2.1 创建超级块、i结点、目录项 超级块 /** Author: Adward-DYX 1654783946qq.com* Date: 2024-05-07 10:18:02* LastEditors: Adward-DYX 1654783946qq.com* LastEditTime: 2024-05-07 11:24:50* FilePath: /OS/chapter14/14.2/fs/super_block.h* Des…

构造、析构、拷贝(Semantics of Construction,Destruction,and Copy)

1、继承体系下的对象构造 当我定义一个object如下 T object;如果T有一个默认构造函数&#xff0c;它会被调用。 比较不明显的是构造函数内部有大量的隐藏代码&#xff0c;因为编译器会扩充构造函数&#xff0c;一般而言编译器所做的扩充如下&#xff1a; 记录在成员初始化列…

WPF学习(6) -- WPF命令和通知

一 、WPF命令 1.ICommand代码 创建一个文件夹和文件 using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; using System.Windows.Input;namespace 学习.Command {public class MyCommand : ICommand{Acti…

CCSI: 数据无关类别增量学习的持续类特定印象| 文献速递-基于深度学习的多模态数据分析与生存分析

Title 题目 CCSI: Continual Class-Specific Impression for data-free class incremental learning CCSI: 数据无关类别增量学习的持续类特定印象 01 文献速递介绍 当前用于医学影像分类任务的深度学习模型表现出令人鼓舞的性能。这些模型大多数需要在训练之前收集所有的…

设计模式使用场景实现示例及优缺点(行为型模式——迭代子模式)

迭代子模式&#xff08;Iterator Pattern&#xff09; 迭代子模式&#xff08;Iterator Pattern&#xff09;是一种常用的设计模式&#xff0c;属于行为型模式。它提供一种方法顺序访问一个聚合对象中的各个元素&#xff0c;而又无需暴露该对象的内部表示。 核心组件 Iterat…

中间件——Kafka

两个系统各自都有各自要去做的事&#xff0c;所以只能将消息放到一个中间平台&#xff08;中间件&#xff09; Kafka 分布式流媒体平台 程序发消息&#xff0c;程序接收消息 Producer&#xff1a;Producer即生产者&#xff0c;消息的产生者&#xff0c;是消息的入口。 Brok…

[Vulnhub] Sedna BuilderEngine-CMS+Kernel权限提升

信息收集 IP AddressOpening Ports192.168.8.104TCP:22, 53, 80, 110, 111, 139, 143, 445, 993, 995, 8080, 55679 $ nmap -p- 192.168.8.104 --min-rate 1000 -sC -sV PORT STATE SERVICE VERSION 22/tcp open ssh OpenSSH 6.6.1p1 Ubuntu 2ubuntu2 …