rknn转换后精度差异很大,失真算子自纠

下面是添加了详细注释的优化代码:

import cv2
import numpy as np
import onnx
import onnxruntime as rt
from onnx import helper, shape_inferencedef get_all_node_names(model):"""获取模型中所有节点的名称。参数:model (onnx.ModelProto): ONNX 模型。返回:list: 包含所有节点名称的列表。"""return [node.name for node in model.graph.node]def remove_node_and_following(model, node_name):"""删除指定节点及其后续节点,并返回新的模型。参数:model (onnx.ModelProto): 原始 ONNX 模型。node_name (str): 要删除的节点名称。返回:onnx.ModelProto: 修改后的 ONNX 模型。"""nodes_to_keep = []  # 要保留的节点nodes_to_remove = set(i.name for i in model.graph.output)  # 要删除的节点start_removal = False  # 是否开始删除节点output = []  # 输出节点列表for node in model.graph.node:if node.name == node_name:start_removal = Trueif start_removal:nodes_to_remove.add(node.name)else:nodes_to_keep.append(node)output.extend(node.output)for node in model.graph.value_info:if node.name in output:shape = [dim.dim_value if (dim.dim_value > 0 and dim.HasField('dim_value')) else Nonefor dim in node.type.tensor_type.shape.dim]output_tensor = helper.make_tensor_value_info(node.name,onnx.TensorProto.FLOAT,shape)model.graph.output.append(output_tensor)new_graph = helper.make_graph(nodes_to_keep,model.graph.name,model.graph.input,[output for output in model.graph.output if output.name not in nodes_to_remove],model.graph.initializer,)new_model = helper.make_model(new_graph, producer_name=model.producer_name)new_model = shape_inference.infer_shapes(new_model)return new_modeldef preprocess_image(image_path, target_shape):"""加载并预处理图像。参数:image_path (str): 图像文件路径。target_shape (tuple): 目标形状 (宽, 高)。返回:np.ndarray: 预处理后的图像数组。"""im = cv2.imread(image_path)im = cv2.resize(im, target_shape)im = im.transpose((2, 0, 1))[::-1]  # HWC 转 CHW, BGR 转 RGBreturn np.ascontiguousarray(im)def main():model_path = 'yolov5s.onnx'model = onnx.load(model_path)dtype_map = {'tensor(float)': np.float32,'tensor(int32)': np.int32,'tensor(int64)': np.int64,}all_node_names = get_all_node_names(model)im = preprocess_image('./sample.jpg', (640, 640))for node_name in all_node_names:if node_name:new_model = remove_node_and_following(model, node_name)new_model_path = 'test.onnx'onnx.save(new_model, new_model_path)sess = rt.InferenceSession(new_model_path)output_names = [output.name for output in sess.get_outputs()]input_data = {input.name: im.reshape(input.shape).astype(dtype_map[input.type])for input in sess.get_inputs()}outputs = sess.run(output_names, input_data)# TODO: 补充 rknn 转换、加载和获取输出的代码# TODO: 补充评估对比两个输出 onnx 的 target 和 rknn 的 predict# TODO: 补充当评估准确率低于 0.8 时终止循环,此时的模型最后一个算子就是 rknn 不正确的算子if __name__ == "__main__":main()

在这里插入图片描述

解释

在补充 TODO 部分的代码之前,我们需要确保你有正确的 RKNN API 环境,并且你知道如何使用它来转换、加载和获取输出。以下是一个示例代码,展示了如何使用 RKNN API 完成这些任务。

1. 安装 RKNN Toolkit

首先,确保你已经安装了 RKNN Toolkit。你可以通过以下命令安装:

pip install rknn-toolkit

2. 补充代码

以下是补充了 TODO 部分的代码:

import cv2
import numpy as np
import onnx
import onnxruntime as rt
from onnx import helper, shape_inference
from rknn.api import RKNNdef get_all_node_names(model):"""获取模型中所有节点的名称。参数:model (onnx.ModelProto): ONNX 模型。返回:list: 包含所有节点名称的列表。"""return [node.name for node in model.graph.node]def remove_node_and_following(model, node_name):"""删除指定节点及其后续节点,并返回新的模型。参数:model (onnx.ModelProto): 原始 ONNX 模型。node_name (str): 要删除的节点名称。返回:onnx.ModelProto: 修改后的 ONNX 模型。"""nodes_to_keep = []  # 要保留的节点nodes_to_remove = set(i.name for i in model.graph.output)  # 要删除的节点start_removal = False  # 是否开始删除节点output = []  # 输出节点列表for node in model.graph.node:if node.name == node_name:start_removal = Trueif start_removal:nodes_to_remove.add(node.name)else:nodes_to_keep.append(node)output.extend(node.output)for node in model.graph.value_info:if node.name in output:shape = [dim.dim_value if (dim.dim_value > 0 and dim.HasField('dim_value')) else Nonefor dim in node.type.tensor_type.shape.dim]output_tensor = helper.make_tensor_value_info(node.name,onnx.TensorProto.FLOAT,shape)model.graph.output.append(output_tensor)new_graph = helper.make_graph(nodes_to_keep,model.graph.name,model.graph.input,[output for output in model.graph.output if output.name not in nodes_to_remove],model.graph.initializer,)new_model = helper.make_model(new_graph, producer_name=model.producer_name)new_model = shape_inference.infer_shapes(new_model)return new_modeldef preprocess_image(image_path, target_shape):"""加载并预处理图像。参数:image_path (str): 图像文件路径。target_shape (tuple): 目标形状 (宽, 高)。返回:np.ndarray: 预处理后的图像数组。"""im = cv2.imread(image_path)im = cv2.resize(im, target_shape)im = im.transpose((2, 0, 1))[::-1]  # HWC 转 CHW, BGR 转 RGBreturn np.ascontiguousarray(im)def convert_onnx_to_rknn(onnx_model_path, rknn_model_path):"""将 ONNX 模型转换为 RKNN 模型。参数:onnx_model_path (str): ONNX 模型路径。rknn_model_path (str): 转换后的 RKNN 模型路径。"""rknn = RKNN()# 加载 ONNX 模型print('--> Loading model')ret = rknn.load_onnx(model=onnx_model_path)if ret != 0:print('Load ONNX model failed!')returnprint('done')# 配置模型print('--> Building model')ret = rknn.build(do_quantization=False)if ret != 0:print('Build RKNN model failed!')returnprint('done')# 导出 RKNN 模型print('--> Export RKNN model')ret = rknn.export_rknn(rknn_model_path)if ret != 0:print('Export RKNN model failed!')returnprint('done')def load_and_run_rknn_model(rknn_model_path, input_data):"""加载 RKNN 模型并运行推理。参数:rknn_model_path (str): RKNN 模型路径。input_data (np.ndarray): 输入数据。返回:list: RKNN 模型的输出结果。"""rknn = RKNN()# 加载 RKNN 模型print('--> Loading RKNN model')ret = rknn.load_rknn(rknn_model_path)if ret != 0:print('Load RKNN model failed!')return []print('done')# 初始化 RKNN 模型print('--> Init runtime environment')ret = rknn.init_runtime()if ret != 0:print('Init runtime environment failed!')return []print('done')# 运行推理print('--> Running model')outputs = rknn.inference(inputs=[input_data])print('done')rknn.release()return outputsdef compare_outputs(onnx_outputs, rknn_outputs, threshold=0.8):"""比较 ONNX 和 RKNN 模型的输出结果。参数:onnx_outputs (list): ONNX 模型的输出结果。rknn_outputs (list): RKNN 模型的输出结果。threshold (float): 准确率阈值。返回:bool: 如果准确率低于阈值,则返回 False,否则返回 True。"""# 计算准确率 (这里假设是简单的相对误差)accuracy = np.mean([np.allclose(onnx_out, rknn_out, rtol=threshold) for onnx_out, rknn_out in zip(onnx_outputs, rknn_outputs)])return accuracy >= thresholddef main():model_path = 'yolov5s.onnx'model = onnx.load(model_path)dtype_map = {'tensor(float)': np.float32,'tensor(int32)': np.int32,'tensor(int64)': np.int64,}all_node_names = get_all_node_names(model)im = preprocess_image('./sample.jpg', (640, 640))for node_name in all_node_names:if node_name:new_model = remove_node_and_following(model, node_name)new_model_path = 'test.onnx'onnx.save(new_model, new_model_path)sess = rt.InferenceSession(new_model_path)output_names = [output.name for output in sess.get_outputs()]input_data = {input.name: im.reshape(input.shape).astype(dtype_map[input.type])for input in sess.get_inputs()}onnx_outputs = sess.run(output_names, input_data)# 转换 ONNX 模型为 RKNN 模型rknn_model_path = 'test.rknn'convert_onnx_to_rknn(new_model_path, rknn_model_path)# 加载并运行 RKNN 模型rknn_outputs = load_and_run_rknn_model(rknn_model_path, im)# 比较 ONNX 和 RKNN 模型的输出结果if not compare_outputs(onnx_outputs, rknn_outputs):print(f'Node {node_name} is the incorrect operator in RKNN model.')breakif __name__ == "__main__":main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/31659.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

现代易货交易:重塑物品交换的新纪元

在数字时代的浪潮中,交易模式正在经历一场革命。其中,现代易货交易模式以其独特的魅力,逐渐在市场中崭露头角。这种交易模式不仅是对古老“以物换物”的复兴,更是对物品价值和交换方式的全新定义。 现代易货:物品交换的…

openppp2 命令行接口详解

openppp2 是一个工作在 OSI/3 Layer 网络通信层的虚拟以太网工具链的开源软件,在查阅本文之前,人们可以查阅以下资料。 开源仓库: liulilittle/openppp2: PPP PRIVATE NETWORK™ 2 VPN Next Generation Reliable and Secure Virtual Etherne…

LeetCode 19.删除链表的倒数第N个结点

链接 https://leetcode.cn/problems/remove-nth-node-from-end-of-list/description/ 题目: 给你一个链表,删除链表的倒数第 n 个结点,并且返回链表的头结点。 示例1: 输入:head [1,2,3,4,5], n 2 输出:[1,2,3,5…

电动汽车厂商Rivian将全新设计元素融入由虚幻引擎驱动的车机界面

Rivian Automotive(简称:“Rivian”),是美国一家电动汽车厂商,该品牌创办于2009年,总部位于加州埃尔文,专注于生产电动皮卡车Rivian R1T和电动SUV Rivian R1S。 Rivian的车主们正追寻这样一条道…

Qt坐标系统

目录 概述 渲染 逻辑表示 锯齿绘制 坐标转换 模拟时钟示例 Window-Viewport转换 概述 坐标系统由QPainter类控制。与QPaintDevice和QPaintEngine类一起,QPainter构成了Qt绘画系统的基础。QPainter用于执行绘制操作,QPaintDevice是一个二维空间的抽…

番外篇 | YOLOv8算法解析和实战应用:车辆检测 + 车辆追踪 + 行驶速度计算

前言:Hello大家好,我是小哥谈。YOLOv8是ultralytics公司在2023年1月10号开源的,是YOLOv5的下一个重大更新版本,目前支持图像分类、物体检测和实例分割任务,在还没有开源时就收到了用户的广泛关注。它是一个SOTA模型,建立在以前YOLO版本的成功基础上,并引入了新的功能和改…

开发中遇到的错误 - @SpringBootTest 注解爆红

我在使用 SpringBootTest 注解的时候爆红了&#xff0c;ait 回车也导不了包&#xff0c;后面发现是因为没有加依赖&#xff1a; <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-test</artifactId>…

【shell脚本速成】函数

文章目录 一、函数1.1、函数介绍1.2、函数定义1.3、函数调用 &#x1f308;你好呀&#xff01;我是 山顶风景独好 &#x1f388;欢迎踏入我的博客世界&#xff0c;能与您在此邂逅&#xff0c;真是缘分使然&#xff01;&#x1f60a; &#x1f338;愿您在此停留的每一刻&#xf…

网络虚拟化考题

vrrp讲过吗&#xff1f;&#xff1f;&#xff1f; d 每一层都是什么设备啊 abcd 为啥流量不可控不可视 c是啥意思 讲过吗 abc aNET网络虚拟化是啥啊 为啥&#xff1f;&#xff1f; 啥是CDN&#xff1f;&#xff1f;&#xff1f;&#xff1f;&#xff1f;

Java数据类型与运算符

1. 变量和类型 变量指的是程序运行时可变的量&#xff0c;相当于开辟一块空间来保存一些数据。 类型则是对变量的种类进行了划分&#xff0c;不同类型的变量具有不同的特性。 1.1 整型变量&#xff08;重点&#xff09; 基本语法格式&#xff1a; int 变量名 初始值;代码示…

舔狗日记Puls微信小程序源码

源码介绍&#xff1a; 这是一款舔狗日记Puls微信小程序源码&#xff0c;提供每日一舔的功能&#xff0c;让你舔到最后&#xff0c;什么都有&#xff01; 源码通过API获取一些舔狗日记&#xff0c;内置了100多句舔狗日记&#xff0c;让你摆脱上班摸鱼的无聊时光&#xff0c; …

TIM: A Time Interval Machine for Audio-Visual Action Recognition

标题&#xff1a;TIM&#xff1a;一种用于视听动作识别的时间间隔机器 源文链接&#xff1a;openaccess.thecvf.com/content/CVPR2024/papers/Chalk_TIM_A_Time_Interval_Machine_for_Audio-Visual_Action_Recognition_CVPR_2024_paper.pdfhttps://openaccess.thecvf.com/cont…

社区项目-项目介绍环境搭建

文章目录 1.技术选型2.原型设计1.安装AxureRP2.进行汉化3.载入元件库4.基本设计 3.元数建模1.安装元数建模软件2.新建项目3.新增一个刷题模块主题域4.新增数据表 subject_category5.新增关系图&#xff0c;将表拖过来6.新增题目标签表7.新增题目信息表8.新增单选表、多选表、判…

​Claude 3.5 最新体验:助力硕博生与科研人员高效完成论文,超越ChatGPT4o !

我是娜姐 迪娜学姐 &#xff0c;一个SCI医学期刊编辑&#xff0c;探索用AI工具提效论文写作和发表。 要不说AI领域的进展真的是日新月异&#xff0c;发展速度已经大大超过预期进度。娜姐本来在准备AI降重工具的测评文章&#xff08;最近好多小伙伴需要&#xff09;。 昨天晚上…

ECharts 词云图案例二:创意蒙版应用

ECharts 词云图案例二&#xff1a;创意蒙版应用 引言 在数据可视化领域&#xff0c;ECharts 以其强大的功能性和灵活性&#xff0c;成为开发者和设计师的首选工具之一。继上一篇关于 ECharts 词云图的详细介绍后&#xff0c;本文将探索词云图的进阶应用——使用蒙版来创造更具…

【C#上位机应用开发实战】—— UI界面设计与实践代码

在C#上位机应用开发中&#xff0c;UI界面设计是至关重要的一环。一个好的UI设计不仅可以提升应用的用户体验&#xff0c;还可以提高应用的易用性和效率。本文将介绍一些UI界面设计的实战经验和技巧。 在这个示例中&#xff0c;我们创建了一个名为MainForm的窗体类。该窗体包含了…

AI在线免费视频工具2:视频配声音;图片说话hedra

1、视频配声音 https://deepmind.google/discover/blog/generating-audio-for-video/ https://www.videotosoundeffects.com/ &#xff08;免费在线使用&#xff09; 2、图片说话在线图片生成播报hedra hedra 上传音频与图片即可合成 https://www.hedra.com/ https://www.…

如何使用Windows备份轻松将数据转移到新电脑?这里有详细步骤

序言 我们都知道那种买了一台新电脑,就想直接上手的感觉。我记得在过去的日子里,要花几个小时传输我的文件,并试图复制我的设置。在当今传输数据的众多方法中,Windows备份提供了一个简单可靠的解决方案。 登录到你的Microsoft帐户 Microsoft在传输过程中使用其云存储来保…

英文字母表

目录 一 设计原型 二 后台源码 一 设计原型 二 后台源码 namespace 英文字母表 {public partial class Form1 : Form{public Form1(){InitializeComponent();}private void Form1_Load(object sender, EventArgs e){foreach (var item in panel1.Controls){if (item ! null)…

A股3000点失守是出局还是机会?

今天的大A失守300点&#xff0c;那么A股3000点失守是出局还是机会&#xff1f; 1、今天两市低开&#xff0c;盘中一度跌破3000点&#xff0c;最低回踩到了2985点&#xff0c;盘面出现了两个罕见现象&#xff0c;意味着即将探底回升。 2、盘面出现两个罕见现象&#xff1a; 一是…