【如何训练一个中译英翻译器】LSTM机器翻译模型部署之ncnn(python)(四)

ncnn:https://github.com/Tencent/ncnn

1、.h5模型保存为TFSaveModel格式

import tensorflow as tf
from keras.models import load_model# 加载Keras模型
model = load_model('encoder_model.h5')# 转换为SavedModel类型
tf.saved_model.save(model, 'TFSaveModel')

2、TFSaveModel格式模型保存为onnx模型

python3 -m tf2onnx.convert --saved-model TFSaveModel --output onnxModel/encoder_model.onnx

打开https://netron.app/来看下网络结构,主要是先看输入部分的维度(网络结构后面会细讲哈)
可以看到输入维度[unk__64、unk__65、62],我们需要将unk__64、unk__65这两个改为具体数值,否则在导出ncnn模型时会报一些op不支持的错误,那么问题来了,要怎么改呢,我也不知道啊!!!
哈哈哈,开完笑的,都写出来了,怎么会不知道,请听我慢慢说来。
在这里插入图片描述
其实数据第一个是batch,第二个是输入句子的最大长度,第三个是字符总数量,我们在推理时,batch size一般为1,所以这个input_1的shape就是[1,max_encoder_seq_length, num_encoder_tokens]
max_encoder_seq_length, num_encoder_tokens 这两个参数可以在训练的时候获取到了,拿到这个input shape 之后,对onnx模型进行simplify,我训练出来的模型时得到的shape是[1,16,62],因此执行以下命令:

python3 -m onnxsim onnxModel/encoder_model.onnx onnxModel/encoder_model-sim.onnx --overwrite-input-shape 1,16,62

可得到简化后的onnx模型啦
在这里插入图片描述
这个时候,我们再用https://netron.app打开encoder_model-sim.onnx,可以看到encoder模型的输出了,有两个输出,均为[1,256]的维度
在这里插入图片描述
然后我们需要对decoder_model.h5也进行转换,

import tensorflow as tf
from keras.models import load_model# 加载Keras模型
model = load_model('decoder_model.h5')# 转换为SavedModel类型
tf.saved_model.save(model, 'TFSaveModel')
python3 -m tf2onnx.convert --saved-model TFSaveModel --output onnxModel/decoder_model.onnx

同样打开模型来看,能看到一共有三个输入,其中的input_3:[unk__57,256],input_4:[unk__58,256],为encoder的输出,因此可以得到这两个输入维度均为[1,256],那 input_2:[unk__55,unk__56,849]的维度是多少呢,我们接着往下看。
在这里插入图片描述
我们想一想,解码器除了接受编码器的数据,还有什么数据没给它呢,没有错,就是target_characters的特征,对于英翻中而言就是中文的字符,要解码器解出中文,肯定要把中文数据给它,要不然你让解码器去解空气嘛,实际上这个 input_2的维度就是

target_seq = np.zeros((1, 1, num_decoder_tokens))

num_decoder_tokens同样可以在训练的时候获取到(至于不知道怎么来的,可以看这个系列文章的第一、二篇),我这边得到的num_decoder_tokens是849,当然实际上这个模型的 input_2:[unk__55,unk__56,849]已经给了num_decoder_tokens,我们只需要把unk__55,unk__56都改为1就可以了,即[1,1,849],那么对onnx进行simplify

python3 -m onnxsim onnxModel/decoder_model.onnx onnxModel/decoder_model-sim.onnx --overwrite-input-shape input_2:1,1,849 input_3:1,256 input_4:1,256

成功完成simplify可得到:
在这里插入图片描述完成了onnx模型的转换之后,我们要做的就是将模型转换为ncnn模型

3、onnx模型转换为ncnn

onnx2ncnn onnxModel/encoder_model-sim.onnx ncnnModel/encoder_model.param ncnnModel/encoder_model.bin
onnx2ncnn onnxModel/decoder_model-sim.onnx ncnnModel/decoder_model.param ncnnModel/decoder_model.bin

转换成功可以看到:
在这里插入图片描述

ncnnoptimize ncnnModel/encoder_model.param ncnnModel/encoder_model.bin ncnnModel/encoder_model.param ncnnModel/encoder_model.bin 1
ncnnoptimize ncnnModel/decoder_model.param ncnnModel/decoder_model.bin ncnnModel/decoder_model.param ncnnModel/decoder_model.bin 1

4、ncnn模型加载与推理(python版)
有点问题,先把调试代码贴在下面吧

import numpy as np
import ncnn# 加载字符
# 从 input_words.txt 文件中读取字符串
with open('config/input_words.txt', 'r') as f:input_words = f.readlines()input_characters = [line.rstrip('\n') for line in input_words]# 从 target_words.txt 文件中读取字符串
with open('config/target_words.txt', 'r', newline='') as f:target_words = [line.strip() for line in f.readlines()]target_characters = [char.replace('\\t', '\t').replace('\\n', '\n') for char in target_words]#字符处理,以方便进行编码
input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])# something readable.
reverse_input_char_index = dict((i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict((i, char) for char, i in target_token_index.items())
num_encoder_tokens = len(input_characters) # 英文字符数量
num_decoder_tokens = len(target_characters) # 中文文字数量import json
with open('config/config.json', 'r') as file:loaded_data = json.load(file)# 从加载的数据中获取max_encoder_seq_length和max_decoder_seq_length的值
max_encoder_seq_length = loaded_data["max_encoder_seq_length"]
max_decoder_seq_length = loaded_data["max_decoder_seq_length"]encoder_model = ncnn.Net()encoder_model.load_param("ncnnModel/encoder_model.param")
encoder_model.load_model("ncnnModel/encoder_model.bin")decoder_model = ncnn.Net()
decoder_model.load_param("ncnnModel/decoder_model.param")
decoder_model.load_model("ncnnModel/decoder_model.bin")def decode_sequence(input_seq):# Encode the input as state vectors.ex_encoder = encoder_model.create_extractor()ex_encoder.input("input_1", ncnn.Mat(input_seq))_, LSTM_1 = ex_encoder.extract("LSTM__31:1")_, LSTM_2 = ex_encoder.extract("LSTM__31:2")print(LSTM_1)print(LSTM_2)# Generate empty target sequence of length 1.target_seq = np.zeros((1, 1, 849))# Populate the first character of target sequence with the start character.target_seq[0, 0, target_token_index['\t']] = 1.# this target_seq you can treat as initial state# Sampling loop for a batch of sequences# (to simplify, here we assume a batch of size 1).stop_condition = Falsedecoded_sentence = ''while not stop_condition:ex_decoder = decoder_model.create_extractor()print(ncnn.Mat(target_seq))print("---------")ex_decoder.input("input_2", ncnn.Mat(target_seq))ex_decoder.input("input_3", LSTM_1)ex_decoder.input("input_4", LSTM_2)_, output_tokens = ex_decoder.extract("dense")_, h = ex_decoder.extract("lstm_1")_, c = ex_decoder.extract("lstm_1_1")print(output_tokens)print(h)print(c)print(fdsf)output_tokens = np.array(output_tokens)h = np.array(h)c = np.array(c)print(output_tokens.shape)print(output_tokens.shape)print(h.shape)print(c.shape)#print(gfdgd)#output_tokens, h, c = decoder_model.predict([target_seq] + states_value)# Sample a token# argmax: Returns the indices of the maximum values along an axis# just like find the most possible charsampled_token_index = np.argmax(output_tokens[0, -1, :])# find char using indexsampled_char = reverse_target_char_index[sampled_token_index]# and append sentencedecoded_sentence += sampled_char# Exit condition: either hit max length# or find stop character.if (sampled_char == '\n' or len(decoded_sentence) > max_decoder_seq_length):stop_condition = True# Update the target sequence (of length 1).# append then ?# creating another new target_seq# and this time assume sampled_token_index to 1.0target_seq = np.zeros((1, 1, num_decoder_tokens))target_seq[0, 0, sampled_token_index] = 1.# Update states# update states, frome the front partsstates_value = [h, c]return decoded_sentenceimport numpy as npinput_text = "Call me."
encoder_input_data = np.zeros((1,max_encoder_seq_length, num_encoder_tokens),dtype='float32')
for t, char in enumerate(input_text):print(char)# 3D vector only z-index has char its value equals 1.0encoder_input_data[0,t, input_token_index[char]] = 1.input_seq = encoder_input_data
decoded_sentence = decode_sequence(input_seq)
print('-')
print('Input sentence:', input_text)
print('Decoded sentence:', decoded_sentence)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/6990.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux--assert函数在debug和release下的区别

在debug模式下断言才有效,而在release版本下,断言就无效了 在debug模式下,assert函数用于检查条件是否为真,如果条件为假,则会输出相应的错误信息,并停止程序执行。而在release模式下,assert函数…

redis的常用命令和数据结构

目录 redis的基本特征 Redis操作命令行 redis的数据结构 Redis的基本特征 键值型,value支持多种不同的数据结构,功能丰富 单线程,每个命令具备原子性 低延迟,速快(基于内存,IO多路复用,良好…

无涯教程-jQuery - html( val )方法函数

html(val)方法设置每个匹配元素的html内容。此属性在XML文档上不可用。 html( val ) - 语法 selector.html( val ) 这是此方法使用的所有参数的描述- val - 这是要设置的html内容。 html( val ) - 示例 以下是一个简单的示例&#xff0c;简单说明了此方法的用法- <…

【OpenCV】windows环境下,java OpenCV环境搭建,java 也可以实现opencv的功能了!opencv自由了

目录 1. 下载opencv 2. 安装opencv 目录 1. 下载opencv 2. 安装opencv 3. dll文件的导入配置 dll文件的导入&#xff1a; &#xff08;C的类库文件&#xff09;&#xff0c;opencv是c开发的类库&#xff0c;java语言要调用其中的方法&#xff0c;所以依赖了dll文件 3.1…

前端框架学习-Vue(一)

Vue简介 百度百科上关于vue的词条&#xff0c;说vue时一款渐进式JavaScript框架&#xff0c; 简单来说,渐进式是一种设计理念,即在不失去当前功能的前提下,逐步添加新的特性。 说明它时一直在进行维护的。 Vue3&#xff0c;中使用*.vue作为文件后缀&#xff0c;html&#xff0c…

解决安装依赖时报错:npm ERR! code ERESOLVE

系列文章目录 文章目录 系列文章目录前言一、错误原因二、解决方法三、注意事项总结 前言 在使用 npm 安装项目依赖时&#xff0c;有时会遇到错误信息 “npm ERR! code ERESOLVE”&#xff0c;该错误通常发生在依赖版本冲突或者依赖解析问题时。本文将详细介绍出现这个错误的原…

git指定tag只拉取某个release版本代码,节约贷款

采用-b指定tag&#xff0c;--depth1指定只拉取最后一个版本的代码&#xff0c;日志如下 yeqiangyeqiang-MS-7B23:~/Downloads/src$ git clone --depth1 -b 7cbf1a2 https://github.com/llvm/llvm-project 正克隆到 llvm-project... warning: 不能发现要克隆的远程分支 7cbf1a2…

R语言机器学习之影像组学分析的原理详解

概要 影像组学从常规医学图像中高通量提取大量的放射学定量数据&#xff0c;并以非侵入性方式探索它们与临床结果的相关性&#xff0c;在医学研究中得到广泛的应用。 01 影像组学&#xff08;Radiomics&#xff09;的概念&#xff1a; 影像组学&#xff08;Radiomics&#xff…

React Context(上下文)

1 Context Context 通过组件树提供了一个数据传递的方法&#xff0c;从而避免了在每一个层级手动传递props属性。 有部分小伙伴应该使用props属性进行组件上下传值的操作。当多个组件嵌套的时候&#xff0c;就需要慢慢向上寻找最初的值是什么。 2 API React.createContext:…

Jmeter之Beanshell解析并提取json响应

1&#xff1a;前置条件 将fastjson-1.2.49.jar包置于jmeter的lib目录下&#xff0c;并将该jar包添加到测试计划的Library中&#xff1b;否则会报&#xff1a;Typed variable declaration : Class: JSONObject not found in namespace的错误 2&#xff1a;解析思路 利用beansh…

了解Unity编辑器之组件篇UI(一)

UI组件&#xff1a;提供了用户交互&#xff0c;信息展示&#xff0c;用户导航等功能 一、Button&#xff1a;用于响应用户的点击事件 1.Interactable&#xff08;可交互&#xff09;&#xff1a;该属性控制按钮是否可以与用户交互&#xff0c;如果禁用则按钮无法被点击。可以通…

Ubuntu18.04配置PX4开发环境

源文件下载 读者可以参考PX4中文维基百科&#xff0c;或者使用下面命令↓ git clone https://github.com/PX4/PX4-Autopilot.git --recursive 下载完成之后&#xff0c;执行脚本安装命令&#xff0c;PX4给我们提供了脚本安装模式 bash ./PX4-Autopilot/Tools/setup/ubuntu.sh …

Spring Boot-3

学习笔记&#xff08;今天又读了好多篇的博客&#xff0c;做个今天的总结&#xff0c;加油&#xff01;&#xff01;&#xff01;&#xff09; PS&#xff1a;快到中伏了&#xff0c;今天还是好热 使用阿里巴巴 FastJson 的设置 1、jackson 和 fastJson 的对比 有很多人已经…

Linux 网络收包流程

哈喽大家好&#xff0c;我是咸鱼 我们在跟别人网上聊天的时候&#xff0c;有没有想过你发送的信息是怎么传到对方的电脑上的 又或者我们在上网冲浪的时候&#xff0c;有没有想过 HTML 页面是怎么显示在我们的电脑屏幕上的 无论是我们跟别人聊天还是上网冲浪&#xff0c;其实…

Python绘制多条y轴范围不同的曲线并在一张图上显示

如何使用Python绘制多条y轴范围不同的曲线&#xff0c;然后把它们合并在一张图上显示 import matplotlib.pyplot as plt import numpy as npdef multilines(target, x, ys, types, colors, x_label, labels):"""用来绘制多条y轴范围不同的线&#xff0c;并在一…

代码随想录 DAY28 93.复原IP地址 78.子集 90.子集II

93.复原IP地址 切割字符串&#xff0c;并且在每一个切割过的字符串后面加上 ‘ .’ 返回条件&#xff1a;逗点个数3 如果最后一小节符合要求&#xff0c;就将该字符串添加到结果集中 循环中&#xff1a;从start到i 符合要求&#xff0c;就继续添加逗点和字符 不符合下面就不用…

数学建模的六个步骤

一、模型准备 了解问题的实际背景&#xff0c;明确其实际意义&#xff0c;掌握对象的各种信息&#xff0c;以数学思路来解释问题的精髓&#xff0c;数学思路贯彻问题的全过程&#xff0c;进而用数学语言来描述问题。要求符合数学理论&#xff0c;符合数学习惯&#xff0c;清晰…

苹果iOS 16.6 RC发布:或为iPhone X/8系列养老版本

今天苹果向iPhone用户推送了iOS 16.6 RC更新(内部版本号&#xff1a;20G75)&#xff0c;这是时隔两个月的首次更新。 按照惯例RC版基本不会有什么问题&#xff0c;会在最近一段时间内直接变成正式版&#xff0c;向所有用户推送。 需要注意的是&#xff0c;鉴于iOS 17正式版即将…

第1题 好的序列(seq)

一个长为k的序列b1, b2, ..., bk (1 ≤ b1 ≤ b2 ≤ ... ≤ bk ≤ n)&#xff0c;如果对所有的 i (1 ≤ i ≤ k - 1)&#xff0c;满足bi | bi1&#xff0c;那么它就是好的序列。这里a | b表示a是b的因子&#xff0c;或者说a能整除b。 给出n和k&#xff0c;求…

git取消文件或文件夹追踪

1. 创建仓库时&#xff0c;在本地仓库根目录&#xff0c;创建.gitignore文件&#xff0c;写入忽略规则。规则可以是文件名&#xff0c;或者正则表达式。git 对于 .gitignore配置文件是按行从上到下进行规则匹配的。对于.gitignore文件本身的修改也会被提交到远程端。 2. 删除已…