【深度学习】【图像分类】【OnnxRuntime】【Python】VggNet模型部署

【深度学习】【图像分类】【OnnxRuntime】【Python】VggNet模型部署

提示:博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论

文章目录

  • 【深度学习】【图像分类】【OnnxRuntime】【Python】VggNet模型部署
  • 前言
  • Windows平台搭建依赖环境
  • 模型转换--pytorch转onnx
  • ONNXRuntime推理代码
  • 总结


前言

本期将讲解深度学习图像分类网络VggNet模型的部署,对于该算法的基础知识,可以参考博主【VggNet模型算法Pytorch版本详解】博文。
读者可以通过学习 【onnx部署】部署系列学习文章目录的onnxruntime系统学习–Python篇 的内容,系统的学习OnnxRuntime部署不同任务的onnx模型。


Windows平台搭建依赖环境

在【入门基础篇】中详细的介绍了onnxruntime环境的搭建以及ONNXRuntime推理核心流程代码,不再重复赘述。


模型转换–pytorch转onnx

import torch
import torchvision as tv
def resnet2onnx():# 使用torch提供的预训练权重 1000分类model = tv.models.vgg16(pretrained=True)model.eval()model.cpu()dummy_input1 = torch.randn(1, 3, 224, 224)torch.onnx.export(model, (dummy_input1), "vgg16.onnx", verbose=True, opset_version=11)
if __name__ == "__main__":resnet2onnx()


如下图,torchvision本身提供了不少经典的网络,为了减少教学复杂度,这里博主直接使用了torchvision提供的ResNet网络,并下载和加载了它提供的训练权重。这里可以替换成自己的搭建的ResNet网络以及自己训练的训练权重。


ONNXRuntime推理代码

需要配置imagenet_classes.txt【百度云下载,提取码:rkz7 】文件存储1000类分类标签,假设是用户自定的分类任务,需要根据实际情况作出修改,并将其放置到工程目录下(推荐)。

这里需要将vgg16.onnx放置到工程目录下(推荐),并且将以下推理代码拷贝到新建的py文件中,并执行查看结果。

import onnxruntime as ort
import cv2
import numpy as np# 加载标签文件获得分类标签
def read_class_names(file_path="./imagenet_classes.txt"):class_names = []try:with open(file_path, 'r') as fp:for line in fp:name = line.strip()if name:class_names.append(name)except IOError:print("could not open file...")import syssys.exit(-1)return class_names# 主函数
def main():# 预测的目标标签数labels = read_class_names()# 测试图片image_path = "./lion.jpg"image = cv2.imread(image_path)# cv2.imshow("输入图", image)# cv2.waitKey(0)# 设置会话选项sess_options = ort.SessionOptions()# 0=VERBOSE, 1=INFO, 2=WARN, 3=ERROR, 4=FATALsess_options.log_severity_level = 3# 优化器级别:基本的图优化级别sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC# 线程数:4sess_options.intra_op_num_threads = 4# 设备使用优先使用GPU而是才是CPU,列表中的顺序决定了执行提供者的优先级providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']# onnx训练模型文件onnxpath = "./vgg16.onnx"# 加载模型并创建会话session = ort.InferenceSession(onnxpath, sess_options=sess_options, providers=providers)input_nodes_num = len(session.get_inputs())     # 输入节点输output_nodes_num = len(session.get_outputs())   # 输出节点数input_node_names = []                           # 输入节点名称output_node_names = []                          # 输出节点名称# 获取模型输入信息for i in range(input_nodes_num):# 获得输入节点的名称并存储input_name = session.get_inputs()[i].nameinput_node_names.append(input_name)# 显示输入图像的形状input_shape = session.get_inputs()[i].shapech, input_h, input_w = input_shape[1], input_shape[2], input_shape[3]print(f"input format: {ch}x{input_h}x{input_w}")# 获取模型输出信息for i in range(output_nodes_num):# 获得输出节点的名称并存储output_name = session.get_outputs()[i].nameoutput_node_names.append(output_name)# 显示输出结果的形状output_shape = session.get_outputs()[i].shapenum, nc = output_shape[0], output_shape[1]print(f"output format: {num}x{nc}")input_shape = session.get_inputs()[0].shapeinput_h, input_w = input_shape[2], input_shape[3]print(f"input format: {input_shape[1]}x{input_h}x{input_w}")# 预处理输入数据# 默认是BGR需要转化成RGBrgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)# 对图像尺寸进行缩放blob = cv2.resize(rgb, (input_w, input_h))blob = blob.astype(np.float32)# 对图像进行标准化处理blob /= 255.0   # 归一化blob -= np.array([0.485, 0.456, 0.406])  # 减去均值blob /= np.array([0.229, 0.224, 0.225])  # 除以方差#CHW-->NCHW 维度扩展timg = cv2.dnn.blobFromImage(blob)# ---blobFromImage 可以用以下替换---# blob = blob.transpose(2, 0, 1)# blob = np.expand_dims(blob, axis=0)# -------------------------------# 模型推理try:ort_outputs = session.run(output_names=output_node_names, input_feed={input_node_names[0]: timg})except Exception as e:print(e)ort_outputs = None# 后处理推理结果prob = ort_outputs[0]max_index = np.argmax(prob)     # 获得最大值的索引print(f"label id: {max_index}")# 在测试图像上加上预测的分类标签label_text = labels[max_index]cv2.putText(image, label_text, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2, 8)cv2.imshow("输入图像", image)cv2.waitKey(0)if __name__ == '__main__':main()

图片预测为猎豹(cheetah),没有准确预测出狮子(lion),但是这个图片难度很大,在1000分类中预测的比较接近的。

其实图像分类网络的部署代码基本是一致的,几乎不需要修改,只需要修改传入的图片数据已经训练模型权重即可。


总结

尽可能简单、详细的讲解了Python下onnxruntime环境部署VggNet模型的过程。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/53791.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

走进低代码表单开发(一):可视化表单数据源设计

在前文,我们已对勤研低代码平台的报表功能做了详细介绍。接下来,让我们深入探究低代码开发中最为常用的表单设计功能。一个完整的应用是由众多表单组合而成的,所以高效的表单设计在开发过程中起着至关重要的作用。让我们一同了解勤研低代码开…

[网络]http/https的简单认识

文章目录 一. 什么是http二. http协议工作过程三. http协议格式1. 抓包工具fiddler2. http请求报文3. http响应报文 一. 什么是http HTTP (全称为 “超⽂本传输协议”) 是⼀种应⽤⾮常⼴泛的 应⽤层协议 HTTP 诞⽣与1991年. ⽬前已经发展为最主流使⽤的⼀种应⽤层协议 HTTP 往…

FPGA实现串口升级及MultiBoot(四)MultiBoot简介

缩略词索引: K7:Kintex 7V7:Vertex 7A7:Artix 7 我们在正常升级的过程(只使用一个位流文件),假如:(1)因为干扰通信模块收到了一个错误位;(2)或者烧写进FLASH时…

《深度学习》—— 神经网络模型中的损失函数及正则化惩罚和梯度下降

文章目录 前言一、损失函数二、正则化惩罚三、梯度下降 前言 在神经网络中,损失函数、正则化惩罚和梯度下降是三个关键的概念,它们共同作用于网络的训练过程,以提升网络的性能和泛化能力。神经网络模型结构如下图所示: 在构建好一…

LCSS—最长回文子序列

思路分析 关于”回文串“的问题,是面试中常见的,本文提升难度,讲一讲”最长回文子序列“问题,题目很好理解: 输入一个字符串 s,请找出 s 中的最长回文子序列长度。 比如输入 s"aecda"&#xff0c…

【AI-18】Adam和SGD优化算法比较

Adam(Adaptive Moment Estimation)和 SGD(Stochastic Gradient Descent,随机梯度下降)是两种常见的优化算法,它们在不同方面有各自的特点。 一、算法原理 SGD: 通过计算损失函数关于每个样本的…

S7-1500T分布式同步功能

1. 功能描述工控人加入PLC工业自动化精英社群 在一些实际应用中,会需要很多轴进行同步运行,如印刷机、纸尿裤生产线等。由于一个 PLC 的运动控制资源有限,控制轴的数量也是有限的,就会需要多个 PLC 间协调实现轴工艺对象的跨CPU的…

k8s以及prometheus

#生成控制器文件并建立控制器 [rootk8s-master ~]# kubectl create deployment bwmis --image timinglee/myapp:v1 --replicas 2 --dry-runclient -o yaml > bwmis.yaml [rootk8s-master ~]# kubectl expose deployment bwmis --port 80 --target-port 80 --dry-runclient…

专题三_二分查找算法_算法详细总结

目录 二分查找 1.⼆分查找(easy) 1)朴素二分查找,就是设mid(leftright)/2,xnums[mid],t就是我们要找的值 2)二分查找就是要求保证数组有序的前提下才能进行。 3)细节问题: 总结&#xff1a…

基于SpringBoot+Vue+MySQL的招聘管理系统

系统展示 用户前台界面 管理员后台界面 企业后台界面 系统背景 在当今数字化转型的大潮中,企业对于高效、智能化的人力资源管理系统的需求日益增长。招聘作为人力资源管理的首要环节,其效率与效果直接影响到企业的人才储备与竞争力。因此,构建…

详解Diffusion扩散模型:理论、架构与实现

本文深入探讨了Diffusion扩散模型的概念、架构设计与算法实现,详细解析了模型的前向与逆向过程、编码器与解码器的设计、网络结构与训练过程,结合PyTorch代码示例,提供全面的技术指导。 关注TechLead,复旦AI博士,分享A…

宠物毛发对人体有什么危害?宠物空气净化器小米、希喂、352对比实测

作为一个呼吸科医生,我自己也养猫。软软糯糯的小猫咪谁不爱啊,在养猫的过程中除了欢乐外,也面临着一系列的麻烦,比如要忍耐猫猫拉粑粑臭、掉毛、容易带来细菌等等的问题。然而我发现,现在许多年轻人光顾着养猫快乐了&a…

Linux命令:用于应用补丁文件来更新源代码的工具patch详解

目录 一、概述 二、基本概念 1. 补丁文件 2. diff 工具 三、基本用法 1、基本语法 2、常用选项 3、获取帮助 四、patch 工具的主要功能 1. 应用补丁 2. 逆向应用补丁 3. 查看补丁内容 4. 交互模式 5. 非交互模式 6. 备份文件 五、patch基本用法举例 1、应用补…

动态规划:汉诺塔问题|循环汉诺塔

目录 1. 汉诺塔游戏简介 2.算法原理 3.循环汉诺塔 1. 汉诺塔游戏简介 汉诺塔游戏是一个经典的数学智力游戏,其目标是将塔上不同大小的圆盘全部移动到另一个塔上,且在移动过程中必须遵守以下规则: 每次只能移动一个圆盘较大的圆盘不能放在…

css百分比布局中height:100%不起作用

百分比布局时,我们有时候会遇到给高度 height 设置百分比后无效的情况,而宽度设置百分比却是正常的。 当为一个元素的高度设定为百分比高度时,是相对于父元素的高度来计算的。当没有给父元素设置高度(height)时或设置…

杂七杂八-系统环境安装

杂七杂八-系统&环境安装 1. 系统安装2. 环境安装 仅个人笔记使用,后续会根据自己遇到问题记录,感谢点赞关注 1. 系统安装 Windows安装linux子系统WSL2:使用windows系统跑linux程序(大模型)WSL VSCode:VSCode连接WSL实现高效…

就服务器而言,ARM架构与X86架构有什么区别?各自的优势在哪里?

一、服务器架构概述 在数字化时代,服务器架构至关重要。服务器是网络核心节点,存储、处理和提供数据与服务,是企业和组织信息化、数字化的关键基础设施。ARM 和 x86 架构为服务器领域两大主要架构,x86 架构服务器在市场占主导&…

学习之git的团队协作

git团队协作 一 团队内协作 生成SSH公钥私钥 一(跨团队协作)

jmeter之仅一次控制器

仅一次控制器作用: 不管线程组设置多少次循环,它下面的组件都只会执行一次 Tips:很多情况下需要登录才能访问其他接口,比如:商品列表、添加商品到购物车、购物车列表等,在多场景下,登录只需要…

【GBase 8c V5_3.0.0 分布式数据库常用维护命令】

一、查看数据库状态/检查(gbase用户) 1.gha_ctl monitor 使用gha_ctl monitor查看节点运行情况(跟dcs的地址和端口) gha_ctl monitor -c gbase -l http://172.20.10.8:2379 -Hall |coordinator | datanode | gtm | server|dcs:必选字段。指定查看哪类集…