【如何训练一个中英翻译模型】LSTM机器翻译模型部署之ncnn(python)(五)

系列文章
【如何训练一个中英翻译模型】LSTM机器翻译seq2seq字符编码(一)
【如何训练一个中英翻译模型】LSTM机器翻译模型训练与保存(二)
【如何训练一个中英翻译模型】LSTM机器翻译模型部署(三)
【如何训练一个中英翻译模型】LSTM机器翻译模型部署之onnx(python)(四)

目录

  • 一、事情准备
  • 二、模型转换
  • 三、ncnn模型加载与推理(python版)

一、事情准备

这篇是在【如何训练一个中译英翻译器】LSTM机器翻译模型部署之onnx(python)(四)的基础上进行的,要用到文件为:

input_words.txt
target_words.txt
config.json
encoder_model-sim.onnx
decoder_model-sim.onnx

其中的onnx就是用来转为ncnn模型的,这里借助了onnx这个中间商,所以前面我们需要先通过onnxsim对模型进行simplify,要不然在模型转换时会出现op不支持的情况(模型转换不仅有中间商这个例子,目前还可以通过pnnx直接将pytorch模型转为ncnn,感兴趣的小伙伴可以去折腾下)
老规矩,先给出工具:

onnx2ncnn:https://github.com/Tencent/ncnn
netron:https://netron.app

二、模型转换

这里进行onnx转ncnn,通过命令进行转换

onnx2ncnn onnxModel/encoder_model-sim.onnx ncnnModel/encoder_model.param ncnnModel/encoder_model.bin
onnx2ncnn onnxModel/decoder_model-sim.onnx ncnnModel/decoder_model.param ncnnModel/decoder_model.bin

转换成功可以看到:
在这里插入图片描述
转换之后可以对模型进行优化,但是奇怪的是,这里优化了不起作用,去不了MemoryData这些没用的op

ncnnoptimize ncnnModel/encoder_model.param ncnnModel/encoder_model.bin ncnnModel/encoder_model.param ncnnModel/encoder_model.bin 1
ncnnoptimize ncnnModel/decoder_model.param ncnnModel/decoder_model.bin ncnnModel/decoder_model.param ncnnModel/decoder_model.bin 1

三、ncnn模型加载与推理(python版)

跟onnx的推理比较类似,就是函数的调用方法有点不同,这里先用python实现,验证下是否没问题,方面后面部署到其它端,比如android。
主要包括:模型加载、推理模型搭建跟模型推理,但要注意的是这里的输入输出名称需要在param这个文件里面获取。

采用netron分别查看encoder与decoder的网络结构,获取输入输出名称:

encoder:
输入输出分别如图
在这里插入图片描述
decoder:

输入
在这里插入图片描述
输出:
在这里插入图片描述

推理代码如下,推理过程感觉没问题,但是推理输出结果相差很大(对比过第一层ncnn与onnx的推理结果了),可能问题出在模型转换环节的精度损失上,而且第二层模型转换后网络输出结果不一致了,很迷,还没找出原因,但是以下的推理是能运行通过,只不过输出结果有问题

import numpy as np
import ncnn# 加载字符
# 从 input_words.txt 文件中读取字符串
with open('config/input_words.txt', 'r') as f:input_words = f.readlines()input_characters = [line.rstrip('\n') for line in input_words]# 从 target_words.txt 文件中读取字符串
with open('config/target_words.txt', 'r', newline='') as f:target_words = [line.strip() for line in f.readlines()]target_characters = [char.replace('\\t', '\t').replace('\\n', '\n') for char in target_words]#字符处理,以方便进行编码
input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])# something readable.
reverse_input_char_index = dict((i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict((i, char) for char, i in target_token_index.items())
num_encoder_tokens = len(input_characters) # 英文字符数量
num_decoder_tokens = len(target_characters) # 中文文字数量import json
with open('config/config.json', 'r') as file:loaded_data = json.load(file)# 从加载的数据中获取max_encoder_seq_length和max_decoder_seq_length的值
max_encoder_seq_length = loaded_data["max_encoder_seq_length"]
max_decoder_seq_length = loaded_data["max_decoder_seq_length"]# Load the ncnn models for the encoder and decoder
encoderNet = ncnn.Net()
encoderNet.load_param("ncnnModel/encoder_model.param")
encoderNet.load_model("ncnnModel/encoder_model.bin")decoderNet = ncnn.Net()
decoderNet.load_param("ncnnModel/decoder_model.param")
decoderNet.load_model("ncnnModel/decoder_model.bin")def decode_sequence(input_seq):# Encode the input as state vectors.# print(input_seq)ex_encoder = encoderNet.create_extractor()ex_encoder.input("input_1", ncnn.Mat(input_seq))states_value = []_, LSTM_1 = ex_encoder.extract("lstm")_, LSTM_2 = ex_encoder.extract("lstm_1")states_value.append(LSTM_1)states_value.append(LSTM_2)# print(ncnn.Mat(input_seq))# print(vgdgd)# Generate empty target sequence of length 1.target_seq = np.zeros((1, 1, 849))# Populate the first character of target sequence with the start character.target_seq[0, 0, target_token_index['\t']] = 1.# this target_seq you can treat as initial state# Sampling loop for a batch of sequences# (to simplify, here we assume a batch of size 1).stop_condition = Falsedecoded_sentence = ''ex_decoder = decoderNet.create_extractor()while not stop_condition:#print(ncnn.Mat(target_seq))print("---------")ex_decoder.input("input_2", ncnn.Mat(target_seq))ex_decoder.input("input_3", states_value[0])ex_decoder.input("input_4", states_value[1])_, output_tokens = ex_decoder.extract("dense")_, h = ex_decoder.extract("lstm_1")_, c = ex_decoder.extract("lstm_1_1")print(output_tokens)tk = []for i in range(849):tk.append(output_tokens[849*i])tk = np.array(tk)output_tokens = tk.reshape(1,1,849)print(output_tokens)# print(fdgd)print(h)print(c)# output_tokens = np.array(output_tokens)# output_tokens = output_tokens.reshape(1, 1, -1)# # h = np.array(h)# # c = np.array(c)# print(output_tokens.shape)# print(h.shape)# print(c.shape)#output_tokens, h, c = decoder_model.predict([target_seq] + states_value)# Sample a token# argmax: Returns the indices of the maximum values along an axis# just like find the most possible charsampled_token_index = np.argmax(output_tokens[0, -1, :])# find char using indexsampled_char = reverse_target_char_index[sampled_token_index]# and append sentencedecoded_sentence += sampled_char# Exit condition: either hit max length# or find stop character.if (sampled_char == '\n' or len(decoded_sentence) > max_decoder_seq_length):stop_condition = True# Update the target sequence (of length 1).# append then ?# creating another new target_seq# and this time assume sampled_token_index to 1.0target_seq = np.zeros((1, 1, num_decoder_tokens))target_seq[0, 0, sampled_token_index] = 1.print(sampled_token_index)# Update states# update states, frome the front partsstates_value = [h, c]return decoded_sentenceimport numpy as npinput_text = "Call me."
encoder_input_data = np.zeros((1,max_encoder_seq_length, num_encoder_tokens),dtype='float32')
for t, char in enumerate(input_text):print(char)# 3D vector only z-index has char its value equals 1.0encoder_input_data[0,t, input_token_index[char]] = 1.input_seq = encoder_input_data
decoded_sentence = decode_sequence(input_seq)
print('-')
print('Input sentence:', input_text)
print('Decoded sentence:', decoded_sentence)

decoder的模型输出为849*849,感觉怪怪的,然后我们把模型的输入固定下来看看是不是模型的问题。
打开decoder_model.param,把输入层固定下来,0=w 1=h 2=c,那么:
input_2:0=849 1=1 2=1
input_3:0=256 1=1
input_4:0=256 1=1

运行以下命令进行优化

ncnnoptimize ncnnModel/decoder_model.param ncnnModel/decoder_model.bin ncnnModel/decoder_model.param ncnnModel/decoder_model.bin 1

结果如下:
在这里插入图片描述
打开网络来看一下:
可以看到输出确实是849849(红色框),那就是模型转换有问题了
在这里插入图片描述
仔细看,能够看到有两个shape(蓝色框)分别为849跟849
1,这两个不同维度的网络进行BinaryOP之后,就变成849849了,那么,我们把Reshape这个网络去掉试试(不把前面InnerProduct的输入维度有849reshape为8491),下面来看手术刀怎么操作。

我们需要在没经过固定维度并ncnnoptimize的模型上操作(也就是没经过上面0=w 1=h 2=c修改的模型上操作)
根据名字我们找到Reshape那一层:
在这里插入图片描述
然后找到与reshape那一层相连接的上一层(红色框)与下一层(蓝色框)
在这里插入图片描述
通过红色框与蓝色框里面的名字我们找到了上层与下层分别为InnerProduct与BinaryOp
在这里插入图片描述
这时候,把InnerProduct与BinaryOp接上,把Reshape删掉
在这里插入图片描述
再改一下最上面的层数,把19改为18,因为我们删掉了一层
在这里插入图片描述保存之后再次执行

ncnnoptimize ncnnModel/decoder_model.param ncnnModel/decoder_model.bin ncnnModel/decoder_model.param ncnnModel/decoder_model.bin 1

执行后可以看到网络层数跟blob数都更新了
在这里插入图片描述

这时候改一下固定一下输入层数,并运行ncnnoptimize,再打开netron看一下网络结构,可以看到输出维度正常了
在这里插入图片描述
但是通过推理结果还是不对,没找到原因,推理代码如下:

import numpy as np
import ncnn# 加载字符
# 从 input_words.txt 文件中读取字符串
with open('config/input_words.txt', 'r') as f:input_words = f.readlines()input_characters = [line.rstrip('\n') for line in input_words]# 从 target_words.txt 文件中读取字符串
with open('config/target_words.txt', 'r', newline='') as f:target_words = [line.strip() for line in f.readlines()]target_characters = [char.replace('\\t', '\t').replace('\\n', '\n') for char in target_words]#字符处理,以方便进行编码
input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])# something readable.
reverse_input_char_index = dict((i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict((i, char) for char, i in target_token_index.items())
num_encoder_tokens = len(input_characters) # 英文字符数量
num_decoder_tokens = len(target_characters) # 中文文字数量import json
with open('config/config.json', 'r') as file:loaded_data = json.load(file)# 从加载的数据中获取max_encoder_seq_length和max_decoder_seq_length的值
max_encoder_seq_length = loaded_data["max_encoder_seq_length"]
max_decoder_seq_length = loaded_data["max_decoder_seq_length"]# Load the ncnn models for the encoder and decoder
encoderNet = ncnn.Net()
encoderNet.load_param("ncnnModel/encoder_model.param")
encoderNet.load_model("ncnnModel/encoder_model.bin")decoderNet = ncnn.Net()
decoderNet.load_param("ncnnModel/decoder_model.param")
decoderNet.load_model("ncnnModel/decoder_model.bin")def decode_sequence(input_seq):# Encode the input as state vectors.# print(input_seq)ex_encoder = encoderNet.create_extractor()ex_encoder.input("input_1", ncnn.Mat(input_seq))states_value = []_, LSTM_1 = ex_encoder.extract("lstm")_, LSTM_2 = ex_encoder.extract("lstm_1")states_value.append(LSTM_1)states_value.append(LSTM_2)# print(ncnn.Mat(input_seq))# print(vgdgd)# Generate empty target sequence of length 1.target_seq = np.zeros((1, 1, 849))# Populate the first character of target sequence with the start character.target_seq[0, 0, target_token_index['\t']] = 1.# this target_seq you can treat as initial state# Sampling loop for a batch of sequences# (to simplify, here we assume a batch of size 1).stop_condition = Falsedecoded_sentence = ''ex_decoder = decoderNet.create_extractor()while not stop_condition:#print(ncnn.Mat(target_seq))print("---------")ex_decoder.input("input_2", ncnn.Mat(target_seq))ex_decoder.input("input_3", states_value[0])ex_decoder.input("input_4", states_value[1])_, output_tokens = ex_decoder.extract("dense")_, h = ex_decoder.extract("lstm_1")_, c = ex_decoder.extract("lstm_1_1")print(output_tokens)# print(ghfhf)# tk = []# for i in range(849):#     tk.append(output_tokens[849*i])# tk = np.array(tk)# output_tokens = tk.reshape(1,1,849)# print(output_tokens)# print(fdgd)print(h)print(c)output_tokens = np.array(output_tokens)output_tokens = output_tokens.reshape(1, 1, -1)# # h = np.array(h)# # c = np.array(c)# print(output_tokens.shape)# print(h.shape)# print(c.shape)#output_tokens, h, c = decoder_model.predict([target_seq] + states_value)# Sample a token# argmax: Returns the indices of the maximum values along an axis# just like find the most possible charsampled_token_index = np.argmax(output_tokens[0, -1, :])# find char using indexsampled_char = reverse_target_char_index[sampled_token_index]# and append sentencedecoded_sentence += sampled_char# Exit condition: either hit max length# or find stop character.if (sampled_char == '\n' or len(decoded_sentence) > max_decoder_seq_length):stop_condition = True# Update the target sequence (of length 1).# append then ?# creating another new target_seq# and this time assume sampled_token_index to 1.0target_seq = np.zeros((1, 1, num_decoder_tokens))target_seq[0, 0, sampled_token_index] = 1.print(sampled_token_index)# Update states# update states, frome the front partsstates_value = [h, c]return decoded_sentenceimport numpy as npinput_text = "Call me."
encoder_input_data = np.zeros((1,max_encoder_seq_length, num_encoder_tokens),dtype='float32')
for t, char in enumerate(input_text):print(char)# 3D vector only z-index has char its value equals 1.0encoder_input_data[0,t, input_token_index[char]] = 1.input_seq = encoder_input_data
decoded_sentence = decode_sequence(input_seq)
print('-')
print('Input sentence:', input_text)
print('Decoded sentence:', decoded_sentence)

参考文献:https://github.com/Tencent/ncnn/issues/2586

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/13969.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

flask的配置项

flask的配置项 为了使 Flask 应用程序正常运行,有多种配置选项需要考虑。下面是一些基本的 Flask 配置选项: DEBUG: 这个配置项决定 Flask 是否应该在调试模式下运行。如果这个值被设为 True,Flask 将会提供更详细的错误信息,并…

go 查询采购单设备事项V3

一、版本说明 本版本在整合上两次的功能基础上,引进ini配置文件的读取事项,快速读取本地配置文件,完成读取设置 第一版:实现了严格匹配模式的查找 https://blog.csdn.net/wtt234/article/details/131979385 第二版:实…

作为程序员,你很有必要了解一下IVX

一、IVX是什么 iVX 是一个“零代码”的可视化编程平台,拥有方便的在线集成开发环境,不需要下载开发环境,打开浏览器即可随时随地进行项目编辑。iVX 还拥有“一站式”的云资源,通过这一套一站式服务,iVX 可以实现一站式…

详解STM32的GPIO八种输入输出模式,GPIO各种输入输出的区别、初始化的步骤详解,看这文章就行了(超详细)

在STM32微控制器中,常见的输入输出(GPIO)模式有八种,分别是推挽输出、开漏输出、复用推挽输出、复用开漏输出、浮空输入、上拉输入、下拉输入和模拟输入。下面我将为你解释每种模式的特点和区别,并提供相应的示例代码。 文章目录 介绍区别初…

MySql002——关系型数据库基础知识

前言:因为本专栏学习的是MySQL,而MySQL是关系型数据库,所以这篇文章就来介绍一些关系型数据库基础知识,至于其他知识小伙伴们可以自行学习,同时不足之处也欢迎批评指正,谢谢~ 一、MySQL关系型数据库(RDBMS)…

从实践彻底掌握MySQL的主从复制

目录 一、本次所用结构如图---一主多从级联: 二、IP。 三、配置M1: 四、从库M1S1: 五、从库M2配置: 六、 从库M2S1: 一、本次所用结构如图--- 一主多从级联: 二、IP。这里M1S1和M1S2一样的&#xff0…

图技术在 LLM 下的应用:知识图谱驱动的大语言模型 Llama Index

LLM 如火如荼地发展了大半年,各类大模型和相关框架也逐步成型,可被大家应用到业务实际中。在这个过程中,我们可能会遇到一类问题是:现有的哪些数据,如何更好地与 LLM 对接上。像是大家都在用的知识图谱,现在…

查看maven发布时间的方法

查看maven发布时间的方法如下【 打开maven官网 选中Release Notes 即可查看对应版本的发布时间 】

【计算机网络】第 4 课 - 物理层

欢迎来到博主 Apeiron 的博客,祝您旅程愉快 ! 时止则止,时行则行。动静不失其时,其道光明。 目录 1、物理层的基本概念 2、物理层协议的主要任务 3、物理层任务 4、总结 1、物理层的基本概念 在计算机网络中,用来…

基于多场景的考虑虑热网网损的太阳能消纳能力评估研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

windows切换php版本以及composer

前提 安装php8.2 安装Php7.4 下载 nts是非线程安全的,这里选择线程安全的,选择64位 解压缩 修改系统环境变量 修改为php-7的 cmd中输入php -v查看 找到composer存放路径C:\ProgramData\ComposerSetup\bin 将三个文件复制到php目录下 重启电脑…

2023年深圳杯数学建模赛题浅析

由于今明两日由于一些不可避免的事情,这里仅仅先给大家简单写一个赛题浅析,详细过程步骤思路以及讲解视频预计后天发布 A题 影响城市居民身体健康的因素分析 A题以慢性病为命题背景,给出数据以及题目初步来看来看为一个数据处理数据分析的综…

oracle12c静默安装

目录 前言 安装配置步骤 关闭防火墙,禁止防火墙开机自启 关闭selinux yum安装必要安装包 内网环境下载依赖包的方式 创建用户和组 创建oinstall和dba组 创建oracle用户 设置oracle密码 查看创建结果 修改内核参数 使配置生效 修改用户及文件限制 改文件限制 修改用…

PHP使用Redis实战实录1:宝塔环境搭建、6379端口配置、Redis服务启动失败解决方案

宝塔环境搭建、6379端口配置、Redis服务启动失败解决方案 前言一、Redis安装部署1.安装Redis2.php安装Redis扩展3.启动Redis 二、避坑指南1.6379端口配置2.Redis服务启动(1)Redis服务启动失败(2)Redis启动日志排查(3&a…

《向量数据库指南》:向量数据库Pinecone如何集成LangChain (一)

目录 LangChain中的检索增强 建立知识库 欢迎使用Pinecone和LangChain的集成指南。本文档涵盖了将高性能向量数据库Pinecone与基于大型语言模型(LLMs)构建应用程序的框架LangChain集成的步骤。 Pinecone使开发人员能够基于向量相似性搜索构建可扩展的实时推荐和搜索系统…

CentOS 项目发出一篇奇怪的博文

导读最近,在红帽限制其 RHEL 源代码的访问之后,整个社区围绕这件事发生了很多事情。 CentOS 项目发出一篇奇怪的博文 周五,CentOS 项目董事会发出了一篇模糊不清的简短博文,文中称,“发展社区并让人们更容易做出贡献…

Vue学习Day3——生命周期\组件化

一、Vue生命周期 Vue生命周期:就是一个Vue实例从创建 到 销毁 的整个过程。 生命周期四个阶段:① 创建 ② 挂载 ③ 更新 ④ 销毁 1.创建阶段:创建响应式数据 2.挂载阶段:渲染模板 3.更新阶段:修改数据,更…

【计算机网络】11、网桥(bridge)、集线器(hub)、交换机(switch)、路由器(router)、网关(gateway)

文章目录 一、网桥(bridge)二、集线器(hub)三、交换机(switch)四、路由器(router)五、网关(gateway) 对于hub,一个包过来后,直接将包转发到其他口。 对于桥&…

区块链 2.0笔记

区块链 2.0 以太坊概述 相对于比特币的几点改进 缩短出块时间至10多秒ghost共识机制mining puzzle BTC:计算密集型ETH:memory-hard(限制ASIC) proof of work->proof of stake对智能合约的支持 BTC:decentralized currencyETH:decentral…

怎么查看gcc的安装路径

2023年7月29日 很简单,通过在命令行输入如下命令就可以了: gcc -print-search-dirs在Windows中 在Linux中 ​​​