PyTorch中torchvision库的详细介绍

torchvision 是 PyTorch 生态系统中的一个关键库,专门为计算机视觉任务设计和优化。它提供了以下几个核心功能:

  1. 数据集:内置了多种广泛使用的图像和视频数据集,如 MNIST、CIFAR10/100、Fashion-MNIST、ImageNet、COCO 等,并且它们以 torch.utils.data.Dataset 的形式实现,方便与 PyTorch 数据加载器(DataLoader)集成。

  2. 数据预处理工具:通过 torchvision.transforms 模块提供了丰富的数据增强和预处理操作,包括但不限于裁剪、旋转、翻转、归一化、调整大小、颜色转换等,这些操作对于训练稳健的深度学习模型至关重要。

  3. 深度学习模型架构:在 torchvision.models 中封装了大量经典的预训练模型结构,例如 AlexNet、VGG、ResNet、Inception 系列、DenseNet、SqueezeNet 以及一些用于目标检测和语义分割任务的模型,用户可以直接加载这些模型进行迁移学习或者作为基础网络结构进行扩展。

  4. 实用工具:包含了一系列实用方法,比如将张量保存为图像文件、创建图像网格以便可视化多个样本等。

总之,torchvision 为基于 PyTorch 构建计算机视觉项目提供了极大的便利性,涵盖了从数据获取到模型构建及实验结果可视化等各个环节所需的功能。

1. 数据集

torchvision 是 PyTorch 的一个官方库,主要用于计算机视觉任务,它为开发者提供了一系列常用的数据集、模型架构以及图像转换工具。在 torchvision.datasets 子模块中,它包含了多个内置数据集,这些数据集可以直接用于训练和评估图像分类、对象检测、语义分割等多种视觉模型。以下是几个 torchvision 库中包含的常见数据集:

  1. MNIST

    手写数字识别数据集,包含60,000个训练样本和10,000个测试样本,每个样本都是大小为28x28像素的单通道灰度图像,对应的标签是0-9的数字类别。
  2. CIFAR-10/100

    • CIFAR-10 包含了60,000张32x32像素的彩色图像,分为10个类别,每类各有6000个样本(50,000用于训练,10,000用于测试)。
    • CIFAR-100 与 CIFAR-10 类似,但具有100个类别,每个类别有600张图片,因此对于细粒度分类更具挑战性。
  3. Fashion-MNIST

    作为 MNIST 数据集的替代品,同样包含60,000训练样本和10,000测试样本,但是每个样本是一张28x28像素的时尚物品(如衬衫、裤子等)的灰度图像。
  4. ImageNet

    虽然 torchvision 自身不直接提供 ImageNet 数据集的下载功能,但它提供了接口来加载已经下载好的 ILSVRC 2012 分类数据集(即通常所说的 ImageNet),该数据集包含超过1000类的物体类别,每类有数千张不同大小的RGB彩色图像。
  5. STL10

    STL-10是一个小规模版本的ImageNet,有10个类别的100,000张未标记图像、5000张带标签的训练图像和8000张带标签的测试图像。
  6. COCO (Common Objects in Context)

    COCO 数据集用于目标检测、分割和图像字幕等任务,包含大量标注的日常场景图片,每张图片可以包含多个目标及其边界框和分割掩模。

使用时,可以通过以下方式加载这些数据集:

Python
1import torch
2import torchvision
3from torchvision import datasets
4
5# 加载CIFAR-10数据集并进行基本处理
6transform = torchvision.transforms.Compose([...])  # 定义数据预处理操作
7dataset_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
8dataset_test = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
9
10# 使用DataLoader进一步将数据集转化为适合训练的批次
11dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=..., shuffle=True)
12dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=...)

其中,root 参数指定了数据集存储的位置;train 参数确定是否加载训练集或测试集;download 参数设置为 True 则会自动从网上下载数据集;transform 参数允许对原始图像数据进行必要的预处理操作,例如归一化、裁剪、旋转等。

2. 数据预处理工具

torchvision 库中的数据预处理工具主要体现在 torchvision.transforms 模块,它提供了丰富的函数和类来对图像数据进行各种形式的转换和预处理。这些预处理操作在深度学习中是至关重要的,因为它们可以增强模型的泛化能力,并且将不同大小和格式的原始图像数据转化为神经网络能够接受的标准输入。

以下是一些常用的数据预处理方法:

  1. Resize

    transforms.Resize(size, interpolation):调整图像大小到指定尺寸。
  2. CenterCrop

    transforms.CenterCrop(size):从图像中心裁剪出一个给定大小的正方形区域。
  3. RandomCrop

    transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'):随机裁剪图像的一块区域。
  4. Normalize

    transforms.Normalize(mean, std):对图像像素值进行标准化处理,通常用于归一化RGB通道的均值和标准差。
  5. ToTensor

    transforms.ToTensor():将 PIL Image 或 numpy.ndarray 转换为 PyTorch 张量(从 0-255 的整数范围转换到 0-1 的浮点数范围)。
  6. ConvertImageDtype

    transforms.ConvertImageDtype(dtype):将图像转换为指定的数据类型。
  7. RandomHorizontalFlip

    transforms.RandomHorizontalFlip(p=0.5):以一定概率水平翻转图像。
  8. RandomVerticalFlip

    transforms.RandomVerticalFlip(p=0.5):以一定概率垂直翻转图像。
  9. ColorJitter

    transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0):随机改变图像的颜色属性,如亮度、对比度、饱和度和色调。
  10. Grayscale

    transforms.Grayscale(num_output_channels=1):将彩色图像转换为灰度图像。
  11. RandomRotation

    transforms.RandomRotation(degrees, resample=False, expand=False, center=None, fill=None):随机旋转图像。

为了方便使用,用户通常会组合多个预处理步骤,利用 transforms.Compose 类将其封装成一个预处理流水线:

Python
1from torchvision import transforms
2
3preprocess = transforms.Compose([
4    transforms.Resize(256),
5    transforms.CenterCrop(224),
6    transforms.ToTensor(),
7    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
8])
9
10# 使用预处理流水线
11transformed_image = preprocess(image)

以上是对 torchvision.transforms 中一些关键数据预处理功能的概述,实际上该模块包含了更多丰富的方法,可以满足各种计算机视觉任务的需求。

3. 深度学习模型架构

torchvision.models 模块提供了大量预训练的深度学习模型,这些模型主要针对图像分类、对象检测和语义分割等计算机视觉任务。以下是一些常见的模型架构:

  1. 图像分类模型

    • resnet18resnet34resnet50resnet101resnet152:基于残差网络(ResNet)架构,是目前最常用的深度神经网络之一,用于ImageNet数据集上的图像分类任务。
    • vgg16vgg19:基于VGG(Visual Geometry Group)架构,特征提取能力强,但计算复杂度相对较高。
    • densenet121densenet169densenet201densenet161:密集连接网络(DenseNet),通过密集块之间的稠密连接减少信息丢失并提升模型性能。
    • alexnetsqueezenet1_0squeezenet1_1:AlexNet和SqueezeNet是较早的深度学习模型,前者在ILSVRC 2012竞赛中取得了突破性成果,后者以其轻量级结构著称。
    • googlenetshufflenet_v2_x1_0mobilenet_v2 等:为移动设备或资源受限环境设计的小型化网络。
  2. 对象检测模型

    • torchvision不直接提供完整的预训练对象检测模型,但它包含了如 ssdfaster_rcnn 等检测模型的基本组件,用户可以利用 torchvision.ops 和 torchvision.models.detection 中的模块来构建自己的检测模型。
  3. 语义分割模型

    • fcn_resnet50deeplabv3_resnet50lraspp_mobilenet_v3_large 等:全卷积网络(Fully Convolutional Networks, FCN)、DeepLabV3 和 MobileNetV3 架构为基础的语义分割模型,可用于像素级别的图像分类任务。

所有这些模型都支持加载预训练权重,并且能够作为基础结构进行迁移学习或微调以适应新的任务。例如,加载预训练的ResNet50模型进行图像分类任务可以通过以下方式实现:

Python
1import torchvision.models as models
2
3# 加载预训练模型
4model = models.resnet50(pretrained=True)
5
6# 将模型的最后一层替换为与新任务类别数匹配的线性层
7num_classes = len(new_dataset.classes)
8model.fc = nn.Linear(model.fc.in_features, num_classes)
9
10# 设定优化器并开始训练
11optimizer = torch.optim.Adam(model.parameters())

请注意,具体的模型列表可能会随着 torchvision 版本的更新而有所变化,因此建议查阅最新的官方文档获取详细信息。

4. 实用工具

torchvision 库除了提供数据集和预训练模型之外,还包含一些实用工具函数和类,这些工具在处理计算机视觉任务时非常有用。以下是一些关键的实用工具:

  1. 图像保存与读取

    • torchvision.utils.save_image(tensor, filename, format=None):将一个张量(通常是经过处理后的图像)保存为指定格式(如PNG、JPEG等)的图像文件。
    • 通过 PIL 或其他图像库读取图像后,可以使用 transforms.ToTensor() 将其转换为 PyTorch 张量。
  2. 图像显示

    虽然 torchvision 自身不直接提供图像显示功能,但可以通过与外部库(如 matplotlib)结合来展示图像。例如,plt.imshow(torchvision.utils.make_grid(images)) 可以用来创建并显示一张由多个图像组成的网格。
  3. 图像拼接

    torchvision.utils.make_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0):将一组张量按照行列方式排列成一个大图像,常用于可视化多幅小图的结果。
  4. 视频处理

    • torchvision.io.read_video(filename, start_pts=0, end_pts=float('inf'), pts_unit='sec', decoder_backend=None):从视频文件中读取帧,并返回一个包含所有帧的 Tensor 列表。
    • torchvision.ops.video_reader.VideoReader(file_path, mode='video', backend=None):提供了一个视频读取器对象,可用于逐帧读取视频。
  5. 图像元数据获取

    对于某些加载的数据集(如COCO),torchvision 提供了方法来访问图像尺寸、标签以及其他元数据信息。
  6. 模型可视化工具

    虽然不是严格意义上的“实用工具”,但 torchvision.models.utils 模块提供了用于生成模型结构图的方法,如 plot_model(model, show_shapes=False, to_file=None) 可以生成模型结构图(需要安装额外依赖如graphviz)。

以上列举的功能可以帮助开发者在进行计算机视觉任务时更好地管理和可视化数据及模型。同时,随着 torchvision 的不断更新和发展,可能会有更多实用工具加入其中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/673415.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux中的numactl命令指南

假设我们想控制线程如何被分配到处理器核心,或者选择我们想分配数据的位置,那么numactl命令就适合此类任务。在这篇文章中,我们讨论了如何使用numactl命令执行此类操作。 目录: 介绍语法命令总结参考文献 简介 现代处理器采用…

QGIS编译(跨平台编译)之五十一:Shapelib编译(Windows、Linux、MacOS环境下编译)

文章目录 一、Shapelib介绍二、Shapelib下载三、Windows下编译四、Linux下编译五、MacOS下编译一、Shapelib介绍 Shapelib是一个开源的C/C++库,用于读取、写入和处理ESRI Shapefile格式的空间数据。Shapefile是一种常用的GIS数据格式,包含矢量数据,如点、线、面等。Shapeli…

rkmedia使用记录

1.函数 1) RK_MPI_VI_SetChnAttr _CAPI RK_S32 RK_MPI_VI_SetChnAttr(VI_PIPE ViPipe, VI_CHN ViChn,const VI_CHN_ATTR_S *pstChnAttr); /*VI通道属性结构体指针1)pcVideoNode:video节点路径2)u32BufCnt:VI捕获视频…

ROS学习笔记13:导航相关消息

前言 本人ROS小白,利用寒假时间学习ROS,在此以笔记的方式记录自己每天的学习过程。争取写满20篇(13/20)。 环境:Ubuntu20.04、ROS1:noetic 环境配置:严格按照下方学习链接的教程配置,基本一次成功。 学习链…

VBA技术资料MF116:测试操作系统是否为64位

我给VBA的定义:VBA是个人小型自动化处理的有效工具。利用好了,可以大大提高自己的工作效率,而且可以提高数据的准确度。我的教程一共九套,分为初级、中级、高级三大部分。是对VBA的系统讲解,从简单的入门,到…

洛谷P1039 [NOIP2003提高组]侦探推理

题目描述 明明同学最近迷上了侦探漫画《柯南》并沉醉于推理游戏之中,于是他召集了一群同学玩推理游戏。游戏的内容是这样的,明明的同学们先商量好由其中的一个人充当罪犯(在明明不知情的情况下),明明的任务就是找出这…

Android Studio 2022.3.1版本 引入包、maven等需要注意的问题

普通包 以前: // okhttp3 implementation com.squareup.okhttp3:okhttp:3.10.0 新版本: implementation("com.github.bumptech.glide:glide:3.7.0") libs文件夹中的包 以前: android {******sourceSets.main{jniLibs.srcDir…

使用SM4国密加密算法对Spring Boot项目数据库连接信息以及yaml文件配置属性进行加密配置(读取时自动解密)

一、前言 在业务系统开发过程中,我们必不可少的会使用数据库,在应用开发过程中,数据库连接信息往往都是以明文的方式配置到yaml配置文件中的,这样有密码泄露的风险,那么有没有什么方式可以避免呢?方案当然是有的,就是对数据库密码配置的时候进行加密,然后读取的时候再…

人工智能|推荐系统——基于tensorflow的个性化电影推荐系统实战(有前端)

代码下载: 基于tensorflow的个性化电影推荐系统实战(有前端).zip资源-CSDN文库 项目简介: dl_re_web : Web 项目的文件夹re_sys: Web app model:百度云下载之后,把model放到该文件夹下recommend: 网络模型相…

Python在小型无人机

Python在小型无人机的发展和研发中具有重要性。以下是几个原因: 简单易学:Python是一种简单易学的编程语言,具有简洁的语法和易于理解的语言结构。这使得开发人员可以更快速地理解和编写代码,从而加快了研发的进程。 多用途性&am…

Android 自定义BaseActivity

直接上代码: BaseActivity代码: package com.example.custom.activity;import android.annotation.SuppressLint; import android.app.Activity; import android.content.pm.ActivityInfo; import android.os.Bundle; import android.os.Looper; impor…

寒假作业-day5

1>现有无序序列数组为23,24,12,5,33,5347&#xff0c;请使用以下排序实现编程 函数1:请使用冒泡排序实现升序排序 函数2:请使用简单选择排序实现升序排序 函数3:请使用直接插入排序实现升序排序 函数4:请使用插入排序实现升序排序 代码&#xff1a; #include<stdio.h&g…

macbook电脑如何永久删除app软件?

在使用MacBook的过程中&#xff0c;我们经常会下载各种App来满足日常的工作和娱乐需求。然而&#xff0c;随着时间的积累&#xff0c;这些App不仅占据了宝贵的硬盘空间&#xff0c;还可能拖慢电脑的运行速度。那么&#xff0c;如何有效地管理和删除这些不再需要的App呢&#xf…

如何使用websocket

如何使用websocket 之前看到过一个面试题&#xff1a;吃饭点餐的小程序里&#xff0c;同一桌的用户点餐菜单如何做到的实时同步&#xff1f; 答案就是&#xff1a;使用websocket使数据变动时服务端实时推送消息给其他用户。 最近在我们自己的项目中我也遇到了类似问题&#xf…

使用CMSIS-DSP库进行嵌入式音频信号处理

在嵌入式环境下&#xff0c;使用CMSIS-DSP库进行音频信号处理是一种常见的应用场景。通过CMSIS-DSP库&#xff0c;开发人员可以利用嵌入式系统的处理能力来实现各种数字信号处理&#xff08;DSP&#xff09;功能&#xff0c;例如音频滤波、均衡器、噪音消除等。本文将介绍如何在…

问题 | IT行业有哪些证书含金量高?

IT行业有哪些证书含金量高? Cisco认证&#xff08;CCNA&#xff0c;CCNP&#xff0c;CCIE&#xff09;&#xff1a;思科是全球最大的网络设备供应商之一&#xff0c;它的认证证书在网络和通信领域被广泛认可。CCNA是初级认证&#xff0c;CCNP是高级认证&#xff0c;而CCIE是专…

NLP_Seq2Seq编码器-解码器架构

文章目录 Seq2Seq架构构建简单Seq2Seq架构1.构建实验语料库和词汇表2.生成Seq2Seq训练数据3. 定义编码器和解码器类4.定义Seq2Seq架构5. 训练Seq2Seq架构6.测试Seq2Seq架构 归纳Seq2Seq编码器-解码器架构小结 Seq2Seq架构 起初&#xff0c;人们尝试使用一个独立的RNN来解决这种…

CentOS7搭建Hadoop集群

准备工作 1、准备三台虚拟机&#xff0c;参考&#xff1a;CentOS7集群环境搭建&#xff08;3台&#xff09;-CSDN博客 2、配置虚拟机之间免密登录&#xff0c;参考&#xff1a;CentOS7集群配置免密登录-CSDN博客 3、虚拟机分别安装jdk&#xff0c;参考&#xff1a;CentOS7集…

【51单片机】实现一个动静态数码管显示项目(前置知识铺垫,代码&图演示)(5)

前言 大家好吖&#xff0c;欢迎来到 YY 滴单片机 系列 &#xff0c;热烈欢迎&#xff01; 本章主要内容面向接触过单片机的老铁 主要内容含&#xff1a; 欢迎订阅 YY滴C专栏&#xff01;更多干货持续更新&#xff01;以下是传送门&#xff01; YY的《C》专栏YY的《C11》专栏YY…

vue electron应用调exe程序

描述 用Python写了一个本地服务编译成exe程序&#xff0c;在electron程序启动后&#xff0c;自动执行exe程序 实现 1. 使用node的child_process模块可以执行windows执行&#xff0c;通过指令调exe程序 // electron/index.js var cp require("child_process"); /…