【Pytorch】计算机视觉项目——卷积神经网络TinyVGG模型图像分类(模型预测)

介绍

这篇文章是《【Pytorch】计算机视觉项目——卷积神经网络TinyVGG模型图像分类(如何使用自定义数据集)》的最后一部分内容:模型预测。

在本文中,我们将介绍如何测试模型的预测效果——让已训练好模型对一张新的图片进行分类;最后将整个流程打包,写成一个可以被直接调用的函数。

整个预测流程包括:

  • 图片下载
  • 图像转张量、图像数据变换
  • 使用训练好的模型进行预测
  • 预测结果输出

通过这些步骤,读者将能够进一步了解如何对已经训练好模型进行测试,以及了解模型是如何完成对图像的分类工作。


其他相关文章:

  • 深度学习入门笔记:总结了一些神经网络的基础概念。
  • TensorFlow专栏:《计算机视觉入门系列》介绍如何用TensorFlow框架实现卷积分类器。
  • 【Pytorch】整体工作流程代码详解(新手入门)
    在这里插入图片描述

图像处理和预测分步骤详解

1. 图片下载&路径设置

import requests# 设置文件路径
custom_image_path = data_path / "04-pizza-dad.jpeg"# 文件下载
if not custom_image_path.is_file():with open(custom_image_path, "wb") as f:request = requests.get("https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/main/images/04-pizza-dad.jpeg")print(f"Downloading {custom_image_path}...")f.write(request.content)
else:print(f"{custom_image_path} already exists, skipping download.")

2. 将图像转换成张量

import torchvision# 将图像转换成张量(未指定格式)
custom_image_uint8 = torchvision.io.read_image(str(custom_image_path))# 打印结果
print(f"Custom image tensor:\n{custom_image_uint8}\n")
print(f"Custom image shape: {custom_image_uint8.shape}\n")
print(f"Custom image dtype: {custom_image_uint8.dtype}")

![[04.1 8. 模型预测-20240605181741722.webp]]

图像数据的格式是torch.uint8, 表示范围在(0,255),通常用于表示图像的像素值。

而在深度学习模型中,通常使用 torch.float32 格式的输入,因为模型训练和推理时需要更高的数值精度和更广泛的表示范围。

因此,需要把格式转成精度更高的float32,这是模型所需要的格式。

# 载入图像,并将张量值转换为float32
custom_image = torchvision.io.read_image(str(custom_image_path)).type(torch.float32)# 将torch.uint8张量转换为torch.float32,并归一化到[0, 1]
custom_image = custom_image / 255. # 检查转换后的张量的数据
print(f"Custom image tensor:\n{custom_image}\n")
print(f"Custom image shape: {custom_image.shape}\n")
print(f"Custom image dtype: {custom_image.dtype}")

![[04.1 8. 模型预测-20240606124914487.webp]]在这里插入图片描述

# 图片展示
plt.imshow(custom_image.permute(1, 2, 0))
plt.title(f"Image shape: {custom_image.shape}")
plt.axis(False);

在这里插入图片描述
![[04.1 8. 模型预测-20240605181832312.webp]]
数据形状现在是[3,4032,4032], 我们还需要对它进行进一步的处理,使其能够匹配模型训练时使用的数据形状。

3. 图像变换

# 设置图像变换过程
custom_image_transform = transforms.Compose([transforms.Resize((64, 64)),
])# 图片转换
custom_image_transformed = custom_image_transform(custom_image)# 打印图片形状
print(f"Original shape: {custom_image.shape}")
print(f"New shape: {custom_image_transformed.shape}")

![[04.1 8. 模型预测-20240605181913674.webp]]

经过Transform过程,图片形状变成[3,64,64]。

原始形状为torch.Size([3, 4032, 3024]),这表示图像的高度为4032像素,宽度为3024像素,并且有3个通道(通常表示RGB通道)。

新的形状为torch.Size([3, 64, 64]),这表示经过调整后,图像的高度和宽度都变成了64像素,依然保持3个通道。

transforms.ToTensor():

  • 输入格式:对于彩色图像(RGB),输入通常是形状为 (H, W, 3) 的 numpy 数组或 PIL 图像,其中 H 是高度,W 是宽度,3 表示颜色通道(红、绿、蓝)。
  • 输出格式:形状为 (C, H, W)的PyTorch 张量,其中 C 是颜色通道数(通常为 3),H 是高度,W 是宽度。
  • 示例:假设有一张 RGB 图像,原始大小为 256x256,转换后为形状为 (3, 256, 256) 的张量,其中 3 表示 RGB 通道。如果是灰度图像,转换为 (1, H, W),因为灰度图像只有一个通道。

4. 模型预测

model_0.eval()with torch.inference_mode():# 给图像增加一个维度:batch sizecustom_image_transformed_with_batch_size = custom_image_transformed.unsqueeze(dim=0)  # 打印结果print(f"Custom image transformed shape: {custom_image_transformed.shape}")print(f"Unsqueezed custom image shape: {custom_image_transformed_with_batch_size.shape}")# 使用模型对图像进行分类预测custom_image_pred = model_0(custom_image_transformed.unsqueeze(dim=0).to(device))

![[04.1 8. 模型预测-20240605182651951.webp]]

  • custom_image_transformed.unsqueeze(dim=0) 因为在模型训练过程中,图像张量数据是按照批次导入模型训练的,模型适应的维度/形状是(N, C, H, W), 这里的N是批次的意思。因此, torch.unsqueeze(dim=0)给图像价

5. 预测结果输出

# 打印原始预测值logits
print(f"Prediction logits: {custom_image_pred}")# 将logits转换为预测概率-->模型预测的概率
custom_image_pred_probs = torch.softmax(custom_image_pred, dim=1)
print(f"Prediction probabilities: {custom_image_pred_probs}")# 将预测概率转换为预测标签
custom_image_pred_label = torch.argmax(custom_image_pred_probs, dim=1)
print(f"Prediction label: {custom_image_pred_label}")

![[04.1 8. 模型预测-20240605182220141.webp]]

  • torch.softmax(custom_image_pred, dim=1): 使用Softmax函数将logits转换为概率。Softmax函数将logits转换为0到1之间的概率值,并且所有概率值的总和为1。dim=1表示在类别维度上进行计算。
  • torch.argmax(custom_image_pred_probs, dim=1): 在概率最大的类别索引上取最大值,这个索引对应于模型预测的类别标签。
# 找出预测标签
custom_image_pred_class = class_names[custom_image_pred_label.cpu()] # put pred label to CPU, otherwise will errorcustom_image_pred_class

![[04.1 8. 模型预测-20240606123453519.webp]]

  • .cpu() 这里代码是在GPU上运行的,所以需要把预测标签移回CPU上。

创建预测函数(打包整个预测过程)

我们复习一下上面的步骤:

  1. 设置目标图像路径,并将其转换为适合我们模型的数据类型(torch.float32)。
  2. 确保目标图像的像素值在范围 [0, 1] 之内。
  3. 如有必要,对目标图像进行变换。
  4. 确保模型在指定的设备上。
  5. 使用训练好的模型对目标图像进行预测(确保图像尺寸正确,并与模型在同一设备上)。
  6. 将模型的输出logits转换为预测概率。
  7. 将预测概率转换为预测标签。
  8. 绘制目标图像,并显示模型的预测结果和预测概率。

接下来我们需要把这些步骤都打包到一个函数中,这样就能通过函数实现模型的预测功能。

def pred_and_plot_image(model: torch.nn.Module,image_path: str,class_names: List[str] = None,transform=None,device: torch.device = device):# 1. 载入图像,并将张量值转换为float32target_image = torchvision.io.read_image(str(image_path)).type(torch.float32)# 2. 将图像像素值除以255,使其在[0, 1]之间target_image = target_image / 255.# 3. 如有必要,进行图像变换if transform:target_image = transform(target_image)# 4. 确保模型在指定设备上model.to(device)# 5. 启用模型评估模式和推理模式model.eval()with torch.inference_mode():# 为图像添加一个维度target_image = target_image.unsqueeze(dim=0)# 对图像进行预测,并将其发送到指定设备target_image_pred = model(target_image.to(device))# 6. 将logits转换为预测概率(使用softmax进行多分类)target_image_pred_probs = torch.softmax(target_image_pred, dim=1)# 7. 将预测概率转换为预测标签target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)# 8. 绘制图像,并显示预测结果和预测概率plt.imshow(target_image.squeeze().permute(1, 2, 0))  # 调整图像以适应matplotlibif class_names:title = f"预测: {class_names[target_image_pred_label.cpu()]} | 概率: {target_image_pred_probs.max().cpu():.3f}"else:title = f"预测: {target_image_pred_label} | 概率: {target_image_pred_probs.max().cpu():.3f}"plt.title(title)plt.axis(False);
pred_and_plot_image(model=model_0,image_path=custom_image_path,class_names=class_names,transform=custom_image_transform,device=device)

![[04.1 8. 模型预测-20240606123753105.webp]]

最后结果展示了分类的标签,概率,以及经过处理后的图片。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/23801.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Docker面试整理-什么是Docker Hub?

Docker Hub 是一个由 Docker, Inc. 维护的公共镜像注册服务,它允许用户分享、存储和管理 Docker 镜像。Docker Hub 提供了一个中心化的资源库,用户可以从中拉取(下载)和推送(上传)镜像,这使得它成为分享和分发容器应用的重要平台。 Docker Hub 的主要功能包括: 镜像存储…

在 SEO 中,一个好的网页必须具备哪些 HTML 标签和属性?

搜索引擎优化 (SEO) 是涉及提高网站在搜索引擎上的可见性的过程。这是通过提高网站在搜索引擎结果页面(例如Google)上的排名来实现的。网站在这些页面上的显示位置越高,就越有可能获得更大的流量。 搜索引擎优化涉及了…

跑mask2former(自用)

1. 运行docker 基本命令: sudo docker ps -a (列出所有容器状态) sudo docker run -dit -v /hdd/lyh/mask2former:/mask --gpus "device0,1" --shm-size 16G --name mask 11.1:v6 (创建docker容器&…

Mac系统使用COLMAP

安装教程 如有出入,参照官网手册最新版 Installation — COLMAP 3.9-dev documentation 首先确保mac上安装了Homebrew 1.安装依赖项 brew install \cmake \ninja \boost \eigen \flann \freeimage \metis \glog \googletest \ceres-solver \qt5 \glew \cgal \s…

Python中Web表单和用户输入的处理

在现代Web应用程序中,处理用户输入和表单提交是必不可少的部分。在Python中,使用Flask框架可以非常方便地处理这些操作。本文将详细介绍如何在Flask中处理Web表单和用户输入,包括基本的表单创建、验证、提交和处理等方面。通过这些内容&#…

万里长城第一步——尚庭公寓【技术概述】

简略版: 项目概述主要是移动端(房源检索;预约看房,租赁管理,浏览历史)和后台管理(管理员对房源进行操作); 项目使用前后端分离的方法,主要以后端为主&#xf…

#05 深入Stable Diffusion的参数调整和优化技巧

文章目录 前言1. 理解关键参数2. 参数调整策略2.1 学习率调整2.2 批量大小优化2.3 迭代次数设置2.4 潜在空间维度选择 3. 优化技巧3.1 使用预训练模型3.2 数据增强3.3 模型微调 4. 实践建议结论 前言 Stable Diffusion作为一款强大的AI图像生成工具,其性能的优劣很…

centos如何压缩zip

在CentOS中,您可以使用zip命令来压缩文件或文件夹为ZIP格式。如果zip命令尚未安装,您可以通过执行以下命令来安装它: sudo yum install zip unzip压缩单个文件的基本命令格式为: zip [压缩后的文件名].zip [文件名]压缩一个文件…

rpm安装

rpm安装 命令格式: rpm 【选项】 文件名 选项: -i:安装软件 -v:显示安装过程信息 -h:用#表示安装进度,一个#代表2% -ivh:安装软件,显示安装过程 -e:卸载软件 -q:查看软件是否安装 -ql&#xff1…

什么是函数?在C语言中如何定义一个函数

函数是编程中用于执行特定任务的一组指令的集合。它有一个名称(即函数名),可以通过该名称在程序中多次调用该函数以执行相同的任务。这有助于提高代码的可重用性和可维护性。 在C语言中,函数的定义通常包括以下几个部分&#xff…

信息系统项目管理师0147:工具与技术(9项目范围管理—9.3规划范围管理—9.3.2工具与技术)

点击查看专栏目录 文章目录 9.3.2 工具与技术 9.3.2 工具与技术 专家判断 规划范围管理过程中,应征求具备如下领域相关专业知识或接受过相关培训的个人或小组 的意见,涉及的领域包括:以往类似项目;特定行业、学科和应用领域的信息…

UIScrollView的相关笔记

1. 当UIScrollview横向滚动时,如果在上面添加5个按钮,但当前scrollview 一页只能显示3个按钮,此时有一项要求,需要在点击第3个按钮时,scrollview自动向左滑动,显示后面的按钮等,需要在按钮点击方…

SpringMVC:Quartz常见问题

一、配置job的xml里<start-time>的时间格式 从源码JobSchedulingDataProcessor类中可以看出&#xff1a;格式例如&#xff08;2012-03-31T05:55:00&#xff09; /*** XML Schema dateTime datatype format.* <p>* See <a href"http://www.w3.org/TR/2001/…

学习anjuke的过程

一、抓包 先看看12.25.1版本的APP是不是还能使用&#xff0c;如果还能使用我们就先破解低版本的。打开APP后发现还能正常使用&#xff0c;因为低版本的难度低我们就破解这个版本。低版本和高版本的算法是一样的&#xff0c;算法破解之后我们后续抓包替换接口就行了。手机安装上…

SQLAlchemy 模型中数据的错误表示

1. 问题背景 在使用 SQLAlchemy 0.6.0 版本&#xff08;也曾尝试使用 0.6.4 版本&#xff09;的 Pylons 应用程序中遇到了一个 SQLAlchemy ORM 问题。该问题出现在使用 psycopg2 作为数据库驱动程序、连接至 Postgresql 8.2 数据库的环境中。定义了一个 User 模型对象&#xf…

FreeRTOS基础(十一):消息队列

本文将详细全方位的讲解FreeRTOS的消息队列&#xff0c;其实在FreeRTOS中消息队列的重要性也不言而喻&#xff0c;与FreeRTOS任务调度同等重要&#xff0c;因为后面的各种信号量基本都是基于消息队列的。 目录 一、消息队列的简介 1.1 产生的原因 1.2 消息队列的解决办法 …

【数据库】SQL零基础入门学习

人不走空 &#x1f308;个人主页&#xff1a;人不走空 &#x1f496;系列专栏&#xff1a;算法专题 ⏰诗词歌赋&#xff1a;斯是陋室&#xff0c;惟吾德馨 目录 &#x1f308;个人主页&#xff1a;人不走空 &#x1f496;系列专栏&#xff1a;算法专题 ⏰诗词歌…

重邮计算机网络803-(2)物理层

一.物理层 1.介绍 物理层的主要任务描述为确定与传输媒体的接口的一些特性&#xff0c;即&#xff1a; ①机械特性 指明接口所用接线器的形状和尺寸、引线数目和排列、固定和锁定装置等等。 ②电气特性 指明在接口电缆的各条线上出现的电压的范围。 ③功能特性 指明某条线上…

B=2W,奈奎斯特极限定理详解

一直没搞明白奈奎斯特极限定理的含义&#xff0c;网上搜了很久也没得到答案。最近深思几天后&#xff0c;终于有了点心得。顺便吐槽一下&#xff0c;csdn的提问栏目&#xff0c;有很多人用chatgpt秒回这个事&#xff0c;实在是解决不了问题&#xff0c;有时候人的问题大多数都是…

HDFS 之 DataNode 核心知识点

优质博文&#xff1a;IT-BLOG-CN 一、DataNode工作机制 DataNode工作机制&#xff0c;如下所示&#xff1a; 【1】一个数据块在 DataNode上以文件形式存储在磁盘上&#xff0c;包括两个文件&#xff0c;一个是数据本身&#xff0c;一个是元数据包括数据块的长度&#xff0c…