基于LSTM encoder-decoder模型实现英文转中文的翻译机器

前言

神经网络机器翻译(NMT, neuro machine tranlation)是AIGC发展道路上的一个重要应用。正是对这个应用的研究,发展出了注意力机制,在此基础上产生了AIGC领域的霸主transformer。我们今天先把注意力机制这些东西放一边,介绍一个对机器翻译起到重要里程碑作用的模型:LSTM encoder-decoder模型(sutskever et al. 2014)。根据这篇文章的描述,这个模型不需要特别的优化,就可以取得超过其他NMT模型的效果,所以我们也来动手实现一下,看看是不是真的有这么厉害。

模型

原文作者采用了4层LSTM模型,每层有1000个单元(每个单元有输入门,输出门,遗忘门和细胞状态更新共计4组状态),采用1000维单词向量,纯RNN部分,就有64M参数。同时,在encoder的输出,和decoder的输出后放一个长度为80000的softmax层(因为论文的输出字典长80000),用于softmax的参数量为320M。整个模型共计320M + 64M = 384M。该模型用了8GPU的服务器训练了10天。
模型大概长这样:
在这里插入图片描述
按照现在的算力价格,用8张4090的主机训练每小时要花20多块钱,训练一轮下来需要花费小5000,笔者当然没有这么土豪,所以我们会使用一个参数量小得多的模型,主要为了记录整个搭建过程使用到的工具链和技术。另外,由于笔者使用了一个预训练的词向量库,包含了中英文单词共计128万多条,其中中文90多万,英文30多万,要像论文中一样用一个超大的softmax来预测每个词的概率并不现实,因此先使用一个linear层再加上relu来简化,加快训练过程,只求能看到收敛。

笔者的模型看起来像这样:
在这里插入图片描述
该模型的主要参数如下:
词向量维度:300
LSTM隐藏层个数:600
LSTM层数:4
linear层输入:600
linear层输出:300
模型参数个数如下为:

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Seq2Seq                                  [1, 11, 300]              --
├─Encoder: 1-1                           [1, 300]                  --
│    └─LSTM: 2-1                         [1, 10, 600]              10,819,200
│    └─Linear: 2-2                       [1, 300]                  180,300
│    └─ReLU: 2-3                         [1, 300]                  --
├─Decoder: 1-2                           [1, 11, 300]              --
│    └─LSTM: 2-4                         [1, 11, 600]              10,819,200
│    └─Linear: 2-5                       [1, 11, 300]              180,300
│    └─ReLU: 2-6                         [1, 11, 300]              --
==========================================================================================
Total params: 21,999,000
Trainable params: 21,999,000
Non-trainable params: 0
Total mult-adds (M): 227.56
==========================================================================================
Input size (MB): 0.02
Forward/backward pass size (MB): 0.13
Params size (MB): 88.00
Estimated Total Size (MB): 88.15
==========================================================================================

如果大家希望了解LSTM层的10,819,200个参数如何计算出来,可以参考pytorch源码 pytorch/torch/csrc/api/src/nn/modules/rnn.cpp中方法void RNNImplBase::reset()的实现。笔者如果日后有空也可能会写一写。

3 单词向量及语料

3.1 语料

先说语料,NMT需要大量的平行语料,语料可以从这里获取。另外有个语料天涯网站大全分享给大家。

3.2 词向量

首先需要对句子进行分词,中英文都需要做分词。中文分词工具本例采用jieba,可直接安装。

$ pip install jieba
...
$ python
Python 3.11.6 (tags/v3.11.6:8b6ee5b, Oct  2 2023, 14:57:12) [MSC v.1935 64 bit (AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> for token in jieba.cut("我爱踢足球!", cut_all=True):
...     print(token)
... 
我
爱
踢足球
足球
!

英文分词采用nltk,安装之后,需要下载一个分词模型。

$ pip install nltk
...
$ python
Python 3.11.6 (tags/v3.11.6:8b6ee5b, Oct  2 2023, 14:57:12) [MSC v.1935 64 bit (AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import nltk
>>> nltk.download("punkt")
...
>>> from nltk import word_tokenize
>>> word_tokenize('i love you')
['i', 'love', 'you']

国内有墙,一般下载不了,所以可以到这里找到punkt文件并下载,解压到~/nltk_data/tokenizers/下边。

3.3 加载语料代码

import xml.etree.ElementTree as ETclass TmxHandler():def __init__(self):self.tag=Noneself.lang=Noneself.corpus={}def handleStartTu(self, tag):self.tag=tagself.lang=Noneself.corpus={}def handleStartTuv(self, tag, attributes):if self.tag == 'tu':if attributes['{http://www.w3.org/XML/1998/namespace}lang']:self.lang=attributes['{http://www.w3.org/XML/1998/namespace}lang']else:raise Exception('tuv element must has a xml:lang attribute')self.tag = tagelse:raise Exception('tuv element must go under tu, not ' + tag)def handleStartSeg(self, tag, elem):if self.tag == 'tuv':self.tag = tagif self.lang:if elem.text:self.corpus[self.lang]=elem.textelse:raise Exception('lang must not be none')else:raise Exception('seg element must go under tuv, not ' + tag)def startElement(self, tag, attributes, elem):if tag== 'tu':self.handleStartTu(tag)elif tag == 'tuv':self.handleStartTuv(tag, attributes)elif tag == 'seg':self.handleStartSeg(tag, elem)def endElem(self, tag):if self.tag and self.tag != tag:raise Exception(self.tag + ' could not end with ' + tag)if tag == 'tu':self.tag=Noneself.lang=Noneself.corpus={}elif tag == 'tuv':self.tag='tu'self.lang=Noneelif tag == 'seg':self.tag='tuv'def parse(self, filename):for event, elem in ET.iterparse(filename, events=('start','end')):if event == 'start':self.startElement(elem.tag, elem.attrib, elem)elif event == 'end':if elem.tag=='tu':yield self.corpusself.endElem(elem.tag)

3.4 句子转词向量代码

from gensim.models import KeyedVectors
import torch
import jieba
from nltk import word_tokenize
import numpy as npclass WordEmbeddingLoader():def __init__(self):passdef load(self, fname):self.model = KeyedVectors.load_word2vec_format(fname)def get_embeddings(self, word:str):if self.model:try:return self.model.get_vector(word)except(KeyError):return Noneelse:return Nonedef get_scentence_embeddings(self, scent:str, lang:str):embeddings = []ws = []if(lang == 'zh'):ws = jieba.cut(scent, cut_all=True)elif lang == 'en':ws = word_tokenize(scent)else:raise Exception('Unsupported language ' + lang)for w in ws:embedding = self.get_embeddings(w.lower())if embedding is None:embedding = np.zeros(self.model.vector_size)embedding = torch.from_numpy(embedding).float()embeddings.append(embedding.unsqueeze(0))return torch.cat(embeddings, dim=0)

4 模型代码实现

4.1 encoder

import torch.nn as nnclass Encoder(nn.Module):def __init__(self, device, embeddings=300, hidden_size=600, num_layers=4):super().__init__()self.device = deviceself.hidden_layer_size = hidden_sizeself.n_layers = num_layersself.embedding_size = embeddingsself.lstm = nn.LSTM(embeddings, hidden_size, num_layers, batch_first=True)self.linear = nn.Linear(hidden_size, embeddings)self.relu = nn.ReLU()def forward(self, x):# x: [batch size, seq length, embeddings]# lstm_out: [batch size, x length, hidden size]lstm_out, (hidden, cell) = self.lstm(x)# linear input is the lstm output of the last wordlineared = self.linear(lstm_out[:,-1,:].squeeze(1))out = self.relu(lineared)# hidden: [n_layer, batch size, hidden size]# cell: [n_layer, batch size, hidden size]return out, hidden, cell

4.2 decoder

import torch.nn as nnclass Decoder(nn.Module):def __init__(self, device, embedding_size=300, hidden_size=900, num_layers=4):super().__init__()self.device = deviceself.hidden_layer_size = hidden_sizeself.n_layers = num_layersself.embedding_size = embedding_sizeself.lstm = nn.LSTM(embedding_size, hidden_size, num_layers, batch_first=True)self.linear = nn.Linear(hidden_size, embedding_size)self.relu = nn.ReLU()def forward(self, x, hidden_in, cell_in):# x: [batch_size, x length, embeddings]# hidden: [n_layers, batch size, hidden size]# cell: [n_layers, batch size, hidden size]# lstm_out: [seq length, batch size, hidden size]lstm_out, (hidden,cell) = self.lstm(x, (hidden_in, cell_in))# prediction: [seq length, batch size, embeddings]prediction=self.relu(self.linear(lstm_out))return prediction, hidden, cell

4.3 encoder-decoder

接下来把encoder和decoder串联起来。

import torch
import encoder as enc
import decoder as dec
import torch.nn as nn
import timeclass Seq2Seq(nn.Module):def __init__(self, device, embeddings, hiddens, n_layers):super().__init__()self.device = deviceself.encoder = enc.Encoder(device, embeddings, hiddens, n_layers)self.decoder= dec.Decoder(device, embeddings, hiddens, n_layers)self.embeddings = self.encoder.embedding_sizeassert self.encoder.n_layers == self.decoder.n_layers, "Number of layers of encoder and decoder must be equal!"assert self.decoder.hidden_layer_size==self.decoder.hidden_layer_size, "Hidden layer size of encoder and decoder must be equal!"# x: [batches, x length, embeddings]# x is the source scentences# y: [batches, y length, embeddings]# y is the target scentencesdef forward(self, x, y):# encoder_out: [batches, n_layers, embeddings]# hidden, cell: [n layers, batch size, embeddings]encoder_out, hidden, cell = self.encoder(x)# use encoder output as the first word of the decode sequencedecoder_input = torch.cat((encoder_out.unsqueeze(0), y), dim=1)# predicted: [batches, y length, embeddings]predicted, hidden, cell = self.decoder(decoder_input, hidden, cell)return predicted

5 模型训练

5.1 训练代码


def do_train(model:Seq2Seq, train_set, optimizer, loss_function):step = 0model.train()# seq: [seq length, embeddings]# labels: [label length, embeddings]for seq, labels in train_set:step = step + 1# ignore the last word of the label scentence# because it is to be predictedlabel_input = labels[:-1].unsqueeze(0)# seq_input: [1, seq length, embeddings]seq_input = seq.unsqueeze(0)# y_pred: [1, seq length + 1, embeddings]y_pred = model(seq_input, label_input)# single_loss = loss_function(y_pred.squeeze(0), labels.to(self.device))single_loss = loss_function(y_pred.squeeze(0), labels)optimizer.zero_grad()single_loss.backward()optimizer.step()print_steps = 100if print_steps != 0 and step%print_steps==1:print(f'[step: {step} - {time.asctime(time.localtime(time.time()))}] - loss:{single_loss.item():10.8f}')def train(device, model, embedding_loader, corpus_fname, batch_size:int, batches: int):reader = corpus_reader.TmxHandler()loss = torch.nn.MSELoss()# summary(model, input_size=[(1, 10, 300),(1,10,300)])optimizer = torch.optim.SGD(model.parameters(), lr=0.01)generator = reader.parse(corpus_fname)for _b in range(batches):batch = []try:for _c in range(batch_size):try:corpus = next(generator)if 'en' in corpus and 'zh' in corpus:en = embedding_loader.get_scentence_embeddings(corpus['en'], 'en').to(device)zh = embedding_loader.get_scentence_embeddings(corpus['zh'], 'zh').to(device)batch.append((en,zh))except (StopIteration):breakfinally:print(time.localtime())print("batch: " + str(_b))do_train(model, batch, optimizer, loss)torch.save(model, "./models/seq2seq_" + str(time.time()))if __name__=="__main__":# device = torch.device('cuda')device = torch.device('cpu')embeddings = 300hiddens = 600n_layers = 4embedding_loader = word2vec.WordEmbeddingLoader()print("loading embedding")# a full vocabulary takes too long to load, a baby vocabulary is used for demo purposeembedding_loader.load("../sgns.merge.word.toy")print("load embedding finished")# if there is an existing model, load the existing model from file# model_fname = "./models/_seq2seq_1698000846.3281412"model_fname = Nonemodel = Noneif not model_fname is None:print('loading model from ' + model_fname)model = torch.load(model_fname, map_location=device)print('model loaded')else:model = Seq2Seq(device, embeddings, hiddens, n_layers).to(device)train(device, model, embedding_loader, "../News-Commentary_v16.tmx", 1000, 100)

5.2 使用CPU进行训练

让我们先来体验一下CPU的龟速训练。下图是每100句话的训练输出。每次打印的间隔大约为2-3分钟。

[step: 1 - Thu Oct 26 05:14:13 2023] - loss:0.00952744
[step: 101 - Thu Oct 26 05:17:11 2023] - loss:0.00855174
[step: 201 - Thu Oct 26 05:20:07 2023] - loss:0.00831730
[step: 301 - Thu Oct 26 05:23:09 2023] - loss:0.00032693
[step: 401 - Thu Oct 26 05:25:55 2023] - loss:0.00907284
[step: 501 - Thu Oct 26 05:28:55 2023] - loss:0.00937218
[step: 601 - Thu Oct 26 05:32:00 2023] - loss:0.00823146

5.3 使用GPU进行训练

如果把main函数的第一行中的"cpu"改成“cuda”,则可以使用显卡进行训练。笔者使用的是一张GTX1660显卡,打印间隔缩短为15秒。

[step: 1 - Thu Oct 26 06:38:45 2023] - loss:0.00955237
[step: 101 - Thu Oct 26 06:38:50 2023] - loss:0.00844441
[step: 201 - Thu Oct 26 06:38:56 2023] - loss:0.00820994
[step: 301 - Thu Oct 26 06:39:01 2023] - loss:0.00030389
[step: 401 - Thu Oct 26 06:39:06 2023] - loss:0.00896622
[step: 501 - Thu Oct 26 06:39:11 2023] - loss:0.00929985
[step: 601 - Thu Oct 26 06:39:17 2023] - loss:0.00813591

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/119148.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

文心一言简单体验

百度正式发布文心一言,文心一言 这里的插件模式挺有意思: 测试了一下图解说明,随意上传了一张图片: 提供图解让反过来画,抓住了部分重点,但是还是和原图有比较大的差异! 百宝箱 暂未逐个体验&am…

Linux ————​文件权限

(一)文件权限 基础补充 文件基本属性(Linux中万物皆是文件)文件是操作系统用来存储信息的基本结构,是一组信息的集合。文件通过文件名来唯一标识。Linux中的文件名称最长允许255个字符,这些字符可用A~Z、0…

【JAVA学习笔记】46 - (43)第十一章作业

项目代码 https://github.com/yinhai1114/Java_Learning_Code/tree/main/IDEA_Chapter11/src/com/yinhai/homework11 1.枚举类 1.创建一个Color枚举类 2.有RED,BLUE,BL ACK,YELLOW,GREEN这个五个枚举值/对象: 3. Color有三 个属性redValue, greenValue, blueValue, 4.创建构…

点击弹出实现模拟百度那样子

<uni-section title"输入框示例" type"line" padding><view class"dialog-box"><text class"dialog-text">输入内容&#xff1a;{{ value }}</text></view><button class"button" type&qu…

PL/SQL工具下载地址

https://www.allroundautomations.com/registered-plsqldev/ 选择需要下载的版本即可

LuaTable转C#的列表List和字典Dictionary

LuaTable转C#的列表List和字典Dictionaty 介绍lua中创建表测试lua中list表表转成List表转成Dictionary 键值对表表转成Dictionary 多类型键值对表表转成Dictionary 总结 介绍 之前基本都是从C#中的List或者Dictionary转成luaTable&#xff0c;很少会把LuaTable转成C#的List或者…

深入浅出排序算法之简单选择排序

目录 1. 原理和执行流程 2. 代码实现 3. 性能分析 4. 双向选择排序&#xff08;了解&#xff09; 1. 原理和执行流程 选择排序包含了堆排序和简单选择排序。 每一次从无序区间选出最大&#xff08;或最小&#xff09;的一个元素&#xff0c;存放在无序区间的最后&#xff0…

道路数据汇总,全国(2021年+2022年)+重点城市(深圳、上海、武汉、杭州、广州、南京、东莞),格式有shp+xlsx

昨天推了上海道路数据&#xff0c;今天把已收集到的道路数据打包推给大家&#xff0c;后续有新数据会持续更新&#xff01; 废话不多说&#xff0c;先给数据地址再介绍数据情况&#xff1a; 2021年全国道路数据&#xff1a; 2021年全国道路数据https://www.xcitybox.com/dat…

uni-app医院智能导诊系统源码

随着科技的迅速发展&#xff0c;人工智能已经逐渐渗透到我们生活的各个领域。在医疗行业中&#xff0c;智能导诊系统成为了一个备受关注的应用。本文将详细介绍智能导诊系统的概念、技术原理以及在医疗领域中的应用&#xff0c;分析其优势和未来发展趋势。 智能导诊系统通过人工…

迭代器的封装与反向迭代器

一、反向迭代器 在list模拟实现的过程中&#xff0c;第一次接触了迭代器的封装&#xff0c;将list的指针封装成了一个新的类型&#xff0c;并且以迭代器的基本功能对其进行了运算符重载 反向迭代器是对正向迭代器的封装&#xff0c;并且体现了泛型编程的思想&#xff0c;任意…

如何在 openSUSE 中使用 Zypper Configuration 设置代理

如何在 openSUSE 中使用 Zypper Configuration 设置代理 首先&#xff0c;确定问题&#xff1a;设置代理服务器以便 Zypper 能够访问互联网并下载软件包。 亲身经验&#xff1a;我曾在使用 openSUSE 时遇到过类似问题&#xff0c;通过设置代理服务器成功解决。 数据和引证&…

C++初阶:C/C++内存管理

一.C/C内存分布 先来回顾一下C语言内存分区示意图如下&#xff1a; 代码区&#xff1a; 程序执行代码一般存放在代码区&#xff0c;字符串常量以及define定义的常量也可能存放在代码区。 常量区&#xff1a; 字符串&#xff0c;数字等常量以及const修饰的全局变量往往存放在…

day51 --动态规划10

121. 买卖股票的最佳时机 122.买卖股票的最佳时机II 第一题&#xff1a;买卖股票的最佳时机 给定一个数组 prices &#xff0c;它的第 i 个元素 prices[i] 表示一支给定股票第 i 天的价格。 你只能选择 某一天 买入这只股票&#xff0c;并选择在 未来的某一个不同的日子 卖出…

员工福利平台设计方案

需求背景&#xff1a; 1、杭州行政希望给员工有一个福利平台&#xff0c;可以通过该福利平台&#xff0c;一方面可以结合公司周围的实体店&#xff0c;给到员工一些福利的商品&#xff0c;员工可以自行去这些商家进行消费。 2、公司可以通过福利平台&#xff0c;给员工账户进…

K8s 部署 CNI 网络组件+k8s 多master集群部署+负载均衡

------------------------------ 部署 CNI 网络组件 ------------------------------ ---------- 部署 flannel ---------- K8S 中 Pod 网络通信&#xff1a; ●Pod 内容器与容器之间的通信 在同一个 Pod 内的容器&#xff08;Pod 内的容器是不会跨宿主机的&#xff09;共享同一…

实验六:DHCP、DNS、Apache、FTP服务器的安装和配置

1. (其它) 掌握Linux下DHCP、DNS、Apache、FTP服务器的安装和配置&#xff0c;在Linux服务器上部署JavaWeb应用 完成单元八的实训内容。 1、安装 JDK 2、安装 MySQL 3、部署JavaWeb应用 安装jdk 教程连接&#xff1a;linux安装jdk8详细步骤-CSDN博客 Jdk来源&#xff1a;linu…

【Django 05】Django-DRF(ModelViewSet)、路由组件、自定义函数

1. Django-DRF&#xff08;ModelViewSet&#xff09; 1.1 DRF是什么&#xff1f; ModelViewSet 是 Django REST framework 提供的一个视图集类&#xff0c;它封装了常见的模型操作方法。 模型类提供了默认的增删改查功能。 它继承自 GenericViewSet、ListModelMixin、Retri…

基于pyenv和virtualenv搭建python多版本虚拟环境

pyenv简介 由于Python的依赖是基于site的&#xff0c;这对于生产环境来说&#xff0c;是一种简单而正确的方式&#xff0c;然而&#xff0c;对于我们的开发环境&#xff0c;基于这样的管理方式&#xff0c;带来了可怕的第三方依赖管理的难题&#xff0c;virtualenv适时出现了&a…

Altium Designer布局技巧

资料 快捷键 PCB导入原理图 验证工程 导入原理图 进入PCB编辑界面&#xff0c;设计→Import Changes from xxxx 多原理图多PCB 创建多个原理图、PCB 略反键点击原理图 勾选高级 选择原理图及目标PCB&#xff0c;点击确定 右键点击列表项&#xff0c;更新原理图&#xff0…

手机桌面待办事项APP推荐

每天&#xff0c;我们每个人都面临着繁琐的事务和任务&#xff0c;而手机成了我们日常生活中不可或缺的伙伴。手机上的待办事项工具像一个可靠的助手&#xff0c;可以帮助我们更好地记录、管理和完成任务。在手机桌面上使用的待办事项APP推荐用哪一个呢&#xff1f; 手机是我们…