tf2使用savemodel保存之后转化为onnx适合进行om模型部署

tf2使用savemodel保存之后转化为onnx适合进行om模型部署

  • tf保存为kears框架h5文件
  • 将h5转化为savemodel格式,方便部署
  • 查看模型架构
  • 将savemodel转化为onnx格式
  • 使用netron
  • onnx模型细微处理
  • 代码转化为om以及推理代码,要么使用midstudio

tf保存为kears框架h5文件

前提环境是tf2.2及其版本以上的框架,模型训练结果保存为h5(也就是kears框架)
Pasted image 20240507233042

将h5转化为savemodel格式,方便部署

之后将h5文件转化为savemodel的格式
Pasted image 20240507233120

custom是在保存模型的时候需要的自定义函数,如果没有则不需要添加

保存结果如下
Pasted image 20240507233222

这个地方记得验证一下savemodel格式是否能成功搭载测试代码

import  os  
import pandas as pd  
import numpy as np  
from sklearn.metrics import accuracy_score  
from sklearn.model_selection import train_test_split  
from tensorflow.keras.models import Sequential  
from tensorflow.keras.layers import LSTM,Dense,Dropout  
from keras.utils import to_categorical  
import tensorflow as tf  
from tensorflow.python.keras.layers import Activation  os.chdir('D:/software_project/心电信号分类/')  # 加载 SavedModel 目录  
loaded_model = tf.saved_model.load('tfmodel_save')  # 获取默认的服务签名  
infer = loaded_model.signatures['serving_default']  
print(infer.structured_input_signature)  
print(infer.structured_outputs)  # 加载CSV文件  
file_path = 'data2/shuffled_merged_data.csv'  
data = pd.read_csv(file_path)  
from sklearn.preprocessing import StandardScaler  # 创建StandardScaler实例  
scaler = StandardScaler()  
features = data.iloc[0:1, :-1]  # 获取最后一列作为标签  
labels = data.iloc[0:1, -1]  
features1 = scaler.fit_transform(features)  
# features1 = features1.astype(np.float32)  # # 转化为numpy  
# features = features.to_numpy()  
trainX3 = features1.reshape((features1.shape[0], features1.shape[1], 1))  # # 将数据转换为Tensor  
input_data = tf.convert_to_tensor(trainX3, dtype=tf.float32)  output = infer(conv1d_input=input_data)  
output4=output['dense_3']  
print(output4.numpy())  # 为了确定每个样本的预测标签,我们找到概率最高的类别的索引  
predicted_indices = np.argmax(output4.numpy(), axis=1)  
accuracy = accuracy_score(labels, predicted_indices)  
print(accuracy)  output2=output["dense_8"]  
print(output["dense_8"])  
predicted_indices2= np.argmax(output2.numpy(), axis=1)  
accuracy2 = accuracy_score(labels, predicted_indices2)  
print(accuracy2)  output2_1=output["dense_8_1"]  
print(output["dense_8_1"])  
predicted_indices2= np.argmax(output2_1.numpy(), axis=1)  
accuracy3 = accuracy_score(labels, predicted_indices2)  
print(accuracy3)  print('nihao')  
# 不可用  
# print(output["StatefulPartitionedCall:0"])

查看模型架构

可以使用这个代码查看模型架构,输入输出的名字

 saved_model_cli show --dir D:\software_project\心电信号分类\tfmodel_save --tag_set serve --sig
nature_def serving_default

结构如下
Pasted image 20240507233536

如果可以用咱们继续进行下一步

将savemodel转化为onnx格式

之后将保存的savemodel格式转化为onnx格式

这里直接上大佬博客
在Atlas 200 DK中部署深度学习模型

基本把每个步骤过一遍即可

注意安装tensorflowgpu的版本是很高的
Pasted image 20240507233803

转换指令

python -m tf2onnx.convert --saved-model tensorflow-model-path --output model.onnx

使用netron

把模型放入到netron中
Netron

导出的onnx模型如下
Pasted image 20240507234346

onnx模型细微处理

获得的onnx模型放入netron中进行查看,发现有些未知输出量需要修改
【tensorflow onnx】TensorFlow2导出ONNX及模型可视化教程_tf2onnx-CSDN博客

主要是这种未知量
Pasted image 20240507233928

代码转化为om以及推理代码,要么使用midstudio

之后即可使用代码进行模型的转化为om

转化成功之后,放到atlks200dk板子中进行模型的推理
代码

import numpy as np  
import acllite_utils as utils  
import constants as const  
from acllite_model import AclLiteModel  
from acllite_resource import AclLiteResource  
import time  
import csv  
import numpy as np  class Reasoning(object):  """  class for reasoning    """    def __init__(self, model_path):  self._model_path = model_path  self.device_id = 0  self._model = None  def init(self):  """  Initialize        """  # Load model  self._model = AclLiteModel(self._model_path)  return const.SUCCESS  def inference(self, one_dim_data):  """  model inference        """        return self._model.execute(one_dim_data)  def main():  model_path = 'model_dim_replace.om'  # 打开 CSV 文件  with open('shuffled_merged_data.csv', newline='') as csvfile:  # 创建 CSV 读取器对象  csvreader = csv.reader(csvfile, delimiter=',')  # 跳过第一行(标题行)  next(csvreader)  # 读取第二行数据  second_row = next(csvreader)  # 移除最后一个数据  second_row_without_last = second_row[:-1]  # 将数据转换为 NumPy 数组  np_array = np.array(second_row_without_last, dtype=np.float32)  print(np_array.dtype)  # 输出转换后的 NumPy 数组  acl_resource = AclLiteResource()  acl_resource.init()  reasoning = Reasoning(model_path)  # init  ret = reasoning.init()  utils.check_ret("Reasoning.init ", ret)  start_time = time.time()  # 假设你有一个名为 input_data 的 NumPy 数组,它包含模型的输入数据  input_data = np.array([np_array])  # 替换为你的输入数据  result_class = reasoning.inference(input_data)  end_time = time.time()  execution_time = end_time - start_time  print(result_class)  
if __name__ == '__main__':  main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/9064.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

中国M2总量是两个美国,意味着什么

中国人民银行公布数据:2月末,我国广义货币(M2)余额299.56万亿元,同比增长8.7%。 2000年末我国M2仅13万亿元,2013年3月达到100万亿元;2020年1月突破200万亿元;2024年2月接近300万亿元, 与美欧日…

CPU的星际穿越——“三维”解析“二维”之谜

文章目录 写在前面为什么三维的CPU能执行二维的指令二维指令是三维机器的抽象而已计算机所有东西都是三维的降维抽象没有软件没有指令二维到三维的总结操作系统的重塑 写在前面 以下是自己关于CPU为何能执行指令的迷惑的抽丝破茧的解答—— 困扰我的一个的问题之CPU的星际穿越…

【Leetcode】八大排序

总述 插入排序:直接插入排序;希尔排序; 选择排序:简单选择排序;堆排序; 交换排序:冒泡排序;快速排序; 归并排序; 桶排序/基数排序; 直接插入排序 …

【软件工程】期末复习超全整理!!!

软件工程期末复习整理 软件工程大纲以及阅读说明用例图用例图例题1 用例文档用例文档例题1用例文档例题2 活动图活动图例题1活动图例题2活动图例题3 类图类图中的关系类图例题1类图例题2 顺序图顺序图例题1顺序图 例题2顺序图例题3顺序图--分析类顺序图例题4顺序图例题5 状态图…

重学java 33.API 4.日期相关类

任何事,必作于细,也必成于实 —— 24.5.9 一、Date日期类 1.Date类的介绍 1.概述: 表示特定的瞬间,精确到亳秒 2.常识: a.1000毫秒 1秒 b.时间原点:1970年1月1日 0时0分0秒(UNIX系统起始时间),叫做格林威治时间,在0时区上 c.时区:北京位于东八区,一个时区…

模拟实现链表的功能

1.什么是链表? 链表是一种物理存储结构上非连续存储结构,数据元素的逻辑顺序是通过链表中的引用链接次序实现的 。 实际中链表的结构非常多样,以下情况组合起来就有8种链表结构: 单向或者双向 带头或者不带头 …

车载测试到底怎么样?真实揭秘!

什么是车载智能系统测试? 车载智能系统,是汽车智能化重要的组成部分,由旧有的车载资通讯系统结合联网汽车技术所演进而来,随着软硬件技术的不断进步, 让车载智能系统拥有强大的运算能力及多元化的应用功能。 车载智能…

苹果iPad M4:Console级别图形和AI强大功能

苹果iPad M4:Console级别图形和AI强大功能 Apple近日发布了最新的M4芯片,旨在为iPad Pro系列带来明显的性能提升和电池续航时间延长。在本篇报道中,我们将详细介绍M4芯片的特点、性能改进和为创意专业人士带来的影响。 M4芯片的强大功能 …

图解项目管理必备十大管理模型及具体应用建议

心智模型是根深蒂固存在于人们心中,影响人们如何理解这个世界(包括我们自己、他人、组织和整个世界),以及如何采取行动的诸多假设、成见、逻辑、规则,甚至图像、印象等。本图通过对心智模型的分类和描述,表…

【Linux】shell基础,shell脚本

Shell Shell是一个用C语言编写的程序,接受用户输入的命令,并将其传递给操作系统内核执行。Shell还负责解释和执行命令、管理文件系统、控制进程,是用户使用Linux的桥梁。Shell既是一种命令语言,又是一种程序设计语言 Shell脚本 Sh…

上位机图像处理和嵌入式模块部署(树莓派4b和c++新版本的问题)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 自己读书的时候是03年,学习c也是差不多04年开始,到现在基本上20年了。这20年过程当中,其实c的语言版本一直是在…

【stomp 实战】spring websocket 接收消息源码分析

后台消息的发送过程,我们通过spring websocket用户消息发送源码分析已经了解了。我们再来分析一下后端接收消息的过程。这个过程和后端发送消息过程有点类似。 前端发送消息 前端发送消息给服务端的示例如下: 发送给目的/app/echo一个消息。 //主动发…

科林算法_3 图

一、图论基础 多对多的关系 定义&#xff1a;G(V,E) Vertex顶点 Edge边 顶点的集合V{v1,v2} 边的结合E{(v1,v2)} 无向图(1,2) 有向图<1,2> 依附&#xff1a;边(v1,v2)依附于顶点v1,v2 路径&#xff1a;&#xff08;v1,v2)(v2,v3) 无权路径最短&#xff1a;边最少…

程序员不会告诉老板的那些神器

目录 1. 持续集成工具&#xff1a;CruiseControl&#xff08;简称CC&#xff09; 2. 代码风格、质量检查工具&#xff1a;StyleCop 3.AI工具 3.1 AI助力编写开发日报 3.2 AI助力编写普适性代码 3.3 AI助力生成代码注释 3.4 AI助力重构代码去掉“坏味道” 3.5 AI助力…

【小白的大模型之路】基础篇:Transformer细节

基础篇&#xff1a;Transformer 引言模型基础架构原论文架构图EmbeddingPostional EncodingMulti-Head AttentionLayerNormEncoderDecoder其他 引言 此文作者本身对transformer有一些基础的了解,此处主要用于记录一些关于transformer模型的细节部分用于进一步理解其具体的实现机…

渗透之sql注入---宽字节注入

目录 宽字节注入原理&#xff1a; 实战&#xff1a; 源码分析&#xff1a; 开始注入&#xff1a; 找注入点&#xff1a; 注入数据库名&#xff1a; 注入表名&#xff1a; 注入列明&#xff1a; 注入具体值&#xff1a;http://sqli-labs:8084/less-32/?id-1%df%27unio…

luceda ipkiss教程 66:金属线的钝角转弯

案例分享&#xff1a;金属线的135度转弯&#xff1a; 所有代码如下&#xff1a; from si_fab import all as pdk import ipkiss3.all as i3 from ipkiss.geometry.shape_modifier import __ShapeModifierAutoOpenClosed__ from numpy import sqrtclass ShapeManhattanStub(__…

《ESP8266通信指南》11-Lua开发环境配置

往期 《ESP8266通信指南》10-MQTT通信&#xff08;Arduino开发&#xff09;-CSDN博客 《ESP8266通信指南》9-TCP通信&#xff08;Arudino开发&#xff09;-CSDN博客 《ESP8266通信指南》8-连接WIFI&#xff08;Arduino开发&#xff09;&#xff08;非常简单&#xff09;-CSD…

短信公司_供应群发短信公司

短信公司——供应群发短信公司 短信公司作为一种为企业提供群发短信服务的服务商&#xff0c;正逐渐受到市场的青睐。供应群发短信公司作为其中的一种类型&#xff0c;为各行各业的企业提供高效、便捷的短信推广渠道。本文将介绍短信公司的作用以及供应群发短信公司的特点和优势…

Django之创建Model以及后台管理

一&#xff0c;创建项目App python manage.py startapp App 二&#xff0c;在App.models.py中创建类&#xff0c;以下是示例 class UserModel(models.Model):uid models.AutoField(primary_keyTrue, auto_createdTrue)name models.CharField(max_length10, uniqueTrue, db…