Onnx使用预训练的 ResNet18 模型对输入图像进行分类,并将分类结果显示在图像上

目录

一、整体功能概述

二、函数分析

2.1 resnet() 函数:

2.2 pre_process(img_path) 函数:

2.3 loadOnnx(img_path) 函数:

三、代码执行流程


一、整体功能概述


这段代码实现了一个图像分类系统,使用预训练的 ResNet18 模型对输入图像进行分类,并将分类结果显示在图像上。它包括以下主要步骤:
读取一个包含类别名称和对应编号的文本文件,并将其存储在字典中。
定义了几个函数,包括模型导出函数 resnet()、图像预处理函数 pre_process() 和加载 ONNX 模型进行分类的函数 loadOnnx()。
在主程序中,指定输入图像路径,调用 loadOnnx() 函数对图像进行分类并显示结果。


二、函数分析


2.1 resnet() 函数:


使用 torchvision 中的预训练 ResNet18 模型,并设置为评估模式。
生成一个随机输入张量 x,并将模型导出为 ONNX 格式,保存为 models/resnet18.onnx 文件。

def resnet():model=models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)model.eval()x=torch.randn(1,3,224,224)torch.onnx.export(model,x,'models/resnet18.onnx',input_names=['input'],output_names=['output'])


2.2 pre_process(img_path) 函数:


读取输入图像 img_path。
调整图像大小为 224x224。
将图像颜色通道从 BGR 转换为 RGB。
对图像像素值进行归一化处理。
交换图像维度顺序,并增加一个维度。
返回预处理后的图像张量。

def pre_process(img_path):#h w c--->224,224,3#归一化#换轴#增加维度img=cv2.imread(img_path)scale_image=cv2.resize(img,dsize=(224,224))rgb_img=cv2.cvtColor(scale_image,cv2.COLOR_BGR2RGB)rgb_img=rgb_img/255rgb_img=np.transpose(rgb_img,(2,0,1))rgb_img=np.expand_dims(rgb_img,0).astype(np.float32)return rgb_img


2.3 loadOnnx(img_path) 函数:


创建一个 ONNX 推理会话,加载预导出的 ResNet18 ONNX 模型。

调用 pre_process() 函数对输入图像进行预处理。
准备输入数据并进行推理。
获取推理结果中概率最大的类别编号。
根据类别编号从字典中获取对应的类别名称,并进行翻译。
在输入图像上显示分类结果,并展示图像。

def loadOnnx(img_path):session=ort.InferenceSession(r'models\resnet18.onnx',providers=['CPUExecutionProvider'])img=pre_process(img_path)img_back=cv2.imread(img_path)intput_feed={'input':img}session_out=session.run(None,intput_feed)[0]out=np.argmax(session_out,axis=1)[0]res=str(out)# print(dict[res])ans=dict[res].split(',')[1].split(']')[0].strip()ans = translator.translate(ans)cv2.putText(img_back,ans,(100,100),fontFace=1,fontScale=2.0,color=(0,0,255),thickness=3,lineType=cv2.LINE_AA)cv2.imshow('win',img_back)cv2.waitKey(0)cv2.destroyAllWindows()print(ans)

完整代码如下

import cv2
import numpy as np
import torch
from torchvision import models
from torchvision.models import ResNet18_Weights
import onnxruntime as ort
from translate import Translator
translator=Translator(to_lang='Chinese')#翻译成中文
dict={}
with open('类别.txt','r',encoding='utf-8') as f:lines=f.readlines()for line in lines:name=line.split('\t')[0]value=line.split('\t')[1]dict[name]=value
# print(dict)
def resnet():model=models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)model.eval()x=torch.randn(1,3,224,224)torch.onnx.export(model,x,'models/resnet18.onnx',input_names=['input'],output_names=['output'])
def pre_process(img_path):#h w c--->224,224,3#归一化#换轴#增加维度img=cv2.imread(img_path)scale_image=cv2.resize(img,dsize=(224,224))rgb_img=cv2.cvtColor(scale_image,cv2.COLOR_BGR2RGB)rgb_img=rgb_img/255rgb_img=np.transpose(rgb_img,(2,0,1))rgb_img=np.expand_dims(rgb_img,0).astype(np.float32)return rgb_img#RGB
def loadOnnx(img_path):session=ort.InferenceSession(r'models\resnet18.onnx',providers=['CPUExecutionProvider'])img=pre_process(img_path)img_back=cv2.imread(img_path)intput_feed={'input':img}session_out=session.run(None,intput_feed)[0]out=np.argmax(session_out,axis=1)[0]res=str(out)# print(dict[res])ans=dict[res].split(',')[1].split(']')[0].strip()ans = translator.translate(ans)cv2.putText(img_back,ans,(100,100),fontFace=1,fontScale=2.0,color=(0,0,255),thickness=3,lineType=cv2.LINE_AA)cv2.imshow('win',img_back)cv2.waitKey(0)cv2.destroyAllWindows()print(ans)pass
if __name__ == '__main__':img_path='dog.png'# resnet()#导出模型loadOnnx(img_path)


三、代码执行流程


在 if __name__ == '__main__': 部分:
定义输入图像路径 img_path。
可以选择调用 resnet() 函数导出模型(注释状态,通常只在第一次运行或模型更新时使用)。
调用 loadOnnx(img_path) 函数对输入图像进行分类和显示结果。

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/877802.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

计算机毕业设计hadoop+spark+hive漫画推荐系统 动漫视频推荐系统 漫画分析可视化大屏 漫画爬虫 漫画推荐系统 漫画爬虫 知识图谱 大数据

流程:1.DrissionPageSelenium自动爬虫工具采集漫画视频、详情、标签等约200万条漫画数据存入mysql数据库; 2.Mapreduce对采集的动漫数据进行数据清洗、拆分数据项等,转为.csv文件上传hadoop的hdfs集群; 3.hive建库建表导入.csv动漫…

驱动开发系列11 - Linux Graphics 图形栈概述(二)

目录 一:GPU 和 硬件 现代 GPU 功能概览: 硬件结构: 屏幕驱动: 屏幕连接器: 屏幕 CRT 控制器: CPU与GPU通信: 现代主机通信总线介绍: 通信方法: GPU 编程:通过 MMIO 访问寄存器 CPU 和 GPU 内存请求路由: GPU 可访问的内存区域: GTT/GART 是 CPU 与 GPU 共享的通信缓…

R语言VAR模型的多行业关联与溢出效应可视化分析

全文链接:https://tecdat.cn/?p37397 摘要:本文对医疗卫生、通信、金融、房地产和零售等行业的数据展开深入研究。通过读取数据、计算收益率、构建 VAR 模型并进行估计,帮助客户进一步分析各行业变量的影响及残差的协方差与相关矩阵&#xf…

框架——MyBatis查询(单表查询,多表联查)

目录 1.单表查询 2.多表查询 想查询student并且查询student所选择的专业major ①通过id查一个 ②不传入值直接查所有的学生列表 ③嵌套查询 想查询专业major并且查询该专业被哪些学生student选择 ①通过id查一个 ②不传入值直接查所有的专业列表 ③嵌套查询 3. 设置自动…

IOS半越狱工具nathanlr越狱教程

简介 nathanlr 是一款半越狱工具,不是完整越狱。 半越狱只能使用一些系统范围的插件。 无法做到完整越狱 Dopamine 越狱一样插件兼容性。 nathanlr支持 iOS 16.5.1 – 16.6.1 系统。 支持 A12 及以上设备。 肯定有人问,为什么仅仅支持这些系统&#xff…

嵌入式学习day33

tcp的特点 面向字节流特点,会造成可能数据与数据发送到一块,成为粘包,数据之间不区分 拆包 因为缓冲区的大小,一次性发送的数据会进行拆分(大小不符合的时候) 就和水一样一次拆一次沾到一块&#xff0c…

测试用例的设计

*涉及概念来源于《软件测试的艺术》 目录 一、为什么要设计测试用例? 二、黑盒测试与白盒测试介绍 三、测试用例常见设计方法 1.黑盒测试(功能测试) 2.白盒测试(结构测试) 四、测试策略 五、测试用例怎么写 一、为什么要设计测试用例? 由于时间…

Git 的基本使用

1.创建 Git 本地仓库 仓库是进⾏版本控制的⼀个⽂件⽬录。我们要想对⽂件进⾏版本控制,就必须先创建⼀个仓库出来,例如下面代码创建了gitcode_linux的文件夹,之后再对其进行初始化。创建⼀个 Git 本地仓库对应的命令为 git init &#xff0c…

【注解】反序列化时匹配多个 JSON 属性名 @JsonAlias 详解

JsonAlias 注解是 Jackson 提供的一个功能强大的注解,允许一个字段在反序列化时匹配多个 JSON 属性名。它适用于在处理多种输入数据格式时,或当 JSON 数据的键名可能变化时。 一、JsonAlias 的作用 多种别名:JsonAlias 允许你为一个字段定义…

ZNS SSD是不是持久缓存的理想选择?

随着数据量的增加和技术的进步,对于高效、可靠的存储解决方案的需求日益增长。传统的基于块的SSD虽然具有成本效益和持久性的优点,但在处理写密集型和更新密集型工作负载时存在局限性。 NAND闪存的特点是数据只能按页(例如4KiB)写…

2024年最新最全的【大模型学习路线规划】从零基础入门到精通!

2024年最新最全的大模型学习路线规划,对于零基础入门到精通的学习者来说,可以遵循以下阶段进行: 文章目录 一、基础准备阶段数学基础:编程语言:深度学习基础: 二、核心技术学习阶段Transformer模型&#xf…

[Linux#41][线程] 线程的特性 | 分离线程 | 并发的问题

1.线程的特性 进程和线程的关系如下图: 关于进程线程的问题 • 如何看待之前学习的单进程?具有一个线程执行流的进程 线程 ID 及进程地址空间布局 pthread_ create 函数会产生一个线程 ID,存放在第一个参数指向的地址中。 该线程 ID 和前面说的线程 ID …

动手实现基于Reactor模型的高并发Web服务器(一):epoll+多线程版本

系统流程概览 main函数 对于一个服务器程序来说,因为要为外部的客户端程序提供网络服务,也就是进行数据的读写,这就必然需要一个 socket 文件描述符,只有拥有了文件描述符 C/S 两端才能通过 socket 套接字进行网络通信&#xff0…

4.Redis单线程和多线程

1.Redis的单线程 Redis的单线程主要是指Redis的网络IO和键值对读写是由一个线程完成的,Redis在处理客户端的请求时包括获取(Socket读)、解析、执行、内容返回(Socket写)等都由一个顺序串行的主线程处理,这…

ProxySQL 读写分离配置

ProxySQL 是一个高性能、高可用的 MySQL 代理软件,旨在提升 MySQL 数据库的可扩展性和性能。它可以在应用程序和 MySQL 服务器之间充当中间层,提供强大的路由、负载均衡和查询优化功能。 ProxySQL 的主要功能: 查询路由: ProxySQ…

市盈率的概念

写篇有关市盈率的【不务正业】的内容。 重要公式 市盈率 官方的定义 平均市盈率=∑(收盘价发行数量)/∑(每股收益发行数量),统计时剔除亏损及暂停上市的上市公司。 静态市盈率 滚动市盈率(TTM) 股票市盈率的意义 如果某股票有较…

培训第三十四天(初步了解Docker与套接字的应用)

上午 回顾 1、主从复制(高可用) 2、传统的主从复制 3、gtids事务型的主从复制 4、注意 1)server_id唯一 2)8.X版本需要get_ssl_pub_key 3)5.X不需要 4)change master to 5)stop | sta…

拍抖音在哪里去水印,三招教你快速掌握去水印技巧

在抖音上,我们经常会看到一些精彩的内容,想要保存下来,但往往视频上会有水印。本文将分享五个免费且高效的去除抖音视频水印的技巧,帮助你轻松保存无水印的视频。 技巧一:奈斯水印助手(小程序) 奈斯水印助手是一款专…

JavaScript(30)——解构

数组解构 数组解构是将数组的单元值快速批量赋值给一系列变量的简洁语法 基本语法: 赋值运算符左侧的[]用于批量声明变量,右侧数组的单元值将被赋值给左侧变量变量的顺序对应数组单元值的位置依次进行赋值操作 const arr [1, 2, 3, 4, 5]const [a, b…

云渲染的三个条件是指什么!哪三点最重要!

云渲染技术以其灵活性和效率,让创意人士和企业无论身处何地,都能通过网络接入强大的远程服务器,轻松完成复杂的图形渲染任务,但要发挥其魔力,我们得满足一些关键条件。 一、网络连接:云渲染的桥梁 首先&am…