【如何训练一个中译英翻译器】LSTM机器翻译模型部署之ncnn(python)(五)

系列文章
【如何训练一个中译英翻译器】LSTM机器翻译seq2seq字符编码(一)
【如何训练一个中译英翻译器】LSTM机器翻译模型训练与保存(二)
【如何训练一个中译英翻译器】LSTM机器翻译模型部署(三)
【如何训练一个中译英翻译器】LSTM机器翻译模型部署之onnx(python)(四)

目录

  • 一、事情准备
  • 二、模型转换
  • 三、ncnn模型加载与推理(python版)

一、事情准备

这篇是在【如何训练一个中译英翻译器】LSTM机器翻译模型部署之onnx(python)(四)的基础上进行的,要用到文件为:

input_words.txt
target_words.txt
config.json
encoder_model-sim.onnx
decoder_model-sim.onnx

其中的onnx就是用来转为ncnn模型的,这里借助了onnx这个中间商,所以前面我们需要先通过onnxsim对模型进行simplify,要不然在模型转换时会出现op不支持的情况(模型转换不仅有中间商这个例子,目前还可以通过pnnx直接将pytorch模型转为ncnn,感兴趣的小伙伴可以去折腾下)
老规矩,先给出工具:

onnx2ncnn:https://github.com/Tencent/ncnn
netron:https://netron.app

二、模型转换

这里进行onnx转ncnn,通过命令进行转换

onnx2ncnn onnxModel/encoder_model-sim.onnx ncnnModel/encoder_model.param ncnnModel/encoder_model.bin
onnx2ncnn onnxModel/decoder_model-sim.onnx ncnnModel/decoder_model.param ncnnModel/decoder_model.bin

转换成功可以看到:
在这里插入图片描述
转换之后可以对模型进行优化,但是奇怪的是,这里优化了不起作用,去不了MemoryData这些没用的op

ncnnoptimize ncnnModel/encoder_model.param ncnnModel/encoder_model.bin ncnnModel/encoder_model.param ncnnModel/encoder_model.bin 1
ncnnoptimize ncnnModel/decoder_model.param ncnnModel/decoder_model.bin ncnnModel/decoder_model.param ncnnModel/decoder_model.bin 1

三、ncnn模型加载与推理(python版)

跟onnx的推理比较类似,就是函数的调用方法有点不同,这里先用python实现,验证下是否没问题,方面后面部署到其它端,比如android。
主要包括:模型加载、推理模型搭建跟模型推理,但要注意的是这里的输入输出名称需要在param这个文件里面获取。

采用netron分别查看encoder与decoder的网络结构,获取输入输出名称

有点问题,先把调试代码贴在下面了

import numpy as np
import ncnn# 加载字符
# 从 input_words.txt 文件中读取字符串
with open('config/input_words.txt', 'r') as f:input_words = f.readlines()input_characters = [line.rstrip('\n') for line in input_words]# 从 target_words.txt 文件中读取字符串
with open('config/target_words.txt', 'r', newline='') as f:target_words = [line.strip() for line in f.readlines()]target_characters = [char.replace('\\t', '\t').replace('\\n', '\n') for char in target_words]#字符处理,以方便进行编码
input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])# something readable.
reverse_input_char_index = dict((i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict((i, char) for char, i in target_token_index.items())
num_encoder_tokens = len(input_characters) # 英文字符数量
num_decoder_tokens = len(target_characters) # 中文文字数量import json
with open('config/config.json', 'r') as file:loaded_data = json.load(file)# 从加载的数据中获取max_encoder_seq_length和max_decoder_seq_length的值
max_encoder_seq_length = loaded_data["max_encoder_seq_length"]
max_decoder_seq_length = loaded_data["max_decoder_seq_length"]encoder_model = ncnn.Net()encoder_model.load_param("ncnnModel/encoder_model.param")
encoder_model.load_model("ncnnModel/encoder_model.bin")decoder_model = ncnn.Net()
decoder_model.load_param("ncnnModel/decoder_model.param")
decoder_model.load_model("ncnnModel/decoder_model.bin")def decode_sequence(input_seq):# Encode the input as state vectors.ex_encoder = encoder_model.create_extractor()ex_encoder.input("input_1", ncnn.Mat(input_seq))_, LSTM_1 = ex_encoder.extract("LSTM__31:1")_, LSTM_2 = ex_encoder.extract("LSTM__31:2")print(LSTM_1)print(LSTM_2)# Generate empty target sequence of length 1.target_seq = np.zeros((1, 1, 849))# Populate the first character of target sequence with the start character.target_seq[0, 0, target_token_index['\t']] = 1.# this target_seq you can treat as initial state# Sampling loop for a batch of sequences# (to simplify, here we assume a batch of size 1).stop_condition = Falsedecoded_sentence = ''while not stop_condition:ex_decoder = decoder_model.create_extractor()print(ncnn.Mat(target_seq))print("---------")ex_decoder.input("input_2", ncnn.Mat(target_seq))ex_decoder.input("input_3", LSTM_1)ex_decoder.input("input_4", LSTM_2)_, output_tokens = ex_decoder.extract("dense")_, h = ex_decoder.extract("lstm_1")_, c = ex_decoder.extract("lstm_1_1")print(output_tokens)print(h)print(c)print(fdsf)output_tokens = np.array(output_tokens)h = np.array(h)c = np.array(c)print(output_tokens.shape)print(output_tokens.shape)print(h.shape)print(c.shape)#print(gfdgd)#output_tokens, h, c = decoder_model.predict([target_seq] + states_value)# Sample a token# argmax: Returns the indices of the maximum values along an axis# just like find the most possible charsampled_token_index = np.argmax(output_tokens[0, -1, :])# find char using indexsampled_char = reverse_target_char_index[sampled_token_index]# and append sentencedecoded_sentence += sampled_char# Exit condition: either hit max length# or find stop character.if (sampled_char == '\n' or len(decoded_sentence) > max_decoder_seq_length):stop_condition = True# Update the target sequence (of length 1).# append then ?# creating another new target_seq# and this time assume sampled_token_index to 1.0target_seq = np.zeros((1, 1, num_decoder_tokens))target_seq[0, 0, sampled_token_index] = 1.# Update states# update states, frome the front partsstates_value = [h, c]return decoded_sentenceimport numpy as npinput_text = "Call me."
encoder_input_data = np.zeros((1,max_encoder_seq_length, num_encoder_tokens),dtype='float32')
for t, char in enumerate(input_text):print(char)# 3D vector only z-index has char its value equals 1.0encoder_input_data[0,t, input_token_index[char]] = 1.input_seq = encoder_input_data
decoded_sentence = decode_sequence(input_seq)
print('-')
print('Input sentence:', input_text)
print('Decoded sentence:', decoded_sentence)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/7991.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【优选算法题练习】day8

文章目录 一、974. 和可被 K 整除的子数组1.题目简介2.解题思路3.代码4.运行结果 二、525. 连续数组1.题目简介2.解题思路3.代码4.运行结果 三、560. 和为 K 的子数组1.题目简介2.解题思路3.代码4.运行结果 总结 一、974. 和可被 K 整除的子数组 1.题目简介 974. 和可被 K 整…

【设计模式】单例设计模式详解(包含并发、JVM)

文章目录 1、背景2、单例模式3、代码实现1、第一种实现(饿汉式)为什么属性都是static的?2、第二种实现(懒汉式,线程不安全)3、第三种实现(懒汉式,线程安全)4、第四种实现…

day38-Mobile Tab Navigation(手机tab栏导航切换)

50 天学习 50 个项目 - HTMLCSS and JavaScript day38-Mobile Tab Navigation&#xff08;手机tab栏导航切换&#xff09; 效果 index.html <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta name"…

3ds MAX 洗菜池

在家居中我们显然离不开这个对吧 首先绘制一个长方体作为基础 注意设置长宽高的网格大小&#xff0c;方便后续调整 俯视图网格线如下&#xff1a; 长方形变换为可编辑网络&#xff0c;并在【多边形】界面选择底面的所有多边形&#xff0c;按delete删除&#xff0c;形成一个壳体…

Github上方导航栏介绍

Code Watch&#xff1a;相当于关注&#xff0c;到时候这个项目又有什么操作&#xff0c;就会以通知的形式提醒你。 Fork&#xff1a;也就是把这个项目拉到你的仓库里&#xff0c;之后你可以对该代码进行修改&#xff0c;之后你可以发起Pull Request&#xff0c;简称PR&#xf…

vulnhub靶场之CengBox3

1.信息收集 输入命令&#xff1a;netdiscover -i eth0 -r 192.168.239.0 &#xff0c;发现181机器存活 输入命令nmap -p- -sV -O -Pn -A 192.168.239.181 &#xff0c;进行端口探测&#xff0c;发现存在22、80、443端口&#xff0c;还发现存在域名ceng-company.vm。 将域名c…

了解Unity编辑器之组件篇Tilemap(五)

Tilemap&#xff1a;用于创建和编辑2D网格地图的工具。Tilemap的主要作用是简化2D游戏中地图的创建、编辑和渲染过程。以下是一些Tilemap的主要用途&#xff1a; 2D地图绘制&#xff1a;Tilemap提供了一个可视化的编辑器界面&#xff0c;可以快速绘制2D地图&#xff0c;例如迷…

python实现逻辑回归-清风数学建模-二分类水果数据

所用数据 &#x1f449;&#x1f449;&#x1f449;二分类水果数据 1.数据预处理 可以看到有4个特征&#xff0c;2种分类结果&#xff0c;最后4个没有分类结果的数据是拿来预测的 # 1. 数据预处理 import pandas as pd df pd.read_excel(oridata/二分类水果数据.xlsx,use…

开源大模型LLaMA 2会扮演类似Android的角色么?

在AI大模型没有商业模式&#xff1f;等文章中&#xff0c;我多次表达过这样一个观点&#xff1a;不要把大模型的未来应用方式比喻成公有云&#xff0c;大模型最终会是云端操作系统的核心&#xff08;新通用计算平台&#xff09;&#xff0c;而它的落地形式会很像过去的沃森&…

【C++】开源:Linux端ALSA音频处理库

&#x1f60f;★,:.☆(&#xffe3;▽&#xffe3;)/$:.★ &#x1f60f; 这篇文章主要介绍Linux端ALSA音频处理库。 无专精则不能成&#xff0c;无涉猎则不能通。。——梁启超 欢迎来到我的博客&#xff0c;一起学习&#xff0c;共同进步。 喜欢的朋友可以关注一下&#xff0c…

Python增删改查小练习

目录 1. List操作-增加 2. List操作-查询 3. List操作-修改 4. List操作-删除 资料获取方法 1. List操作-增加 List Append(“xx”) 插入到列表尾部 Insert(x,xx) 在指定的位置插入 Extend 将列表的元素分开,插入到之前列表的尾部 小练习: 把一个字符串”abcdefg…

ssh打开远程vscode

如果想要远程打开其他终端的vscode&#xff0c;首先要知道远程终端的ip地址和用户名称以及用户密码 1、打开本地vscode 2、点击左下角蓝色区域 3、页面上部出现如下图&#xff0c;点击ssh&#xff0c;我这里已经连接&#xff0c;所以是connect to host 4、选择Add New SSH Host…

appium中toast识别

目录 一、什么是Toast&#xff1f; 二、环境前提 三、修改配置 四、安装驱动 五、常见报错及解决方案 1、cnpm 不识别&#xff0c;提示不是内部或外部命令 2、npm 也不识别 3、报错 六、代码节选 一、什么是Toast&#xff1f; Android中的Toast是一种简易的消息提示框…

比selenium体验更好的ui自动化测试工具: cypress介绍

话说 Cypress is a next generation front end testing tool built for the modern web. And Cypress can test anything that runs in a browser.Cypress consists of a free, open source, locally installed Test Runner and a Dashboard Service for recording your tests.…

day44-Spring_AOP

0目录 1.2.3 1.Spring_AOP 实体类&#xff1a; Mapper接口&#xff1a; Service和实现类&#xff1a; 测试1&#xff1a; 运行后&#xff1a; 测试2&#xff1a;无此型号时 测试3&#xff1a;库存不足时 解决方案1&#xff1a;事务声明管理器 测试&#xff1a…

RocketMQ主从集群broker无法启动,日志报错

使用vmWare安装的centOS7.9虚拟机&#xff0c;RocketMQ5.1.3 在rocketMQ的bin目录里使用相对路径的方式启动broker&#xff0c;jps查询显示没有启动&#xff0c;日志报错如下 排查配置文件没有问题&#xff0c;nameServer也已经正常启动 更换绝对路径&#xff0c;启动broker&…

[ELK使用篇]:SpringCloud整合ELK服务

文章目录 一&#xff1a;前置准备-(参考之前博客)&#xff1a;1.1&#xff1a;准备Elasticsearch和Kibana环境&#xff1a;1.1.1&#xff1a;地址&#xff1a;[https://blog.csdn.net/Abraxs/article/details/128517777](https://blog.csdn.net/Abraxs/article/details/1285177…

MySQL explain详解

文章目录 0 环境准备1 explain 之 id2 explain 之 select_type3 explain 之 table4 explain 之 type5 explain 之 key6 explain 之 rows7 explain 之 extra MySQL 的 EXPLAIN 是一个用于查询优化的关键字。它用于分析和评估查询语句的执行计划&#xff0c;帮助开发者理解查询语…

尚硅谷大数据项目《在线教育之采集系统》笔记001

视频地址&#xff1a;尚硅谷大数据项目《在线教育之采集系统》_哔哩哔哩_bilibili 目录 P004 P006 P007 P009 P010 P017 P025 P026 P027 P028 P030 P004 将数据以图形图表的方式展示出来&#xff01; P006 数据埋点 所谓埋点就是在应用中特定的流程收集一些信息&…

(css)自定义登录弹窗页面

(css)自定义登录弹窗页面 效果&#xff1a; 代码&#xff1a; <!-- 登录弹窗 --> <el-dialog:visible.sync"dialogVisible"title"用户登录"width"25%"centerclass"custom-dialog":show-close"false":close-on-cli…