使用LSTM建立seq2seq模型进行语言翻译

文章目录

    • 1. 数据处理
    • 2. 编码器、解码器数据
      • 2.1 编码器
      • 2.2 解码器
      • 2.3 模型
    • 3. 训练
    • 4. 推理模型
    • 5. 采样

参考 基于深度学习的自然语言处理

1. 数据处理

  • 读取数据
with open('deu.txt', 'r', encoding='utf-8') as f:lines = f.read().split('\n')
print("文档有 {} 行。".format(len(lines)))
num_samples = 20000 # 使用的语料行数
lines_to_use = lines[ : min(num_samples, len(lines)-1)]
print(lines_to_use)

  • 替换数字
import re
print(lines_to_use[19516])
for i in range(len(lines_to_use)):lines_to_use[i] = re.sub('\d', ' _NUMBER_ ', lines_to_use[i])# 用 ' _NUMBER_ ' 替换 数字(\d)
print(lines_to_use[19516])

输出:(数字被替换了)

Turn to channel 1.	Wechsle auf Kanal eins.
Turn to channel  _NUMBER_ .	Wechsle auf Kanal eins.
  • 切分 输入,输出
input_texts = [] # 输入句子集
target_texts = [] # 输出句子集
input_words = set() # 输入词集合
target_words = set() # 输出词集合
for line in lines_to_use:x, y = line.split('\t')y = 'BEGIN_ ' + y + ' _END' # 输出加上 开始结束 标记input_texts.append(x)target_texts.append(y)for word in x.split():if word not in input_words:input_words.add(word)for word in y.split():if word not in target_words:target_words.add(word)
  • 输入输出句子的 最大长度
max_input_seq_len = max([len(seq.split()) for seq in input_texts])
# 11
max_target_seq_len = max([len(seq.split()) for seq in target_texts])
# 15
  • 输入输出 tokens 个数
input_words = sorted(list(input_words))
target_words = sorted(list(target_words))
num_encoder_tokens = len(input_words) # 5724
num_decoder_tokens = len(target_words) # 9126


  • 建立 tokens 与 id 的映射关系
inputToken_idx = {token : i for (i, token) in enumerate(input_words)}
outputToken_idx = {token : i for (i, token) in enumerate(target_words)}


idx_inputToken = {i : token for (i, token) in enumerate(input_words)}
idx_outputToken = {i : token for (i, token) in enumerate(target_words)}

2. 编码器、解码器数据

  • 注意维度的意义
import numpy as np
encoder_input_data = np.zeros((len(input_texts), max_input_seq_len),# 句子数量,         最大输入句子长度dtype=np.float32
)decoder_input_data = np.zeros((len(target_texts), max_target_seq_len),# 句子数量,          最大输出句子长度dtype=np.float32
)decoder_output_data = np.zeros((len(target_texts), max_target_seq_len, num_decoder_tokens),# 句子数量,          最大输出句子长度,      输出 tokens ids 个数dtype=np.float32
)
  • 填充矩阵
for i,(input_text, target_text) in enumerate(zip(input_texts, target_texts)):for t, word in enumerate(input_text.split()):encoder_input_data[i, t] = inputToken_idx[word]for t, word in enumerate(target_text.split()):decoder_input_data[i, t] = outputToken_idx[word]if t > 0:# 解码器的输出比输入提前一个时间步decoder_output_data[i, t-1, outputToken_idx[word]] = 1.

2.1 编码器

from keras.layers import Input, LSTM, Embedding, Dense
from keras.models import Modelembedding_size = 256  # 嵌入维度
rnn_size = 64
# 编码器
encoder_inputs = Input(shape=(None,))
encoder_after_embedding = Embedding(input_dim=num_encoder_tokens,  # 单词个数output_dim=embedding_size)(encoder_inputs)
encoder_lstm = LSTM(units=rnn_size, return_state=True)
# return_state: Boolean. Whether to return
#   the last state in addition to the output.
_, state_h, state_c = encoder_lstm(encoder_after_embedding)
encoder_states = [state_h, state_c] # 思想向量

2.2 解码器

# 解码器
decoder_inputs = Input(shape=(None,))
decoder_after_embedding = Embedding(input_dim=num_decoder_tokens,  # 单词个数output_dim=embedding_size)(decoder_inputs)
decoder_lstm = LSTM(units=rnn_size, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(decoder_after_embedding,initial_state=encoder_states)
# 使用 encoder 输出的思想向量初始化 decoder 的 LSTM 的初始状态
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
# 输出词个数,多分类
decoder_outputs = decoder_dense(decoder_outputs)

2.3 模型

model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
model.summary()from keras.utils import plot_model
plot_model(model,to_file='model.png')

输出:

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, None)         0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, None)         0                                            
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, None, 256)    1465344     input_1[0][0]                    
__________________________________________________________________________________________________
embedding_2 (Embedding)         (None, None, 256)    2336256     input_2[0][0]                    
__________________________________________________________________________________________________
lstm_1 (LSTM)                   [(None, 64), (None,  82176       embedding_1[0][0]                
__________________________________________________________________________________________________
lstm_2 (LSTM)                   [(None, None, 64), ( 82176       embedding_2[0][0]                ![在这里插入图片描述](https://img-blog.csdnimg.cn/20201215221559994.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzIxMjAxMjY3,size_16,color_FFFFFF,t_70)lstm_1[0][1]                     lstm_1[0][2]                     
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, None, 9126)   593190      lstm_2[0][0]                     
==================================================================================================
Total params: 4,559,142
Trainable params: 4,559,142
Non-trainable params: 0
__________________________________________________________________________________________________

3. 训练

  • 训练 + 回调函数保存最佳模型
from keras.callbacks import ModelCheckpointfilepath='weights.best.h5'# 有一次提升, 则覆盖一次 save_best_only=True
checkpoint = ModelCheckpoint(filepath, monitor='accuracy', verbose=1,save_best_only=True,mode='max',save_freq=2) 
callbacks_list = [checkpoint]
# https://keras.io/api/callbacks/model_checkpoint/history = model.fit(x=[encoder_input_data, decoder_input_data],y=decoder_output_data,batch_size=128,epochs=200,validation_split=0.1,callbacks=callbacks_list)
model.save('model.h5')
  • 绘制训练曲线
import pandas as pd
from matplotlib import pyplot as plt
loss = history.history['loss']
val_loss = history.history['val_loss']
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']plt.plot(loss, label='train Loss')
plt.plot(val_loss, label='valid Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid()
plt.show()plt.plot(acc, label='train Acc')
plt.plot(val_acc, label='valid Acc')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.grid()
plt.show()

4. 推理模型

  • 编码器
encoder_model = Model(encoder_inputs, encoder_states) # 输入(带embedding),输出思想向量
  • 解码器
# 编码器的输出,作为解码器的初始状态
decoder_state_input_h = Input(shape=(rnn_size,))
decoder_state_input_c = Input(shape=(rnn_size,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]

初始状态 + embedding 作为输入,经过LSTM,输出 decoder_outputs_inf, state_h_inf, state_c_inf

decoder_outputs_inf, state_h_inf, state_c_inf = decoder_lstm(decoder_after_embedding,initial_state=decoder_states_inputs)
# 作为下一次推理的状态输入 h, c
decoder_states_inf = [state_h_inf, state_c_inf] 
# LSTM的输出,接 FC,预测下一个词是什么
decoder_outputs_inf = decoder_dense(decoder_outputs_inf)
decoder_model = Model([decoder_inputs] + decoder_states_inputs,[decoder_outputs_inf] + decoder_states_inf
)

5. 采样

def decode_sequence(input_seq):# encoder_states = [state_h, state_c]states_value = encoder_model.predict(input_seq) # list 2个 array 1*rnn_sizetarget_seq = np.zeros((1, 1))# 目标输入序列 初始为 'BEGIN_' 的 idxtarget_seq[0, 0] = outputToken_idx['BEGIN_']stop = Falsedecoded_sentence = ''while not stop:output_tokens, h, c = decoder_model.predict([target_seq] + states_value)# output_tokens [1*1*9126]   h,c [1*rnn_size]sampled_token_idx = np.argmax(output_tokens)sampled_word = idx_outputToken[sampled_token_idx]decoded_sentence += ' ' + sampled_wordif sampled_word == '_END' or len(decoded_sentence) > 60:stop = Truetarget_seq = np.zeros((1, 1))target_seq[0, 0] = sampled_token_idx # 作为下一次预测,输入# Update statesstates_value = [h, c] # 作为下一次的状态输入return decoded_sentence# 简单测试 采样
text_to_translate = 'Are you happy ?'
encoder_input_to_translate = np.zeros((1, max_input_seq_len),dtype=np.float32)
for t, word in enumerate(text_to_translate.split()):encoder_input_to_translate[0, t] = inputToken_idx[word]# encoder_input_to_translate [[ids,...,0,0,0,0]]
print(decode_sequence(encoder_input_to_translate))

输出:

text_to_translate = 'Are you happy?'
输出: Sind Sie glücklich? _END  # 你高兴吗?
text_to_translate = 'Where is my car?'
输出: Wo ist mein Auto? _END # 我的车呢?
text_to_translate = 'When I see you, I fall in love with you!'
输出:Sind Sie mit uns gehen. _END # 你跟我们一起去吗?

注意:

  • 待翻译句子长度不能超过最大长度
  • 且不能出现没有出现过的词汇,如 dear 出现过,但是与标点连着写dear!没有出现过,会报错

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/473363.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【struts2】预定义拦截器

1)预定义拦截器 Struts2有默认的拦截器配置,也就是说,虽然我们没有主动去配置任何关于拦截器的东西,但是Struts2会使用默认引用的拦截器。由于Struts2的默认拦截器声明和引用都在这个Struts-default.xml里面&#xff0…

微信小程序页面跳转方法总结

微信小程序页面跳转目前有以下方法(不全面的欢迎补充): 1. 利用小程序提供的 API 跳转: // 保留当前页面,跳转到应用内的某个页面,使用wx.navigateBack可以返回到原页面。 // 注意:调用 navig…

使用注意力机制建模 - 标准化日期格式

文章目录1. 概述2. 数据3. 模型4. 训练5. 测试参考 基于深度学习的自然语言处理本文使用attention机制的模型,将各种格式的日期转化成标准格式的日期 1. 概述 LSTM、GRU 减少了梯度消失的问题,但是对于复杂依赖结构的长句子,梯度消失仍然存…

微信小程序在当前页面设置其他页面的数据

如果其他页面用到的数据是 globalData, 那么直接在当前页面修改 globalData 数据即可。 如果其他页面用到的数据是 storage, 那么直接在当前页面修改 storage 数据即可。 场景:(由 A 页面跳转到 B 页面) 在 B 页面完…

牛客 数学实验(模拟)

文章目录1. 题目2. 解题1. 题目 链接:https://ac.nowcoder.com/acm/contest/10166/A 来源:牛客网 牛牛在做数学实验。 老师给了牛牛一个数字n,牛牛需要不断地将所有数位上的值做乘法运算,直至最后数字不发生变化为止。 请你帮牛…

css动画之波纹

样式定义: #ContactAbout { height: auto; position: relative; overflow: hidden; } #sectioncontact { display: block; width: 100%; position: relative; height: 700px; z-index: 10; }#sectioncontact .map { width: 370px; height: 280px; position: absolut…

牛客 奇怪的排序问题(单调栈/遍历)

文章目录1. 题目2. 解题1. 题目 链接:https://ac.nowcoder.com/acm/contest/10166/B 来源:牛客网 操场上有n个人排成一队,这n个人身高互不相同,可将他们的身高视为一个1到n的排列。 这时需要把队伍变成升序,也就是从…

Python 中,matplotlib绘图无法显示中文的问题

在python中,默认情况下是无法显示中文的,如下代码: import matplotlib.pyplot as plt# 定义文本框和箭头格式 decisionNode dict(boxstyle "sawtooth", fc "0.8") leafNode dict(boxstyle "round4", fc …

使用Marshal.Copy把Txt行数据转为Struct类型值

添加重要的命名空间: using System.Runtime.InteropServices; 先建立结构相同(char长度相同)的Struct类型用于转换: [StructLayout(LayoutKind.Sequential, Pack 1)]public struct Employee{[MarshalAs(UnmanagedType.ByValArray, SizeConst 6)]public char[] EmployeeId;[Ma…

牛客 XOR和(找规律)

文章目录1. 题目2. 解题1. 题目 链接:https://ac.nowcoder.com/acm/contest/10166/C 来源:牛客网 牛牛最近学会了异或操作,于是他发现了一个函数 f(x)x⊕(x−1)f(x)x\oplus (x-1)f(x)x⊕(x−1),现在牛牛给你一个数 n&#xff0c…

采用contentprivider扫描手机SD卡的图片资源

Intent inten new Intent(Intent.ACTION_PICK,MediaStore.Images.Media.EXTERNAL_CONTENT_URI);startActivityForResult(inten,21);------------------------在onActivityResult中加入-------------------Uri uri data.getData();String[] filePath { MediaStore.Images.Med…

天池 在线编程 数组划分III(计数)

文章目录1. 题目2. 解题1. 题目 https://tianchi.aliyun.com/oj/231188302809557697/235445278655844965 给你一个整数数组和一个整数K,请你判断数组是否可以划分为若干大小为k序列,并满足以下条件: 数组中的每一个数恰恰出现在一个序列中…

详解nohup和 区别

一、nohup nohup 命令运行由 Command参数和任何相关的 Arg参数指定的命令,忽略所有挂断(SIGHUP)信号。在注销后使用 nohup 命令运行后台中的程序。要运行后台中的 nohup 命令,添加 & ( 表示“and”的符号&#xf…

谈谈.NET MVC QMVC高级开发

自从吾修主页上发布了QMVC1.0,非常感兴趣,用了半月的时间学习,真的感觉收益非浅,在此声明非常感谢吾修大哥的分享! 1、轻快简单,框架就几个类,简单,当然代码少也就运行快&#xff01…

天池 在线编程 最小振幅(排序)

文章目录1. 题目2. 解题1. 题目 https://tianchi.aliyun.com/oj/231188302809557697/235445278655844966 给定一个由N个整数组成的数组A,一次移动,我们可以选择此数组中的任何元素并将其替换为任何值。 数组的振幅是数组A中的最大值和最小值之间的差。…

文件系统的类型

文件系统的类型 文件系统类型: ext2 : 早期linux中常用的文件系统 ext3 : ext2的升级版,带日志功能 RAMFS : 内存文件系统,速度很快 NFS : 网络文件系统,由SUN发明&a…

Git中非常重要的一个文件——.gitignore详解

首先要强调一点,这个文件的完整文件名就是“.gitignore”,注意最前面有个“.”。这样没有扩展名的文件在Windows下不太好创建,这里给出win7的创建方法: 创建一个文件,文件名为:“.gitignore.”&#xff0c…

行先知 为您的办公室管理提供方便

《行先知》为您的办公室管理提供方便■省时间和空间 公司人员去向及预定事项一目了然。不管你位置在哪里,不需要回头、翘首去看通知栏。不需要一次次去擦写、修改通知栏。尽管公司人员分布在不同的楼层、不同的建筑,人员去向一目了然。不需要再往纸上写留…

Python把两张图片拼接为一张图片并保存

这里主要用Python扩展库pillow中Image对象的paste()方法把两张图片拼接起来 from os import listdir from PIL import Imagedef pinjie():# 获取当前文件夹中所有JPG图像im_list [Image.open(fn) for fn in listdir() if fn.endswith(.jpg)]# 图片转化为相同的尺寸ims []for…

ubuntu 13.04下MYSQL 5.5环境搭建

解决的问题: 安装mysql server和mysql client 5.5 新建远程账户 远程访问权限 MYSQL默认字符集修改为UTF8 检查防火墙 一、安装 BTW:可以使用查找命令查看安装包 sudo apt- 安装命令 sudo apt-get install mysql-server-5.5 回车 (有一个带core的&…