【Pytorch】计算机视觉项目——卷积神经网络TinyVGG模型图像分类(模型预测)

介绍

这篇文章是《【Pytorch】计算机视觉项目——卷积神经网络TinyVGG模型图像分类(如何使用自定义数据集)》的最后一部分内容:模型预测。

在本文中,我们将介绍如何测试模型的预测效果——让已训练好模型对一张新的图片进行分类;最后将整个流程打包,写成一个可以被直接调用的函数。

整个预测流程包括:

  • 图片下载
  • 图像转张量、图像数据变换
  • 使用训练好的模型进行预测
  • 预测结果输出

通过这些步骤,读者将能够进一步了解如何对已经训练好模型进行测试,以及了解模型是如何完成对图像的分类工作。


其他相关文章:

  • 深度学习入门笔记:总结了一些神经网络的基础概念。
  • TensorFlow专栏:《计算机视觉入门系列》介绍如何用TensorFlow框架实现卷积分类器。
  • 【Pytorch】整体工作流程代码详解(新手入门)
    在这里插入图片描述

图像处理和预测分步骤详解

1. 图片下载&路径设置

import requests# 设置文件路径
custom_image_path = data_path / "04-pizza-dad.jpeg"# 文件下载
if not custom_image_path.is_file():with open(custom_image_path, "wb") as f:request = requests.get("https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/main/images/04-pizza-dad.jpeg")print(f"Downloading {custom_image_path}...")f.write(request.content)
else:print(f"{custom_image_path} already exists, skipping download.")

2. 将图像转换成张量

import torchvision# 将图像转换成张量(未指定格式)
custom_image_uint8 = torchvision.io.read_image(str(custom_image_path))# 打印结果
print(f"Custom image tensor:\n{custom_image_uint8}\n")
print(f"Custom image shape: {custom_image_uint8.shape}\n")
print(f"Custom image dtype: {custom_image_uint8.dtype}")

![[04.1 8. 模型预测-20240605181741722.webp]]

图像数据的格式是torch.uint8, 表示范围在(0,255),通常用于表示图像的像素值。

而在深度学习模型中,通常使用 torch.float32 格式的输入,因为模型训练和推理时需要更高的数值精度和更广泛的表示范围。

因此,需要把格式转成精度更高的float32,这是模型所需要的格式。

# 载入图像,并将张量值转换为float32
custom_image = torchvision.io.read_image(str(custom_image_path)).type(torch.float32)# 将torch.uint8张量转换为torch.float32,并归一化到[0, 1]
custom_image = custom_image / 255. # 检查转换后的张量的数据
print(f"Custom image tensor:\n{custom_image}\n")
print(f"Custom image shape: {custom_image.shape}\n")
print(f"Custom image dtype: {custom_image.dtype}")

![[04.1 8. 模型预测-20240606124914487.webp]]在这里插入图片描述

# 图片展示
plt.imshow(custom_image.permute(1, 2, 0))
plt.title(f"Image shape: {custom_image.shape}")
plt.axis(False);

在这里插入图片描述
![[04.1 8. 模型预测-20240605181832312.webp]]
数据形状现在是[3,4032,4032], 我们还需要对它进行进一步的处理,使其能够匹配模型训练时使用的数据形状。

3. 图像变换

# 设置图像变换过程
custom_image_transform = transforms.Compose([transforms.Resize((64, 64)),
])# 图片转换
custom_image_transformed = custom_image_transform(custom_image)# 打印图片形状
print(f"Original shape: {custom_image.shape}")
print(f"New shape: {custom_image_transformed.shape}")

![[04.1 8. 模型预测-20240605181913674.webp]]

经过Transform过程,图片形状变成[3,64,64]。

原始形状为torch.Size([3, 4032, 3024]),这表示图像的高度为4032像素,宽度为3024像素,并且有3个通道(通常表示RGB通道)。

新的形状为torch.Size([3, 64, 64]),这表示经过调整后,图像的高度和宽度都变成了64像素,依然保持3个通道。

transforms.ToTensor():

  • 输入格式:对于彩色图像(RGB),输入通常是形状为 (H, W, 3) 的 numpy 数组或 PIL 图像,其中 H 是高度,W 是宽度,3 表示颜色通道(红、绿、蓝)。
  • 输出格式:形状为 (C, H, W)的PyTorch 张量,其中 C 是颜色通道数(通常为 3),H 是高度,W 是宽度。
  • 示例:假设有一张 RGB 图像,原始大小为 256x256,转换后为形状为 (3, 256, 256) 的张量,其中 3 表示 RGB 通道。如果是灰度图像,转换为 (1, H, W),因为灰度图像只有一个通道。

4. 模型预测

model_0.eval()with torch.inference_mode():# 给图像增加一个维度:batch sizecustom_image_transformed_with_batch_size = custom_image_transformed.unsqueeze(dim=0)  # 打印结果print(f"Custom image transformed shape: {custom_image_transformed.shape}")print(f"Unsqueezed custom image shape: {custom_image_transformed_with_batch_size.shape}")# 使用模型对图像进行分类预测custom_image_pred = model_0(custom_image_transformed.unsqueeze(dim=0).to(device))

![[04.1 8. 模型预测-20240605182651951.webp]]

  • custom_image_transformed.unsqueeze(dim=0) 因为在模型训练过程中,图像张量数据是按照批次导入模型训练的,模型适应的维度/形状是(N, C, H, W), 这里的N是批次的意思。因此, torch.unsqueeze(dim=0)给图像价

5. 预测结果输出

# 打印原始预测值logits
print(f"Prediction logits: {custom_image_pred}")# 将logits转换为预测概率-->模型预测的概率
custom_image_pred_probs = torch.softmax(custom_image_pred, dim=1)
print(f"Prediction probabilities: {custom_image_pred_probs}")# 将预测概率转换为预测标签
custom_image_pred_label = torch.argmax(custom_image_pred_probs, dim=1)
print(f"Prediction label: {custom_image_pred_label}")

![[04.1 8. 模型预测-20240605182220141.webp]]

  • torch.softmax(custom_image_pred, dim=1): 使用Softmax函数将logits转换为概率。Softmax函数将logits转换为0到1之间的概率值,并且所有概率值的总和为1。dim=1表示在类别维度上进行计算。
  • torch.argmax(custom_image_pred_probs, dim=1): 在概率最大的类别索引上取最大值,这个索引对应于模型预测的类别标签。
# 找出预测标签
custom_image_pred_class = class_names[custom_image_pred_label.cpu()] # put pred label to CPU, otherwise will errorcustom_image_pred_class

![[04.1 8. 模型预测-20240606123453519.webp]]

  • .cpu() 这里代码是在GPU上运行的,所以需要把预测标签移回CPU上。

创建预测函数(打包整个预测过程)

我们复习一下上面的步骤:

  1. 设置目标图像路径,并将其转换为适合我们模型的数据类型(torch.float32)。
  2. 确保目标图像的像素值在范围 [0, 1] 之内。
  3. 如有必要,对目标图像进行变换。
  4. 确保模型在指定的设备上。
  5. 使用训练好的模型对目标图像进行预测(确保图像尺寸正确,并与模型在同一设备上)。
  6. 将模型的输出logits转换为预测概率。
  7. 将预测概率转换为预测标签。
  8. 绘制目标图像,并显示模型的预测结果和预测概率。

接下来我们需要把这些步骤都打包到一个函数中,这样就能通过函数实现模型的预测功能。

def pred_and_plot_image(model: torch.nn.Module,image_path: str,class_names: List[str] = None,transform=None,device: torch.device = device):# 1. 载入图像,并将张量值转换为float32target_image = torchvision.io.read_image(str(image_path)).type(torch.float32)# 2. 将图像像素值除以255,使其在[0, 1]之间target_image = target_image / 255.# 3. 如有必要,进行图像变换if transform:target_image = transform(target_image)# 4. 确保模型在指定设备上model.to(device)# 5. 启用模型评估模式和推理模式model.eval()with torch.inference_mode():# 为图像添加一个维度target_image = target_image.unsqueeze(dim=0)# 对图像进行预测,并将其发送到指定设备target_image_pred = model(target_image.to(device))# 6. 将logits转换为预测概率(使用softmax进行多分类)target_image_pred_probs = torch.softmax(target_image_pred, dim=1)# 7. 将预测概率转换为预测标签target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)# 8. 绘制图像,并显示预测结果和预测概率plt.imshow(target_image.squeeze().permute(1, 2, 0))  # 调整图像以适应matplotlibif class_names:title = f"预测: {class_names[target_image_pred_label.cpu()]} | 概率: {target_image_pred_probs.max().cpu():.3f}"else:title = f"预测: {target_image_pred_label} | 概率: {target_image_pred_probs.max().cpu():.3f}"plt.title(title)plt.axis(False);
pred_and_plot_image(model=model_0,image_path=custom_image_path,class_names=class_names,transform=custom_image_transform,device=device)

![[04.1 8. 模型预测-20240606123753105.webp]]

最后结果展示了分类的标签,概率,以及经过处理后的图片。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/23801.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

在 SEO 中,一个好的网页必须具备哪些 HTML 标签和属性?

搜索引擎优化 (SEO) 是涉及提高网站在搜索引擎上的可见性的过程。这是通过提高网站在搜索引擎结果页面(例如Google)上的排名来实现的。网站在这些页面上的显示位置越高,就越有可能获得更大的流量。 搜索引擎优化涉及了…

跑mask2former(自用)

1. 运行docker 基本命令: sudo docker ps -a (列出所有容器状态) sudo docker run -dit -v /hdd/lyh/mask2former:/mask --gpus "device0,1" --shm-size 16G --name mask 11.1:v6 (创建docker容器&…

Mac系统使用COLMAP

安装教程 如有出入,参照官网手册最新版 Installation — COLMAP 3.9-dev documentation 首先确保mac上安装了Homebrew 1.安装依赖项 brew install \cmake \ninja \boost \eigen \flann \freeimage \metis \glog \googletest \ceres-solver \qt5 \glew \cgal \s…

万里长城第一步——尚庭公寓【技术概述】

简略版: 项目概述主要是移动端(房源检索;预约看房,租赁管理,浏览历史)和后台管理(管理员对房源进行操作); 项目使用前后端分离的方法,主要以后端为主&#xf…

rpm安装

rpm安装 命令格式: rpm 【选项】 文件名 选项: -i:安装软件 -v:显示安装过程信息 -h:用#表示安装进度,一个#代表2% -ivh:安装软件,显示安装过程 -e:卸载软件 -q:查看软件是否安装 -ql&#xff1…

信息系统项目管理师0147:工具与技术(9项目范围管理—9.3规划范围管理—9.3.2工具与技术)

点击查看专栏目录 文章目录 9.3.2 工具与技术 9.3.2 工具与技术 专家判断 规划范围管理过程中,应征求具备如下领域相关专业知识或接受过相关培训的个人或小组 的意见,涉及的领域包括:以往类似项目;特定行业、学科和应用领域的信息…

学习anjuke的过程

一、抓包 先看看12.25.1版本的APP是不是还能使用,如果还能使用我们就先破解低版本的。打开APP后发现还能正常使用,因为低版本的难度低我们就破解这个版本。低版本和高版本的算法是一样的,算法破解之后我们后续抓包替换接口就行了。手机安装上…

SQLAlchemy 模型中数据的错误表示

1. 问题背景 在使用 SQLAlchemy 0.6.0 版本(也曾尝试使用 0.6.4 版本)的 Pylons 应用程序中遇到了一个 SQLAlchemy ORM 问题。该问题出现在使用 psycopg2 作为数据库驱动程序、连接至 Postgresql 8.2 数据库的环境中。定义了一个 User 模型对象&#xf…

FreeRTOS基础(十一):消息队列

本文将详细全方位的讲解FreeRTOS的消息队列,其实在FreeRTOS中消息队列的重要性也不言而喻,与FreeRTOS任务调度同等重要,因为后面的各种信号量基本都是基于消息队列的。 目录 一、消息队列的简介 1.1 产生的原因 1.2 消息队列的解决办法 …

【数据库】SQL零基础入门学习

人不走空 🌈个人主页:人不走空 💖系列专栏:算法专题 ⏰诗词歌赋:斯是陋室,惟吾德馨 目录 🌈个人主页:人不走空 💖系列专栏:算法专题 ⏰诗词歌…

重邮计算机网络803-(2)物理层

一.物理层 1.介绍 物理层的主要任务描述为确定与传输媒体的接口的一些特性,即: ①机械特性 指明接口所用接线器的形状和尺寸、引线数目和排列、固定和锁定装置等等。 ②电气特性 指明在接口电缆的各条线上出现的电压的范围。 ③功能特性 指明某条线上…

B=2W,奈奎斯特极限定理详解

一直没搞明白奈奎斯特极限定理的含义,网上搜了很久也没得到答案。最近深思几天后,终于有了点心得。顺便吐槽一下,csdn的提问栏目,有很多人用chatgpt秒回这个事,实在是解决不了问题,有时候人的问题大多数都是…

HDFS 之 DataNode 核心知识点

优质博文:IT-BLOG-CN 一、DataNode工作机制 DataNode工作机制,如下所示: 【1】一个数据块在 DataNode上以文件形式存储在磁盘上,包括两个文件,一个是数据本身,一个是元数据包括数据块的长度&#xff0c…

前端 JS 经典:图片裁剪上传原理

前言:图片裁剪一般都是用户选择头像时用到,现在很多插件都可以满足这个功能,但是我们不仅要会用插件,还要自己懂的裁剪原理。 1. 流程 流程分为:1. 预览本地图片 2. 选择裁剪区域 3. 上传裁剪图像 2. 如何预览图片 …

小熊家务帮day10-day12 门户管理(缓存,主页,定时任务)

门户管理 1 门户介绍1.1 介绍1.2 常用技术方案 2 缓存技术方案2.1 需求分析2.1.1 C端用户界面原型2.1.2 缓存需求2.1.3 使用的工具 2.2 项目基础使用2.2.1 项目集成SpringCache2.2.2 测试Cacheable需求Service测试 2.1.3 缓存管理器(设置过期时间)2.1.4 …

计算机毕业设计PySpark+Hadoop地震预测系统 地震数据分析可视化 地震爬虫 大数据毕业设计 Flink Hadoop 深度学习

基于Hadoop的地震预测的 分析与可视化研究 姓 名:____田伟情_________ 系 别:____信息技术学院___ 专 业:数据科学与大数据技术 学 号:__2011103094________ 指导教师:_____王双喜________ 年 月 日 …

sqli-labs 靶场 less-5、6 第五关和第六关:判断注入点、使用错误函数注入爆库名、updatexml()函数

SQLi-Labs是一个用于学习和练习SQL注入漏洞的开源应用程序。通过它,我们可以学习如何识别和利用不同类型的SQL注入漏洞,并了解如何修复和防范这些漏洞。Less 5 SQLI DUMB SERIES-5 判断注入点:1. 首先,尝试正常的回显内容&#x…

Hadoop3:MapReduce源码解读之Map阶段的TextInputFormat切片机制(3)

Job那块的断点代码截图省略,直接进入切片逻辑 参考:Hadoop3:MapReduce源码解读之Map阶段的Job任务提交流程(1) 5、TextInputFormat源码解析 类的继承关系 它的内容比较少 重写了两个父类的方法 这里关心一下泛型参数…

【Python报错】已解决Attributeerror: ‘list‘ object has no attribute ‘join‘( Solved)

解决Python报错:AttributeError: ‘list’ object has no attribute ‘join’ (Solved) 在Python中,字符串(str)对象有一个非常有用的join()方法,它允许你将序列中的元素连接(join)成一个字符串…

机器学习笔记 - 本地windows 11 + PyCharm运行stable diffusion流程简述

一、环境说明 硬件:本地电脑windows11、32.0 GB内存、2060的6G的卡。 软件:本地有一个python环境,主要是torch 2.2.2+cu118 二、准备工作 1、下载模型 https://huggingface.co/CompVishttps://huggingface.co/CompVis 进入上面的网址,我这里下载的是这个里面的 …