【如何训练一个中英翻译模型】LSTM机器翻译模型部署之onnx(python)(四)

系列文章
【如何训练一个中英翻译模型】LSTM机器翻译seq2seq字符编码(一)

【如何训练一个中英翻译模型】LSTM机器翻译模型训练与保存(二)

【如何训练一个中英翻译模型】LSTM机器翻译模型部署(三)

【如何训练一个中英翻译模型】LSTM机器翻译模型部署之onnx(python)(四)

目录

  • 一、事前准备
  • 二、.h5模型保存为TFSaveModel格式样例
  • 三、模型转换
    • 1、encoder_model的转换
      • 1).h5模型保存为TFSaveModel
      • 2)TFSaveModel格式模型保存为onnx模型
      • 3)onnx模型简化
    • 2、decoder_model的转换
      • 1).h5模型保存为TFSaveModel
      • 2)TFSaveModel格式模型保存为onnx模型
      • 3)onnx模型简化
  • 4、onnx模型推理
      • 1)加载模型数据
      • 2)查看模型输入输出信息
      • 3)模型推理搭建
      • 4)模型推理
      • 5)完整代码

一、事前准备

先把要用到的几个工具说一下:

ncnn:https://github.com/Tencent/ncnn
tf2onnx:https://github.com/onnx/tensorflow-onnx
netron:https://netron.app
onnxsim:https://github.com/daquexian/onnx-simplifier
onnxruntime:https://github.com/microsoft/onnxruntime
以上工具的安装与使用后面会抽空补充一下,在这里先记录下,以免忘记了

有了工具之后,我们还需要以下几个文件:
在这里插入图片描述
这几个文件可以在前面的文章【如何训练一个中译英翻译器】LSTM机器翻译模型训练与保存(二)训练一个模型并保存模型得到,最快的方式就是运行文章最后的kaggle notebook,直接得到文件,然后下载下来即可

二、.h5模型保存为TFSaveModel格式样例

要将tf模型转为onnx模型,我们需要先将格式为.h5的tf模型保存为saved_model的格式,先给出样例:

import tensorflow as tf
from keras.models import load_model# 加载Keras模型
model = load_model('encoder_model.h5')# 转换为SavedModel类型
tf.saved_model.save(model, 'TFSaveModel')

三、模型转换

1、encoder_model的转换

1).h5模型保存为TFSaveModel

import tensorflow as tf
from keras.models import load_model# 加载Keras模型
model = load_model('encoder_model.h5')# 转换为SavedModel类型
tf.saved_model.save(model, 'TFSaveModel')

2)TFSaveModel格式模型保存为onnx模型

python3 -m tf2onnx.convert --saved-model TFSaveModel --output onnxModel/encoder_model.onnx

3)onnx模型简化

打开https://netron.app/来看下网络结构,主要是先看输入部分的维度(网络结构后面会细讲)
可以看到输入维度:
input_1:[unk__64、unk__65、62]
我们需要将 unk__64、unk__65 这两个改为具体数值,否则在导出ncnn模型时会报一些op不支持的错误,那么问题来了,要怎么改,我也不知道啊!!!
哈哈哈,开完笑的,都写出来了,怎么会不知道,请听我慢慢说来。
在这里插入图片描述[unk__64、unk__65、62]
其实数据第一个unk__64是batch,第二个unk__65是输入句子的最大长度,第三个62是字符总数量,我们在推理时,batch size一般为1,所以这个input_1的shape就是[1,max_encoder_seq_length, num_encoder_tokens](num_encoder_tokens模型已经帮我们填好了)
max_encoder_seq_length, num_encoder_tokens 这两个参数可以在训练的时候获取到了,拿到这个input shape 之后,对onnx模型进行simplify,我训练出来的模型时得到的shape是[1,16,62],因此执行以下命令:

python3 -m onnxsim onnxModel/encoder_model.onnx onnxModel/encoder_model-sim.onnx --overwrite-input-shape 1,16,62

可得到简化后的onnx模型
在这里插入图片描述
这个时候,我们再用https://netron.app打开encoder_model-sim.onnx,可以看到encoder模型的输出了,有两个输出,均为[1,256]的维度
在这里插入图片描述

2、decoder_model的转换

然后我们需要对decoder_model.h5也进行转换,

1).h5模型保存为TFSaveModel

import tensorflow as tf
from keras.models import load_model# 加载Keras模型
model = load_model('decoder_model.h5')# 转换为SavedModel类型
tf.saved_model.save(model, 'TFSaveModel')

2)TFSaveModel格式模型保存为onnx模型

python3 -m tf2onnx.convert --saved-model TFSaveModel --output onnxModel/decoder_model.onnx

3)onnx模型简化

同样打开模型来看,能看到一共有三个输入:
input_2:[unk__55,unk__56,849]
input_3:[unk__57,256]
input_4:[unk__58,256]
其中,input_3、input_4为encoder的输出,因此可以得到这两个输入维度均为[1,256]
那么,input_2的维度是多少,我们接着往下看。
在这里插入图片描述
我们想一想,解码器除了接受编码器的数据,还有什么数据没给它,没有错,就是target_characters的特征,对于英译中而言就是中文的字符,要解码器解出中文,肯定要把中文数据给它,要不然你让解码器去解空气啊,实际上这个 input_2的维度就是

target_seq = np.zeros((1, 1, num_decoder_tokens))

num_decoder_tokens同样可以在训练的时候获取到(至于不知道怎么来的,可以看这个系列文章的第一、二篇),我这边得到的num_decoder_tokens是849,当然实际上这个模型的 input_2:[unk__55,unk__56,849]已经给了num_decoder_tokens,我们只需要把unk__55,unk__56都改为1就可以了,即[1,1,849],那么对onnx进行simplify

python3 -m onnxsim onnxModel/decoder_model.onnx onnxModel/decoder_model-sim.onnx --overwrite-input-shape input_2:1,1,849 input_3:1,256 input_4:1,256

成功完成simplify可得到:
在这里插入图片描述

4、onnx模型推理

到最后一步了,导出onnx模型后,要试试这个模型怎么样,所以拿过来推理一波,推理代码是从前面文章【如何训练一个中译英翻译器】LSTM机器翻译模型训练与保存(二)的第小6节模型加载与推理里面的代码改过来的,感兴趣的小伙伴可以去看看两者的差异

1)加载模型数据

模型数据的加载主要是加载input_words.txt、target_words.txt、config.json、encoder_model-sim.onnx、decoder_model-sim.onnx 这几个文件

input_words.txt、target_words.txt:为输入输出字符表
config.json:为最长输入长度与最长输出长度
encoder_model-sim.onnx、decoder_model-sim.onnx :为导出的onnx模型

import onnxruntime
import numpy as np
# 加载字符
# 从 input_words.txt 文件中读取字符串
with open('config/input_words.txt', 'r') as f:input_words = f.readlines()input_characters = [line.rstrip('\n') for line in input_words]# 从 target_words.txt 文件中读取字符串
with open('config/target_words.txt', 'r', newline='') as f:target_words = [line.strip() for line in f.readlines()]target_characters = [char.replace('\\t', '\t').replace('\\n', '\n') for char in target_words]#字符处理,以方便进行编码
input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])# something readable.
reverse_input_char_index = dict((i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict((i, char) for char, i in target_token_index.items())
num_encoder_tokens = len(input_characters) # 英文字符数量
num_decoder_tokens = len(target_characters) # 中文文字数量import json
with open('config/config.json', 'r') as file:loaded_data = json.load(file)# 从加载的数据中获取max_encoder_seq_length和max_decoder_seq_length的值
max_encoder_seq_length = loaded_data["max_encoder_seq_length"]
max_decoder_seq_length = loaded_data["max_decoder_seq_length"]encoderSess = onnxruntime.InferenceSession('onnxModel/encoder_model-sim.onnx')
decoderSess = onnxruntime.InferenceSession('onnxModel/decoder_model-sim.onnx')

2)查看模型输入输出信息

查看输入输出信息主要是为了获取输入名称,在进行模型输入的时候,要先知道模型有哪些输入,维度是多少,才能输入正确的数据


print("----------------- 输入部分 -----------------")
input_tensors = encoderSess.get_inputs()  # 该 API 会返回列表
for input_tensor in input_tensors:         # 因为可能有多个输入,所以为列表input_info = {"name" : input_tensor.name,"type" : input_tensor.type,"shape": input_tensor.shape,}print(input_info)print("----------------- 输出部分 -----------------")
output_tensors = encoderSess.get_outputs()  # 该 API 会返回列表
for output_tensor in output_tensors:         # 因为可能有多个输出,所以为列表output_info = {"name" : output_tensor.name,"type" : output_tensor.type,"shape": output_tensor.shape,}print(output_info)print("----------------- 输入部分 -----------------")
input_tensors = decoderSess.get_inputs()  # 该 API 会返回列表
for input_tensor in input_tensors:         # 因为可能有多个输入,所以为列表input_info = {"name" : input_tensor.name,"type" : input_tensor.type,"shape": input_tensor.shape,}print(input_info)print("----------------- 输出部分 -----------------")
output_tensors = decoderSess.get_outputs()  # 该 API 会返回列表
for output_tensor in output_tensors:         # 因为可能有多个输出,所以为列表output_info = {"name" : output_tensor.name,"type" : output_tensor.type,"shape": output_tensor.shape,}print(output_info)

3)模型推理搭建


def decode_sequence(input_seq):# Encode the input as state vectors.states_value = encoderSess.run(None, {'input_1': input_seq})# Generate empty target sequence of length 1.target_seq = np.zeros((1, 1, num_decoder_tokens), dtype=np.float32)# Populate the first character of target sequence with the start character.target_seq[0, 0, target_token_index['\t']] = 1.# this target_seq you can treat as initial state# Sampling loop for a batch of sequences# (to simplify, here we assume a batch of size 1).stop_condition = Falsedecoded_sentence = ''while not stop_condition:output_tokens, h, c = decoderSess.run(None, {'input_2': target_seq, 'input_3': states_value[0], 'input_4': states_value[1]})# Sample a token# argmax: Returns the indices of the maximum values along an axis# just like find the most possible charsampled_token_index = np.argmax(output_tokens[0, -1, :])# find char using indexsampled_char = reverse_target_char_index[sampled_token_index]# and append sentencedecoded_sentence += sampled_char# Exit condition: either hit max length# or find stop character.if (sampled_char == '\n' or len(decoded_sentence) > max_decoder_seq_length):stop_condition = True# Update the target sequence (of length 1).# append then ?# creating another new target_seq# and this time assume sampled_token_index to 1.0target_seq = np.zeros((1, 1, num_decoder_tokens), dtype=np.float32)target_seq[0, 0, sampled_token_index] = 1.# Update states# update states, frome the front partsstates_value = [h, c]return decoded_sentenceinput_text = "Call me."
encoder_input_data = np.zeros((1,max_encoder_seq_length, num_encoder_tokens),dtype='float32')
for t, char in enumerate(input_text):# 3D vector only z-index has char its value equals 1.0encoder_input_data[0,t, input_token_index[char]] = 1.

4)模型推理

input_seq = encoder_input_data
decoded_sentence = decode_sequence(input_seq)
print('-')
print('Input sentence:', input_text)
print('Decoded sentence:', decoded_sentence)

5)完整代码

import onnxruntime
import numpy as np
# 加载字符
# 从 input_words.txt 文件中读取字符串
with open('config/input_words.txt', 'r') as f:input_words = f.readlines()input_characters = [line.rstrip('\n') for line in input_words]# 从 target_words.txt 文件中读取字符串
with open('config/target_words.txt', 'r', newline='') as f:target_words = [line.strip() for line in f.readlines()]target_characters = [char.replace('\\t', '\t').replace('\\n', '\n') for char in target_words]#字符处理,以方便进行编码
input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])# something readable.
reverse_input_char_index = dict((i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict((i, char) for char, i in target_token_index.items())
num_encoder_tokens = len(input_characters) # 英文字符数量
num_decoder_tokens = len(target_characters) # 中文文字数量import json
with open('config/config.json', 'r') as file:loaded_data = json.load(file)# 从加载的数据中获取max_encoder_seq_length和max_decoder_seq_length的值
max_encoder_seq_length = loaded_data["max_encoder_seq_length"]
max_decoder_seq_length = loaded_data["max_decoder_seq_length"]encoderSess = onnxruntime.InferenceSession('onnxModel/encoder_model-sim.onnx')
decoderSess = onnxruntime.InferenceSession('onnxModel/decoder_model-sim.onnx')print("----------------- 输入部分 -----------------")
input_tensors = encoderSess.get_inputs()  # 该 API 会返回列表
for input_tensor in input_tensors:         # 因为可能有多个输入,所以为列表input_info = {"name" : input_tensor.name,"type" : input_tensor.type,"shape": input_tensor.shape,}print(input_info)print("----------------- 输出部分 -----------------")
output_tensors = encoderSess.get_outputs()  # 该 API 会返回列表
for output_tensor in output_tensors:         # 因为可能有多个输出,所以为列表output_info = {"name" : output_tensor.name,"type" : output_tensor.type,"shape": output_tensor.shape,}print(output_info)print("----------------- 输入部分 -----------------")
input_tensors = decoderSess.get_inputs()  # 该 API 会返回列表
for input_tensor in input_tensors:         # 因为可能有多个输入,所以为列表input_info = {"name" : input_tensor.name,"type" : input_tensor.type,"shape": input_tensor.shape,}print(input_info)print("----------------- 输出部分 -----------------")
output_tensors = decoderSess.get_outputs()  # 该 API 会返回列表
for output_tensor in output_tensors:         # 因为可能有多个输出,所以为列表output_info = {"name" : output_tensor.name,"type" : output_tensor.type,"shape": output_tensor.shape,}print(output_info)def decode_sequence(input_seq):# Encode the input as state vectors.states_value = encoderSess.run(None, {'input_1': input_seq})# Generate empty target sequence of length 1.target_seq = np.zeros((1, 1, num_decoder_tokens), dtype=np.float32)# Populate the first character of target sequence with the start character.target_seq[0, 0, target_token_index['\t']] = 1.# this target_seq you can treat as initial state# Sampling loop for a batch of sequences# (to simplify, here we assume a batch of size 1).stop_condition = Falsedecoded_sentence = ''while not stop_condition:output_tokens, h, c = decoderSess.run(None, {'input_2': target_seq, 'input_3': states_value[0], 'input_4': states_value[1]})# Sample a token# argmax: Returns the indices of the maximum values along an axis# just like find the most possible charsampled_token_index = np.argmax(output_tokens[0, -1, :])# find char using indexsampled_char = reverse_target_char_index[sampled_token_index]# and append sentencedecoded_sentence += sampled_char# Exit condition: either hit max length# or find stop character.if (sampled_char == '\n' or len(decoded_sentence) > max_decoder_seq_length):stop_condition = True# Update the target sequence (of length 1).# append then ?# creating another new target_seq# and this time assume sampled_token_index to 1.0target_seq = np.zeros((1, 1, num_decoder_tokens), dtype=np.float32)target_seq[0, 0, sampled_token_index] = 1.# Update states# update states, frome the front partsstates_value = [h, c]return decoded_sentenceinput_text = "Call me."
encoder_input_data = np.zeros((1,max_encoder_seq_length, num_encoder_tokens),dtype='float32')
for t, char in enumerate(input_text):# 3D vector only z-index has char its value equals 1.0encoder_input_data[0,t, input_token_index[char]] = 1.input_seq = encoder_input_data
decoded_sentence = decode_sequence(input_seq)
print('-')
print('Input sentence:', input_text)
print('Decoded sentence:', decoded_sentence)

可以看到运行结果:
在这里插入图片描述
代码比较简单,然后也有加一些注释,就不再细讲了,要不然就显得有点啰嗦,有疑问的可以留言,欢迎交流!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/12730.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

云原生架构

1. 何为云原生? 很多IT业内小伙伴会经常听到这个名词,那么什么是云原生呢?云原生是在云计算环境中构建、部署和管理现代应用程序的软件方法。 当今时代,众多企业希望构建高度可扩展、灵活且有弹性的应用程序,以便能够快…

13 亿美金买个寂寞?No!AI 时代的数据行业蓄势待发

6月底,全球数据分析领域彻底炸锅了。 两大数据分析企业Databricks和Snowflake纷纷将目光瞄准了AI大模型。要知道,这两位对手平时没少对台戏,为性能、产品和技术经常开撕。但在今年的自家大会上,两家企业却出奇的一致,…

云安全攻防(二)之 云原生安全

云原生安全 什么是云原生安全?云原生安全包含两层含义:面向云原生环境的安全和具有云原生特征的安全 面向云原生环境的安全 面向云原生环境的安全的目标是防护云原生环境中的基础设施、编排系统和微服务系统的安全。这类安全机制不一定会具有云原生的…

Java 设计模式 - 简单工厂模式 - 创建对象的简便之道

简单工厂模式是一种创建型设计模式,它提供了一种简单的方式来创建对象,而无需暴露对象创建的逻辑。在本篇博客中,我们将深入了解简单工厂模式的概念、实现方式以及如何在Java中使用它来创建对象。 为什么使用简单工厂模式? 在软…

【无标题】深圳卫视专访行云创新马洪喜:拥抱AI与云原生,深耕云智一体化创新

人工智能(AI)是引领新一轮科技革命和产业变革的重要驱动力。因此,深圳出台相关行动方案,统筹设立规模1,000亿元的人工智能基金群,引导产业集聚培育企业梯队,积极打造国家新一代人工智能创新发展试验区和国家…

【高压架构】AP5199S LED平均电流型恒流驱动IC 0.01调光 景观舞台汽车灯驱动照明

说明 AP5199S 是一款外围电路简单的多功能平均电流型 LED 恒流驱动器,适用于宽电压范围的非隔离式大功率恒流 LED 驱动领域。芯片 PWM 端口支持超小占空比的 PWM 调光,可响应 60ns 脉宽。为客户提供解决方案,限度地发挥灯具优势,…

shell中按照特定字符分割字符串,并且在切分后的每段内容后加上特定字符(串),然后再用特定字符拼接起来

文件中的内容&#xff0c;可以这么写&#xff1a; awk -F, -v OFS, {for(i1;i<‌NF;i){$i$i"_suffix"}}1 input.txt-F,&#xff1a;设置输入字段分隔符为逗号&#xff08;,&#xff09;&#xff0c;这将使awk按照逗号分割输入文本。-v OFS‘,’&#xff1a;设置输…

【Golang】Golang进阶系列教程--为什么 Go 不支持 []T 转换为 []interface

文章目录 前言官方解释内存布局程序运行中的内存布局通用方法 前言 在 Go 中&#xff0c;如果 interface{} 作为函数参数的话&#xff0c;是可以传任意参数的&#xff0c;然后通过类型断言来转换。 举个例子&#xff1a; package mainimport "fmt"func foo(v inter…

python 面向对象编程的特点 - 封装 - 继承(经典类、新式类) - 多态 - 静态方法、类方法 - 下划线的使用 - 回合制攻击游戏实验

目录 面向对象编程的特点&#xff1a; 封装&#xff1a;封装是将数据和操作&#xff08;方法&#xff09;封装在一个对象中的能力 继承&#xff1a;继承是指一个类&#xff08;子类&#xff09;可以继承另一个类&#xff08;父类&#xff09;的属性和方法。 我们为什么需要继…

HashMap中hash方法的作用(详解)

首先&#xff0c;hash方法用来干什么&#xff1f; 在搞清楚原理之前&#xff0c;我们先站在巨人的肩膀浅浅了解一下hash方法的本质作用。 实质上&#xff0c;它的作用很朴素&#xff0c;就是用key值通过某种方式计算出一个hash码 而且这个hash码我们后面要用来计算key存在底…

golangd\pycharm-ai免费代码助手安装使用gpt4-免费使用--[推荐]

golangd-ai免费代码助手安装使用,pycharm可以使用&#xff0c;估计只要是xx的ide都是可以使用这个插件 目前GPT4以及gpt的大规模使用&#xff0c;如何快速掌握以及在ide中快速使用的办法&#xff0c;今天安装一款golangd编辑器的插件已经使用 一、安装以及使用 1.在golangd中…

贼全! 一举通关的 Spring+SpringBoot+SpringCloud 全攻略, 是真香啊

前几天&#xff0c;有幸从朋友那里得到了一份 Alibaba 内部的墙裂推荐的“玩转 Spring 全家桶的 PDF”&#xff0c;我也不是个吝啬的人&#xff0c;好的东西当然要一起分享。那今天我就秀一把&#xff0c;带你一站通关 Spring、Spring Boot 与 Spring Cloud,让你轻松斩获大厂 O…

Statefulset部署应用

上一部分我们分享到了使用 RS 没有办法让自己管理的多个 pod 都有一个独立的持久化声明&#xff0c;RS 没有办法在指定模板中对不同的 pod 做差异化处理 使用多个 RS 来分别管理自己的的一个 pod&#xff0c;当我们扩缩容的时候&#xff0c;也会出现问题&#xff0c;老的 pod …

C# 关于使用newlife包将webapi接口寄宿于一个控制台程序、winform程序、wpf程序运行

C# 关于使用newlife包将webapi接口寄宿于一个控制台程序、winform程序、wpf程序运行 安装newlife包 Program的Main()函数源码 using ConsoleApp3; using NewLife.Log;var server new NewLife.Http.HttpServer {Port 8080,Log XTrace.Log,SessionLog XTrace.Log }; serv…

【微服务架构设计】微服务不是魔术:处理超时

微服务很重要。它们可以为我们的架构和团队带来一些相当大的胜利&#xff0c;但微服务也有很多成本。随着微服务、无服务器和其他分布式系统架构在行业中变得更加普遍&#xff0c;我们将它们的问题和解决它们的策略内化是至关重要的。在本文中&#xff0c;我们将研究网络边界可…

使用贝叶斯算法完成文档分类问题

贝叶斯原理 贝叶斯原理&#xff08;Bayes theorem&#xff09;是一种用于计算条件概率的数学公式。它是以18世纪英国数学家托马斯贝叶斯&#xff08;Thomas Bayes&#xff09;的名字命名的。贝叶斯原理表达了在已知某个事件发生的情况下&#xff0c;另一个事件发生的概率。具体…

二十三种设计模式第十七篇--迭代子模式

迭代子模式是一种行为型设计模式&#xff0c;它允许你按照特定方式访问一个集合对象的元素&#xff0c;而又不暴露该对象的内部结构。迭代子模式提供了一种统一的方式来遍历容器中的元素&#xff0c;而不需要关心容器的底层实现。 该模式包含以下几个关键角色&#xff1a; 迭…

计算机网络(1) --- 网络介绍

目录 1.介绍协议 基础知识 协议 协议分层 OSI七层模型 2.TCP/IP五层模型 3.网络传输的基本流程 1.基本知识 协议报头 2.局域网通信的基本流程 3.网络传输流程 局域网分类 跨路由器传输 数据包封装和分用 4.网络中的地址管理 1.IP地址 2.MAC地址 3.区别 1.介绍…

训练自己的行文本检测EAST模型

训练自己的行文本检测EAST模型 训练数据格式 训练数据格式

模糊神经网络机械故障诊断(MATLAB代码)

效果 用训练好的模糊神经网络对机械故障进行诊断,根据网络的预测值得到机械的技术状态。预测值小于 1.5 时为正常状态,预测值在 1.5~2.5 之间时为曲轴轴承轻微异响,预测值在 2.5~3.5 之间时为曲轴轴承严重异响预测值在 3.5~4.5 之间时为连杆轴承轻微异响,预测值大于 4.5 时为连…