【如何训练一个中英翻译模型】LSTM机器翻译模型训练与保存(二)

系列文章

【如何训练一个中英翻译模型】LSTM机器翻译seq2seq字符编码(一)
【如何训练一个中英翻译模型】LSTM机器翻译模型训练与保存(二)
【如何训练一个中英翻译模型】LSTM机器翻译模型部署(三)
【如何训练一个中英翻译模型】LSTM机器翻译模型部署之onnx(python)(四)

目录

  • 系列文章
  • 1、加载训练集
  • 2、训练集数据处理
  • 3、网络搭建
  • 4、启动训练
  • 5、模型保存
  • 6、模型加载与推理

基于LSTM训练一个翻译器,要怎么做呢?其实很简单,也没那么复杂。
主要包括以下几个步骤:
1、加载训练集
2、训练集数据处理(字符编码),可参见:
【如何训练一个中译英翻译器】LSTM机器翻译seq2seq字符编码(一)
3、网络搭建
4、启动训练
5、模型保存
6、模型加载与推理
下面来看下整个过程:

1、加载训练集

from keras.models import Model
from keras.layers import Input, LSTM, Dense
import numpy as npbatch_size = 64  # Batch size for training.
epochs = 1000  # Number of epochs to train for.
latent_dim = 256  # Latent dimensionality of the encoding space.
num_samples = 12000  # Number of samples to train on.
data_path = 'cmn.txt'

2、训练集数据处理

数据集预处理

# Vectorize the data.
input_texts = [] # 保存英文数据集
target_texts = [] # 保存中文数据集
input_characters = set() # 保存英文字符,比如a,b,c
target_characters = set() # 保存中文字符,比如,你,我,她
with open(data_path, 'r', encoding='utf-8') as f:lines = f.read().split('\n')# 一行一行读取数据
for line in lines[: min(num_samples, len(lines) - 1)]: # 遍历每一行数据集(用min来防止越出)input_text, target_text = line.split('\t') # 分割中英文# We use "tab" as the "start sequence" character# for the targets, and "\n" as "end sequence" character.target_text = '\t' + target_text + '\n'input_texts.append(input_text)target_texts.append(target_text)for char in input_text: # 提取字符if char not in input_characters:input_characters.add(char)for char in target_text:if char not in target_characters:target_characters.add(char)input_characters = sorted(list(input_characters)) # 排序一下
target_characters = sorted(list(target_characters))
num_encoder_tokens = len(input_characters) # 英文字符数量
num_decoder_tokens = len(target_characters) # 中文文字数量
max_encoder_seq_length = max([len(txt) for txt in input_texts]) # 输入的最长句子长度
max_decoder_seq_length = max([len(txt) for txt in target_texts])# 输出的最长句子长度print('Number of samples:', len(input_texts))
print('Number of unique input tokens:', num_encoder_tokens)
print('Number of unique output tokens:', num_decoder_tokens)
print('Max sequence length for inputs:', max_encoder_seq_length)
print('Max sequence length for outputs:', max_decoder_seq_length)

数据集编码

# mapping token to index, easily to vectors
# 处理方便进行编码为向量
# {
#   'a': 0,
#   'b': 1,
#   'c': 2,
#   ...
#   'z': 25
# }
input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])# np.zeros(shape, dtype, order)
# shape is an tuple, in here 3D
encoder_input_data = np.zeros( # (12000, 32, 73) (数据集长度、句子长度、字符数量)(len(input_texts), max_encoder_seq_length, num_encoder_tokens),dtype='float32')
decoder_input_data = np.zeros( # (12000, 22, 2751)(len(input_texts), max_decoder_seq_length, num_decoder_tokens),dtype='float32')
decoder_target_data = np.zeros( # (12000, 22, 2751)(len(input_texts), max_decoder_seq_length, num_decoder_tokens),dtype='float32')# 遍历输入文本(input_texts)和目标文本(target_texts)中的每个字符,
# 并将它们转换为数值张量以供深度学习模型使用。
#编码如下
#我,你,他,这,国,是,家,人,中
#1  0  0   0  1   1  0   1  1,我是中国人
#1  0   1  0  0   1  1   1  0,他是我家人
# input_texts contain all english sentences
# output_texts contain all chinese sentences
# zip('ABC','xyz') ==> Ax By Cz, looks like that
# the aim is: vectorilize text, 3D
# zip(input_texts, target_texts)成对取出输入输出,比如input_text = 你好,target_text = you goodfor i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):for t, char in enumerate(input_text):# 3D vector only z-index has char its value equals 1.0encoder_input_data[i, t, input_token_index[char]] = 1.for t, char in enumerate(target_text):# decoder_target_data is ahead of decoder_input_data by one timestepdecoder_input_data[i, t, target_token_index[char]] = 1.if t > 0:# decoder_target_data will be ahead by one timestep# and will not include the start character.# igone t=0 and start t=1, meansdecoder_target_data[i, t - 1, target_token_index[char]] = 1.

3、网络搭建

编码器搭建

# Define an input sequence and process it.
# input prodocts keras tensor, to fit keras model!
# 1x73 vector
# encoder_inputs is a 1x73 tensor!
encoder_inputs = Input(shape=(None, num_encoder_tokens))# units=256, return the last state in addition to the output
encoder_lstm = LSTM((latent_dim), return_state=True)# LSTM(tensor) return output, state-history, state-current
encoder_outputs, state_h, state_c = encoder_lstm(encoder_inputs)# We discard `encoder_outputs` and only keep the states.
encoder_states = [state_h, state_c]

解码器搭建

# Set up the decoder, using `encoder_states` as initial state.
decoder_inputs = Input(shape=(None, num_decoder_tokens))# We set up our decoder to return full output sequences,
# and to return internal states as well. We don't use the
# return states in the training model, but we will use them in inference.
decoder_lstm = LSTM((latent_dim), return_sequences=True, return_state=True)# obtain output
decoder_outputs, _, _ = decoder_lstm(decoder_inputs,initial_state=encoder_states)

整体网络模型

# dense 2580x1 units full connented layer
decoder_dense = Dense(num_decoder_tokens, activation='softmax')# why let decoder_outputs go through dense ?
decoder_outputs = decoder_dense(decoder_outputs)# Define the model that will turn, groups layers into an object
# with training and inference features
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
# model(input, output)
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)# Run training
# compile -> configure model for training
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
# model optimizsm

4、启动训练

model.fit([encoder_input_data, decoder_input_data],decoder_target_data,batch_size=batch_size,epochs=epochs,validation_split=0.2)

可在控制台看到训练的相关信息:
在这里插入图片描述

5、模型保存

首先,保存权重文件
这里将解码器与编码器分别保存下来

encoder_model = Model(encoder_inputs, encoder_states)
# tensor 73x1
decoder_state_input_h = Input(shape=(latent_dim,))
# tensor 73x1
decoder_state_input_c = Input(shape=(latent_dim,))
# tensor 146x1
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
# lstm
decoder_outputs, state_h, state_c = decoder_lstm(decoder_inputs, initial_state=decoder_states_inputs)
#
decoder_states = [state_h, state_c]
#
decoder_outputs = decoder_dense(decoder_outputs)
#
decoder_model = Model([decoder_inputs] + decoder_states_inputs,[decoder_outputs] + decoder_states)
encoder_model.save('encoder_model.h5')
decoder_model.save('decoder_model.h5')

其次,保存编码文件

# 将 input_characters保存为 input_words.txt 文件
with open('input_words.txt', 'w', newline='') as f:for char in input_characters:if char == '\t':f.write('\\t\n')elif char == '\n':f.write('\\n\n')else:f.write(char + '\n')# 将 target_characters保存为 target_words.txt 文件
with open('target_words.txt', 'w', newline='') as f:for char in target_characters:if char == '\t':f.write('\\t\n')elif char == '\n':f.write('\\n\n')else:f.write(char + '\n')

最后,保存一下配置文件:

import json# 将数据保存到JSON文件
data = {"max_encoder_seq_length": max_encoder_seq_length,"max_decoder_seq_length": max_decoder_seq_length
}with open('config.json', 'w') as file:json.dump(data, file)

所以保存下来我们可以看到一共有以下几个文件,这几个文件要保存好,我们后面的文章讲到部署的时候要用到:
在这里插入图片描述

6、模型加载与推理

加载模型权重

from keras.models import load_model
encoder_model = load_model('encoder_model.h5')
decoder_model = load_model('decoder_model.h5')

推理模型搭建

# something readable.
reverse_input_char_index = dict((i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict((i, char) for char, i in target_token_index.items())def decode_sequence(input_seq):# Encode the input as state vectors.states_value = encoder_model.predict(input_seq)# Generate empty target sequence of length 1.target_seq = np.zeros((1, 1, num_decoder_tokens))# Populate the first character of target sequence with the start character.target_seq[0, 0, target_token_index['\t']] = 1.# this target_seq you can treat as initial state# Sampling loop for a batch of sequences# (to simplify, here we assume a batch of size 1).stop_condition = Falsedecoded_sentence = ''while not stop_condition:output_tokens, h, c = decoder_model.predict([target_seq] + states_value)# Sample a token# argmax: Returns the indices of the maximum values along an axis# just like find the most possible charsampled_token_index = np.argmax(output_tokens[0, -1, :])# find char using indexsampled_char = reverse_target_char_index[sampled_token_index]# and append sentencedecoded_sentence += sampled_char# Exit condition: either hit max length# or find stop character.if (sampled_char == '\n' or len(decoded_sentence) > max_decoder_seq_length):stop_condition = True# Update the target sequence (of length 1).# append then ?# creating another new target_seq# and this time assume sampled_token_index to 1.0target_seq = np.zeros((1, 1, num_decoder_tokens))target_seq[0, 0, sampled_token_index] = 1.# Update states# update states, frome the front partsstates_value = [h, c]return decoded_sentence

进行模型推理


for seq_index in range(100,200):# Take one sequence (part of the training set)# for trying out decoding.input_seq = encoder_input_data[seq_index: seq_index + 1]decoded_sentence = decode_sequence(input_seq)print('-')print('Input sentence:', input_texts[seq_index])print('Decoded sentence:', decoded_sentence)

可以看到输出:
在这里插入图片描述

以上的代码可在kaggle上运行:how-to-train-a-chinese-to-english-translator-ii

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/14076.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【框架篇】Spring Boot 日志

Spring Boot 日志 一,日志用途 尽管一个项目在没有日志记录的情况下可能能够正常运行,但是日志记录对于我们来说却是至关重要的,它存在以下功能: 1,故障排查和调试:当项目出现异常或者故障时,…

idea连接远程服务器上传war包文件

idea连接远程服务器&上传war包 文章目录 idea连接远程服务器&上传war包1. 连接服务器2.上传war包 1. 连接服务器 选择Tools -> Start SSH Session 添加配置 连接成功 2.上传war包 Tools -> Deployment -> Browse Remote Host 点击右侧标签,点击&…

gin框架内容(三)--中间件

gin框架内容(三)--中间件 Gin框架允许开发者在处理请求的过程中,加入用户自己的函数。这个函数就叫中间件,中间件适合处理一些公共的业务逻辑,比如登录认证、权限校验、数据分页、记录日志、耗时统计等 即比如&#x…

使用分布式HTTP代理爬虫实现数据抓取与分析的案例研究

在当今信息爆炸的时代,数据已经成为企业决策和发展的核心资源。然而,要获取大规模的数据并进行有效的分析是一项艰巨的任务。为了解决这一难题,我们进行了一项案例研究,通过使用分布式HTTP代理爬虫,实现数据抓取与分析…

华为eNSP:isis的配置

一、拓扑图 二、路由器的配置 配置接口IP AR1&#xff1a; <Huawei>system-view [Huawei]int g0/0/0 [Huawei-GigabitEthernet0/0/0]ip add 1.1.1.1 24 [Huawei-GigabitEthernet0/0/0]qu AR2: <Huawei>system-view [Huawei]int g0/0/0 [Huawei-GigabitEthe…

【用IDEA基于Scala2.12.18开发Spark 3.4.1 项目】

目录 使用IDEA创建Spark项目设置sbt依赖创建Spark 项目结构新建Scala代码 使用IDEA创建Spark项目 打开IDEA后选址新建项目 选址sbt选项 配置JDK debug 解决方案 相关的依赖下载出问题多的话&#xff0c;可以关闭idea&#xff0c;重启再等等即可。 设置sbt依赖 将sbt…

QTday4(鼠标事件和键盘事件/QT实现连接TCP协议)

笔记 #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include <QDebug> #include <QTcpServer>//服务器类 #include <QTcpSocket>//客户端类 #include <QMessageBox> #include <QList>//链表容器QT_BEGIN_NAMESPACE namespace Ui …

单链表详解

&#x1f493;博主个人主页:不是笨小孩&#x1f440; ⏩专栏分类:数据结构与算法&#x1f440; &#x1f69a;代码仓库:笨小孩的代码库&#x1f440; ⏩社区&#xff1a;不是笨小孩&#x1f440; &#x1f339;欢迎大家三连关注&#xff0c;一起学习&#xff0c;一起进步&#…

Mac 安装启动RabbitMq

使用HomeBrew安装 未安装的请参照我的这篇Mac安装HomeBrew文章 安装 执行命令 brew install rabbitmq启动方式 brew services start rabbitmq端口说明 端口用处5672RabbitMQ通讯端口&#xff0c;也就是连接使用的端口15672RabbbitMQ管理界面端口&#xff0c;需要开启Manage…

web自动化测试-PageObject 设计模式

为 UI 页面写测试用例时&#xff08;比如 web 页面&#xff0c;移动端页面&#xff09;&#xff0c;测试用例会存在大量元素和操作细节。当 UI 变化时&#xff0c;测试用例也要跟着变化&#xff0c; PageObject 很好的解决了这个问题。 使用 UI 自动化测试工具时&#xff08;包…

Zebec Card 将在亚洲、拉美等地区推出,生态全球化加速

随着以Visa、特斯拉、BNY Mellon、BlackRock、Mastercard、Gucci等为代表的传统商业机构巨头&#xff0c;以及萨尔瓦多、中非共和国等为代表的国家不断的向加密货币领域布局&#xff0c;越来越多的投资者开始以新的眼光来看待加密货币&#xff0c;仅在2022年&#xff0c;加密货…

如何学好Java并调整学习过程中的心态:学习之路的秘诀

文章目录 第一步&#xff1a;建立坚实的基础实例分析&#xff1a;选择合适的学习路径 第二步&#xff1a;选择合适的学习资源实例分析&#xff1a;参与编程社区 第三步&#xff1a;动手实践实例分析&#xff1a;开发个人项目 调整学习过程中的心态1. 不怕失败2. 持续学习3. 寻求…

Unity自定义后处理——Tonemapping色调映射

大家好&#xff0c;我是阿赵。   继续介绍屏幕后处理&#xff0c;这一期介绍一下Tonemapping色调映射 一、Tone Mapping的介绍 Tone Mapping色调映射&#xff0c;是一种颜色的映射关系处理&#xff0c;简单一点说&#xff0c;一般是从原始色调&#xff08;通常是高动态范围&…

SpringBoot 如何进行 统一异常处理

在Spring Boot中&#xff0c;可以通过自定义异常处理器来实现统一异常处理。异常处理器能够捕获应用程序中抛出的各种异常&#xff0c;并提供相应的错误处理和响应。 Spring Boot提供了ControllerAdvice注解&#xff0c;它可以将一个类标记为全局异常处理器。全局异常处理器能…

【动态规划】子数组系列

文章目录 动态规划&#xff08;子数组系列&#xff09;1. 最大子数组和2. 环形子数组的最大和3. 乘积最大子数组4. 乘积为正的最长子数组的长度5. 等差数列划分6. 最长湍流子数组7. 单词拆分8. 环形字符串中的唯一的子字符串 动态规划&#xff08;子数组系列&#xff09; 1. 最…

算法与数据结构(四)--排序算法

一.冒泡排序 原理图&#xff1a; 实现代码&#xff1a; /* 冒泡排序或者是沉底排序 *//* int arr[]: 排序目标数组,这里元素类型以整型为例; int len: 元素个数 */ void bubbleSort (elemType arr[], int len) {//为什么外循环小于len-1次&#xff1f;//考虑临界情况&#xf…

Neo4j 集群和负载均衡

Neo4j 集群和负载均衡 Neo4j是当前最流行的开源图DB。刚好读到了Neo4j的集群和负载均衡策略&#xff0c;记录一下。 1 集群 Neo4j 集群使用主从复制实现高可用性和水平读扩展。 1.1 复制 集群的写入都通过主节点协调完成的&#xff0c;数据先写入主机&#xff0c;再同步到…

振弦采集仪及在线监测系统完整链条的岩土工程隧道安全监测

振弦采集仪及在线监测系统完整链条的岩土工程隧道安全监测 近年来&#xff0c;随着城市化的不断推进和基础设施建设的不断发展&#xff0c;隧道建设也日益成为城市交通发展的必需品。然而&#xff0c;隧道建设中存在着一定的安全隐患&#xff0c;如地质灾害、地下水涌流等&…

springboot第32集:redis系统-android系统-Nacos Server

Error parsing HTTP request header HTTP method names must be tokens 检查发送HTTP请求的客户端代码&#xff0c;确保方法名中不包含非法字符。通常情况下&#xff0c;HTTP请求的方法名应该是简单的标识符&#xff0c;例如"GET"、"POST"、"PUT"…

《TCP IP网络编程》第十二章

第 12 章 I/O 复用 12.1 基于 I/O 复用的服务器端 多进程服务端的缺点和解决方法&#xff1a; 为了构建并发服务器&#xff0c;只要有客户端连接请求就会创建新进程。这的确是实际操作中采用的一种方案&#xff0c;但并非十全十美&#xff0c;因为创建进程要付出很大的代价。…