【如何训练一个中英翻译模型】LSTM机器翻译模型部署(三)

系列文章

【如何训练一个中英翻译模型】LSTM机器翻译seq2seq字符编码(一)
【如何训练一个中英翻译模型】LSTM机器翻译模型训练与保存(二)
【如何训练一个中英翻译模型】LSTM机器翻译模型部署(三)
【如何训练一个中英翻译模型】LSTM机器翻译模型部署之onnx(python)(四)

目录

  • 系列文章
  • 1、加载字符文件
  • 2、加载权重文件
  • 3、推理模型搭建
  • 4、进行推理

模型部署也是很重要的一部分,这里先讲基于python的部署,后面我们还要将模型部署到移动端。
细心的小伙伴会发现前面的文章在模型保存之后进行模型推理时,我们使用的数据是在训练之前我们对数据进行处理的encoder_input_data中读取,而不是我们手动输入的,那么这一章主要来解决自定义输入推理的问题

1、加载字符文件

首先,我们根据 【如何训练一个中译英翻译器】LSTM机器翻译模型训练与保存(二)的操作,到最后
会得到这样的三个文件:input_words.txt,target_words.txt,config.json
需要逐一进行加载
进行加载

# 加载字符
# 从 input_words.txt 文件中读取字符串
with open('input_words.txt', 'r') as f:input_words = f.readlines()input_characters = [line.rstrip('\n') for line in input_words]# 从 target_words.txt 文件中读取字符串
with open('target_words.txt', 'r', newline='') as f:target_words = [line.strip() for line in f.readlines()]target_characters = [char.replace('\\t', '\t').replace('\\n', '\n') for char in target_words]#字符处理,以方便进行编码
input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])# something readable.
reverse_input_char_index = dict((i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict((i, char) for char, i in target_token_index.items())
num_encoder_tokens = len(input_characters) # 英文字符数量
num_decoder_tokens = len(target_characters) # 中文文字数量

读取配置文件

import json
with open('config.json', 'r') as file:loaded_data = json.load(file)# 从加载的数据中获取max_encoder_seq_length和max_decoder_seq_length的值
max_encoder_seq_length = loaded_data["max_encoder_seq_length"]
max_decoder_seq_length = loaded_data["max_decoder_seq_length"]

2、加载权重文件

# 加载权重
from keras.models import load_model
encoder_model = load_model('encoder_model.h5')
decoder_model = load_model('decoder_model.h5')

3、推理模型搭建

def decode_sequence(input_seq):# Encode the input as state vectors.states_value = encoder_model.predict(input_seq)# Generate empty target sequence of length 1.target_seq = np.zeros((1, 1, num_decoder_tokens))# Populate the first character of target sequence with the start character.target_seq[0, 0, target_token_index['\t']] = 1.# this target_seq you can treat as initial state# Sampling loop for a batch of sequences# (to simplify, here we assume a batch of size 1).stop_condition = Falsedecoded_sentence = ''while not stop_condition:output_tokens, h, c = decoder_model.predict([target_seq] + states_value)# Sample a token# argmax: Returns the indices of the maximum values along an axis# just like find the most possible charsampled_token_index = np.argmax(output_tokens[0, -1, :])# find char using indexsampled_char = reverse_target_char_index[sampled_token_index]# and append sentencedecoded_sentence += sampled_char# Exit condition: either hit max length# or find stop character.if (sampled_char == '\n' or len(decoded_sentence) > max_decoder_seq_length):stop_condition = True# Update the target sequence (of length 1).# append then ?# creating another new target_seq# and this time assume sampled_token_index to 1.0target_seq = np.zeros((1, 1, num_decoder_tokens))target_seq[0, 0, sampled_token_index] = 1.# Update states# update states, frome the front partsstates_value = [h, c]return decoded_sentence

4、进行推理

import numpy as npinput_text = "Call me."
encoder_input_data = np.zeros((1,max_encoder_seq_length, num_encoder_tokens),dtype='float32')
for t, char in enumerate(input_text):print(char)# 3D vector only z-index has char its value equals 1.0encoder_input_data[0,t, input_token_index[char]] = 1.input_seq = encoder_input_data
decoded_sentence = decode_sequence(input_seq)
print('-')
print('Input sentence:', input_text)
print('Decoded sentence:', decoded_sentence)

运行结果:
在这里插入图片描述

以上的代码可在kaggle上运行:how-to-train-a-chinese-to-english-translator-iii

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/13425.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Windows Server 2012 能使用的playwright版本

由于在harkua_bot里面使用到了playwright,我的服务器又是Windows Server 2012 R2,最新版playwright不支持Windows Server 2012 R2,支持Windows Server 2016以上,所以有了这个需求 https://cdn.npmmirror.com/binaries/playwright…

Hadoop生态体系-HDFS

目录标题 1、Apache Hadoop2、HDFS2.1 设计目标:2.2 特性:2.3 架构2.4 注意点2.5 HDFS基本操作2.5.1 shell命令选项2.5.2 shell常用命令介绍 3、HDFS基本原理3.1 NameNode 概述3.2 Datanode概述 1、Apache Hadoop Hadoop:允许使用简单的编程…

Android Hook 剪切板相关方法

想起之前做过的项目有安全合规要求:主动弹窗获取用户同意了才能调用剪切板相关方法,否则属于违规调用,如果是自己项目的相关调用可以自己加一层if判断 但是一些第三方的jar包里面也有在调用的话,我们就无能为力了,而且…

JPA连接达梦数据库导致auto-ddl失效问题解决

现象: 项目使用了JPA,并且auto-ddl设置的为update,在连接达梦数据库的时候,第一次启动没有问题,但是后面重启就会报错,发现错误为重复建表,也就是说已经建好的表没有检测到,…

JVM类加载

一、类记载过程 1、通过类的全限定名获取存储该类的class文件 2、解析成运行时数据,即instanceKlass实例,存放到方法区 3、在堆区生成该类的class对象,即instanceMirrorKlass实例 二、将.class文件解析成什么?类的元信息在JVM中如何…

ceph集群中RBD的性能测试、性能调优

文章目录 rados benchrbd bench-write测试工具Fio测试ceph rbd块设备的iops性能测试ceph rbd块设备的带宽测试ceph rbd块设备的延迟 性能调优 rados bench 参考:https://blog.csdn.net/Micha_Lu/article/details/126490260 rados bench为ceph自带的基准测试工具&am…

全加器(多位)的实现

一,半加器 定义 半加器(Half Adder)是一种用于执行二进制数相加的简单逻辑电路。它可以将两个输入位的和(Sum)和进位(Carry)计算出来。 半加器有两个输入:A 和 B,分别代表…

MySQL基础扎实——MySQL中有那些不同的表格

表格类型 在MySQL中,常见的表格类型有以下几种: MyISAM:是MySQL默认的表格类型,具有较高的性能和较小的存储空间占用。但是,MyISAM不支持事务、崩溃恢复和数据行级锁定。 InnoDB:是MySQL提供的一个更强大…

Redis实现分布式锁

文章目录 4、分布式锁4.1 、基本原理和实现方式对比4.2 、Redis分布式锁的实现核心思路4.3 实现分布式锁版本一4.4 Redis分布式锁误删情况说明4.5 解决Redis分布式锁误删问题4.6 分布式锁的原子性问题4.7 Lua脚本解决多条命令原子性问题4.8 利用Java代码调用Lua脚本改造分布式锁…

MySQL | 常用命令示例

MySQL | 常用命令示例 一、启停MySQL数据库服务二、连接MySQL数据库三、创建和管理数据库四、创建和管理数据表五、数据备份和恢复六、查询与优化 MySQL是一款常用的关系型数据库管理系统,广泛应用于各个领域。在使用MySQL时,我们经常需要编写一些常用脚…

【数据结构】【王道408】——PPT截图与思维导图

自用视频PPT截图 视频网址王道B站链接 23考研 408新增考点: 并查集,红黑树 2023年408真题数据结构篇 408考纲解读 考纲变化 目录 第一章 绪论第二章 线性表顺序表单链表双链表循环链表静态链表差别 第三章 栈 队列 数组栈队列栈的应用数组 第四章 串第五…

软考A计划-系统集成项目管理工程师-项目质量管理-中

点击跳转专栏>Unity3D特效百例点击跳转专栏>案例项目实战源码点击跳转专栏>游戏脚本-辅助自动化点击跳转专栏>Android控件全解手册点击跳转专栏>Scratch编程案例点击跳转>软考全系列 👉关于作者 专注于Android/Unity和各种游戏开发技巧&#xff…

小程序UV:衡量用户规模与活跃度的重要指标

什么是UV UV是Unique Visitor(独立访客)的缩写,指的是在特定时间段内访问某个网站、应用或平台的独立用户数量。UV是根据设备、IP地址、Cookie等来识别不同的用户,对于相同的用户多次访问,只计算为一个UV。UV是衡量网…

流数据湖平台Apache Paimon(一)概述

文章目录 第1章 概述1.1 简介1.2 核心特性1.3 基本概念1.3.1 Snapshot1.3.2 Partition1.3.3 Bucket1.3.4 Consistency Guarantees一致性保证 1.4 文件布局1.4.1 Snapshot Files1.4.2 Manifest Files1.4.3 Data Files1.4.4 LSM Trees 第1章 概述 1.1 简介 Flink 社区希望能够将…

RocketMQ重复消费的解决方案::分布式锁直击面试!

文章目录 场景分析方法的幂等分布式锁Redis实现分布式锁抢锁的设计思路 分布式锁案例 直击面试rocketmq什么时候重复消费消息丢失的问题消息在哪里丢失发送端确保发送成功并且配合失败的业务处理消费端确保消息不丢失rocketmq 主从同步刷盘 场景分析 分布式系统架构中,队列是分…

css实现有缺口的border

css实现有缺口的border 1.问题回溯2.css实现有缺口的border 1.问题回溯 通常会有那种两个div都有border重叠在一起就会有种加粗的效果。 div1,div2,div3都有个1px的border,箭头标记的地方是没有处理解决的,很明显看着是有加粗效果的。其实这种感觉把di…

【Java从入门到大牛】集合进阶上篇

🔥 本文由 程序喵正在路上 原创,CSDN首发! 💖 系列专栏:Java从入门到大牛 🌠 首发时间:2023年7月29日 🦋 欢迎关注🖱点赞👍收藏🌟留言&#x1f43…

IntelliJ IDEA流行的构建工具——Gradle

IntelliJ IDEA,是java编程语言开发的集成环境。IntelliJ在业界被公认为最好的java开发工具,尤其在智能代码助手、代码自动提示、重构、JavaEE支持、各类版本工具(git、svn等)、JUnit、CVS整合、代码分析、 创新的GUI设计等方面的功能可以说是超常的。 如…

基于java SpringBoot和HTML的博客系统

随着网络技术渗透到社会生活的各个方面,传统的交流方式也面临着变化。互联网是一个非常重要的方向。基于Web技术的网络考试系统可以在全球范围内使用互联网,可以在本地或异地进行通信,大大提高了通信和交换的灵活性。在当今高速发展的互联网时…

【达哥讲网络】第3集:数据交换的垫基石——二层交换原理

专业的网络工程师在进行网络设计时,会事先规划好不同业务数据的转发路径,一方面是为了满足用户应用需求,另一方面是为了提高数据转发效率、充分利用各设备/各链路的硬件或带宽资源。在进行网络故障排除时,理顺各路数据的转发路径也…