PyTorch实战:基于Seq2seq模型处理机器翻译任务(模型预测)

文章目录

  • 引言
  • 数据预处理
    • 加载字典对象`en2id`和`zh2id`
    • 文本分词
  • 加载训练好的Seq2Seq模型
  • 模型预测完整代码
  • 结束语

引言

随着全球化的深入,翻译需求日益增长。传统的人工翻译方式虽然质量高,但效率低,成本高。机器翻译的出现,为解决这一问题提供了可能。英译中机器翻译任务是机器翻译领域的一个重要分支,旨在将英文文本自动翻译成中文。本博客以《PyTorch自然语言处理入门与实战》第九章的Seq2seq模型处理英译中翻译任务作为基础,附上模型预测模块。

模型的训练及验证模块的详细解析见PyTorch实战:基于Seq2seq模型处理机器翻译任务(模型训练及验证)

数据预处理

加载字典对象en2idzh2id

在预测阶段中,需要加载模型训练及验证阶段保存的字典对象en2idzh2id

代码如下:

import picklewith open("en2id.pkl", 'rb') as f:en2id = pickle.load(f)
with open("zh2id.pkl", 'rb') as f:zh2id = pickle.load(f)

文本分词

在对输入文本进行预测时,需要先将文本进行分词操作。参考代码如下:

def extract_words(sentence):  """  从给定的英文句子中提取单词,并去除单词后的标点符号。  Args:  sentence (str): 要提取单词的英文句子。  Returns:  List[str]: 提取并处理后的单词列表。  """  en_words = []  for w in sentence.split(' '):  # 将英文句子按空格分词  w = w.replace('.', '').replace(',', '')  # 去除跟单词连着的标点符号  w = w.lower()  # 统一单词大小写  if w:  en_words.append(w)  return en_words  # 测试函数  
sentence = 'I am Dave Gallo.'  
print(extract_words(sentence))

运行结果:

加载训练好的Seq2Seq模型

代码如下:

import torch
import torch.nn as nnclass Encoder(nn.Module):def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):super().__init__()self.hid_dim = hid_dimself.n_layers = n_layersself.embedding = nn.Embedding(input_dim, emb_dim)  # 词嵌入self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)self.dropout = nn.Dropout(dropout)def forward(self, src):# src = (src len, batch size)embedded = self.dropout(self.embedding(src))# embedded = (src len, batch size, emb dim)outputs, (hidden, cell) = self.rnn(embedded)# outputs = (src len, batch size, hid dim * n directions)# hidden = (n layers * n directions, batch size, hid dim)# cell = (n layers * n directions, batch size, hid dim)# rnn的输出总是来自顶部的隐藏层return hidden, cellclass Decoder(nn.Module):def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):super().__init__()self.output_dim = output_dimself.hid_dim = hid_dimself.n_layers = n_layersself.embedding = nn.Embedding(output_dim, emb_dim)self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)self.fc_out = nn.Linear(hid_dim, output_dim)self.dropout = nn.Dropout(dropout)def forward(self, input, hidden, cell):# 各输入的形状# input = (batch size)# hidden = (n layers * n directions, batch size, hid dim)# cell = (n layers * n directions, batch size, hid dim)# LSTM是单向的  ==> n directions == 1# hidden = (n layers, batch size, hid dim)# cell = (n layers, batch size, hid dim)input = input.unsqueeze(0)  # (batch size)  --> [1, batch size)embedded = self.dropout(self.embedding(input))  # (1, batch size, emb dim)output, (hidden, cell) = self.rnn(embedded, (hidden, cell))# LSTM理论上的输出形状# output = (seq len, batch size, hid dim * n directions)# hidden = (n layers * n directions, batch size, hid dim)# cell = (n layers * n directions, batch size, hid dim)# 解码器中的序列长度 seq len == 1# 解码器的LSTM是单向的 n directions == 1 则实际上# output = (1, batch size, hid dim)# hidden = (n layers, batch size, hid dim)# cell = (n layers, batch size, hid dim)prediction = self.fc_out(output.squeeze(0))# prediction = (batch size, output dim)return prediction, hidden, cellclass Seq2Seq(nn.Module):def __init__(self, input_word_count, output_word_count, encode_dim, decode_dim, hidden_dim, n_layers,encode_dropout, decode_dropout, device):""":param input_word_count:    英文词表的长度     34737:param output_word_count:   中文词表的长度     4015:param encode_dim:          编码器的词嵌入维度:param decode_dim:          解码器的词嵌入维度:param hidden_dim:          LSTM的隐藏层维度:param n_layers:            采用n层LSTM:param encode_dropout:      编码器的dropout概率:param decode_dropout:      编码器的dropout概率:param device:              cuda / cpu"""super().__init__()self.encoder = Encoder(input_word_count, encode_dim, hidden_dim, n_layers, encode_dropout)self.decoder = Decoder(output_word_count, decode_dim, hidden_dim, n_layers, decode_dropout)self.device = devicedef forward(self, src):# src = (src len, batch size)# 编码器的隐藏层输出将作为解码器的第一个隐藏层输入hidden, cell = self.encoder(src)# 解码器的第一个输入应该是起始标识符<sos>input = src[0, :]  # 取trg的第“0”行所有列  “0”指的是索引pred = [0] # 预测的第一个输出应该是起始标识符top1 = 0while top1 != 1 and len(pred) < 100:# 解码器的输入包括:起始标识符的词嵌入input; 编码器输出的 hidden and cell states# 解码器的输出包括:输出张量(predictions) and new hidden and cell statesoutput, hidden, cell = self.decoder(input, hidden, cell)top1 = output.argmax(dim=1)  # (batch size, )pred.append(top1.item())input = top1return preddevice = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')  # GPU可用 用GPU
# Seq2Seq模型实例化
source_word_count = 34737  # 英文词表的长度     34737
target_word_count = 4015  # 中文词表的长度     4015
encode_dim = 256  # 编码器的词嵌入维度
decode_dim = 256  # 解码器的词嵌入维度
hidden_dim = 512  # LSTM的隐藏层维度
n_layers = 2  # 采用n层LSTM
encode_dropout = 0.5  # 编码器的dropout概率
decode_dropout = 0.5  # 编码器的dropout概率
model = Seq2Seq(source_word_count, target_word_count, encode_dim, decode_dim, hidden_dim, n_layers, encode_dropout,decode_dropout, device).to(device)# 加载训练好的模型
model.load_state_dict(torch.load("best_model.pth"))
model.eval()

模型预测完整代码

提示预测代码是我们基于训练及验证代码进行改造的,不一定完全正确,可以参考后自行修改~

import torch
import torch.nn as nn
import pickleclass Encoder(nn.Module):def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):super().__init__()self.hid_dim = hid_dimself.n_layers = n_layersself.embedding = nn.Embedding(input_dim, emb_dim)  # 词嵌入self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)self.dropout = nn.Dropout(dropout)def forward(self, src):# src = (src len, batch size)embedded = self.dropout(self.embedding(src))# embedded = (src len, batch size, emb dim)outputs, (hidden, cell) = self.rnn(embedded)# outputs = (src len, batch size, hid dim * n directions)# hidden = (n layers * n directions, batch size, hid dim)# cell = (n layers * n directions, batch size, hid dim)# rnn的输出总是来自顶部的隐藏层return hidden, cellclass Decoder(nn.Module):def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):super().__init__()self.output_dim = output_dimself.hid_dim = hid_dimself.n_layers = n_layersself.embedding = nn.Embedding(output_dim, emb_dim)self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)self.fc_out = nn.Linear(hid_dim, output_dim)self.dropout = nn.Dropout(dropout)def forward(self, input, hidden, cell):# 各输入的形状# input = (batch size)# hidden = (n layers * n directions, batch size, hid dim)# cell = (n layers * n directions, batch size, hid dim)# LSTM是单向的  ==> n directions == 1# hidden = (n layers, batch size, hid dim)# cell = (n layers, batch size, hid dim)input = input.unsqueeze(0)  # (batch size)  --> [1, batch size)embedded = self.dropout(self.embedding(input))  # (1, batch size, emb dim)output, (hidden, cell) = self.rnn(embedded, (hidden, cell))# LSTM理论上的输出形状# output = (seq len, batch size, hid dim * n directions)# hidden = (n layers * n directions, batch size, hid dim)# cell = (n layers * n directions, batch size, hid dim)# 解码器中的序列长度 seq len == 1# 解码器的LSTM是单向的 n directions == 1 则实际上# output = (1, batch size, hid dim)# hidden = (n layers, batch size, hid dim)# cell = (n layers, batch size, hid dim)prediction = self.fc_out(output.squeeze(0))# prediction = (batch size, output dim)return prediction, hidden, cellclass Seq2Seq(nn.Module):def __init__(self, input_word_count, output_word_count, encode_dim, decode_dim, hidden_dim, n_layers,encode_dropout, decode_dropout, device):""":param input_word_count:    英文词表的长度     34737:param output_word_count:   中文词表的长度     4015:param encode_dim:          编码器的词嵌入维度:param decode_dim:          解码器的词嵌入维度:param hidden_dim:          LSTM的隐藏层维度:param n_layers:            采用n层LSTM:param encode_dropout:      编码器的dropout概率:param decode_dropout:      编码器的dropout概率:param device:              cuda / cpu"""super().__init__()self.encoder = Encoder(input_word_count, encode_dim, hidden_dim, n_layers, encode_dropout)self.decoder = Decoder(output_word_count, decode_dim, hidden_dim, n_layers, decode_dropout)self.device = devicedef forward(self, src):# src = (src len, batch size)# 编码器的隐藏层输出将作为解码器的第一个隐藏层输入hidden, cell = self.encoder(src)# 解码器的第一个输入应该是起始标识符<sos>input = src[0, :]  # 取trg的第“0”行所有列  “0”指的是索引pred = [0] # 预测的第一个输出应该是起始标识符top1 = 0while top1 != 1 and len(pred) < 100:# 解码器的输入包括:起始标识符的词嵌入input; 编码器输出的 hidden and cell states# 解码器的输出包括:输出张量(predictions) and new hidden and cell statesoutput, hidden, cell = self.decoder(input, hidden, cell)top1 = output.argmax(dim=1)  # (batch size, )pred.append(top1.item())input = top1return predif __name__ == '__main__':sentence = 'I am Dave Gallo.'en_words = []for w in sentence.split(' '):  # 英文内容按照空格字符进行分词# 按照空格进行分词后,某些单词后面会跟着标点符号 "." 和 “,”w = w.replace('.', '').replace(',', '')  # 去掉跟单词连着的标点符号w = w.lower()  # 统一单词大小写if w:en_words.append(w)print(en_words)with open("en2id.pkl", 'rb') as f:en2id = pickle.load(f)with open("zh2id.pkl", 'rb') as f:zh2id = pickle.load(f)device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')  # GPU可用 用GPU# Seq2Seq模型实例化source_word_count = 34737  # 英文词表的长度     34737target_word_count = 4015  # 中文词表的长度     4015encode_dim = 256  # 编码器的词嵌入维度decode_dim = 256  # 解码器的词嵌入维度hidden_dim = 512  # LSTM的隐藏层维度n_layers = 2  # 采用n层LSTMencode_dropout = 0.5  # 编码器的dropout概率decode_dropout = 0.5  # 编码器的dropout概率model = Seq2Seq(source_word_count, target_word_count, encode_dim, decode_dim, hidden_dim, n_layers, encode_dropout,decode_dropout, device).to(device)model.load_state_dict(torch.load("best_model.pth"))model.eval()src = [0] # 0 --> 起始标识符的编码for i in range(len(en_words)):src.append(en2id[en_words[i]])src = src + [1] # 1 --> 终止标识符的编码text_input = torch.LongTensor(src)text_input = text_input.unsqueeze(-1).to(device)text_output = model(text_input)print(text_output)id2zh = dict()for k, v in zh2id.items():id2zh[v] = ktext_output = [id2zh[index] for index in text_output]text_output = " ".join(text_output)print(text_output)

结束语

  • 亲爱的读者,感谢您花时间阅读我们的博客。我们非常重视您的反馈和意见,因此在这里鼓励您对我们的博客进行评论。
  • 您的建议和看法对我们来说非常重要,这有助于我们更好地了解您的需求,并提供更高质量的内容和服务。
  • 无论您是喜欢我们的博客还是对其有任何疑问或建议,我们都非常期待您的留言。让我们一起互动,共同进步!谢谢您的支持和参与!
  • 我会坚持不懈地创作,并持续优化博文质量,为您提供更好的阅读体验。
  • 谢谢您的阅读!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/581687.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PYTHON基础:数据可视化绘图

python数据可视化入门 –常见的四种数据图形绘制 数据可视化在数据分析和数据科学中起着重要的作用。它可以帮助我们更直观地理解和解释数据&#xff0c;发现数据中的模式、趋势和异常。 在数据可视化中&#xff0c;常用的图表类型包括折线图、散点图、直方图和饼图&#xff…

为什么要运营海外社媒?海外云手机能发挥什么作用?

基于海外社媒在全球范围内拥有的大量流量&#xff0c;海外社媒运营成为了品牌推广、内容创作和用户互动的重要途径。本文将探讨海外社媒运营的重要性&#xff0c;并介绍海外云手机在这一过程中的卓越帮助。 海外社媒运营的重要性 首先&#xff0c;海外社媒运营有助于企业扩大品…

Qt高质量的开源项目合集

文章目录 1.Qt官网下载/文档2.第三方开源 1.Qt官网下载/文档 Qt Downloads Qt 清华大学开源软件镜像站 Qt 官方博客 2.第三方开源 记录了平常项目开发中用到的第三方库&#xff0c;以及一些值得参考的项目&#xff01; Qt AV 基于Qt和FFmpeg的跨平台高性能音视频播放框…

EasyExcel导出

1.简介 官网&#xff1a;EasyExcel官方文档 - 基于Java的Excel处理工具 | Easy Excel 2.案例 2.1 实现的效果 效果图如下&#xff1a; 2.2 实现步骤 三种情景&#xff0c;主要是表头和数据有区别&#xff0c;简列实现步骤如下&#xff1a; 2.3 具体实现 2.3.1 前置-依赖导入…

【LeetCode-剑指offer】--3.比特位计数

3.比特位计数 class Solution {public int[] countBits(int n) {int[] bites new int[n 1];for(int i 0 ; i < n;i){bites[i] Count(i);}return bites;}public int Count(int x){int count 0;while(x > 0){x & (x - 1);count;}return count;} }

Python入门学习篇(十)——函数定义函数传参方式

1 相关定义和概念 1.1 函数的理解 一段被封装的可以重复调用的代码。 1.2 函数定义语法结构 def 函数名(形参1,形参2):要封装的逻辑代码 # 注意:函数可以有返回值也可以没有返回值,没有返回值的结果是None1.3 函数调用的语法结构 函数名(形参1,形参2)1.4 简单实例 1.4.1 …

同义词替换降低论文相似度的注意事项 papergpt

大家好&#xff0c;今天来聊聊同义词替换降低论文相似度的注意事项&#xff0c;希望能给大家提供一点参考。 以下是针对论文重复率高的情况&#xff0c;提供一些修改建议和技巧&#xff0c;可以借助此类工具&#xff1a; 标题&#xff1a;同义词替换降低论文相似度的注意事项 …

计算 10亿 的和,js 和 c 的处理时长对比

计算 10亿 的和&#xff0c;js 和 c 的处理时长对比 js 4.17s let sum 0; let start new Date().getTime(); for (let i0;i<1000000000; ii1){sum sum i; } let stop new Date().getTime(); console.log((stop - start)/1000, sum);结果&#xff1a; c 3.65s #in…

Windows搭建FTP服务器教学以及计算机端口介绍

目录 一. FTP服务器介绍 FTP服务器是什么意思&#xff1f; 二.Windows Service 2012 搭建FTP服务器 1.开启防火墙 2.创建组 ​编辑3.创建用户 4.用户绑定组 5.安装ftp服务器 ​编辑6.配置ftp服务器 7.配置ftp文件夹的权限 8.连接测试 三.计算机端口介绍 什么是网络…

解决ELement-UI懒加载三级联动数据不回显(天坑)

最老是遇到这类问题头有点大,最后也是解决了,为铁铁们总结了一下几点 一.查看数据类型是否一致 未选择下 选择下 二.处理数据时使用this.$set方法来动态地设置实例中的属性&#xff0c;以确保其响应式 三.绑定v-if 确保每次重新加载 四.绑定key 五.完整代码

进行VMware日志管理

随着公司转向虚拟化其 IT 空间&#xff0c;虚拟环境日志监控正在占据日志管理的很大一部分,除了确保网络安全外&#xff0c;虚拟机日志监控还有助于管理虚拟化工具&#xff0c;这是最复杂的任务之一。 对虚拟环境日志的监控分析 当今公司中最受欢迎的虚拟平台之一是 VMware。…

图像处理-周期噪声

周期噪声 对于具有周期性的噪声被称为周期噪声&#xff0c;其中周期噪声在频率域会出现关于中心对称的性质&#xff0c;如下图所示 带阻滤波器 为了消除周期性噪声&#xff0c;由此设计了几种常见的滤波器&#xff0c;其中 W W W表示带阻滤波器的带宽 理想带阻滤波器 H ( u …

二维码能转成链接吗?具体步骤是什么样的?

将二维码分解成链接来使用&#xff0c;是经常会出现的一种需求&#xff0c;分解成的链接可以放在电脑浏览器上&#xff0c;就可以在电脑上查看二维码的内容。那么如何将二维码图片做解码处理呢&#xff1f;最简单也是很多人会选择使用的一种方法就是使用二维码解码器来处理&…

如何提高代码质量:5 个基本步骤

软件开发团队有时会遇到各种挑战&#xff0c;导致他们难以按时生产高质量的项目。在这里&#xff0c;我们讨论了通过持续测试快速保证质量的五种策略。 每个人都想要更高质量、更快的软件。对现代软件开发团队的要求是巨大的——从日益激烈的竞争和市场压力、不断增加的功能和…

php-ssrf

漏洞描述&#xff1a; SSRF(Server-Side Request Forgery:服务器端请求伪造) 是一种由攻击者构造形成由服务端发起请求的一个安全漏洞。 一般情况下&#xff0c;SSRF攻击的目标是从外网无法访问的内部系统。&#xff08;正是因为它是由服务端发起的&#xff0c;所以它能够请求…

POI根据表头模板导出excel数据,并指定单个单元格样式,多sheet。

最近的公司需求&#xff0c;因为Excel表头样式较为复杂&#xff0c;不易直接用poi写出。 需要的Excel为这种&#xff1a; 直接模板导出不能成为这样。 public void exportCheckCsdn(HttpServletResponse response) {//获取到MNR 和 MNR-DT 的List// 此处写 获取到指定li…

是德科技E9304A功率传感器

是德科技E9304A二极管功率传感器测量频率范围为9 kHz至6 GHz的平均功率&#xff0c;功率范围为-60至20 dBm。该传感器非常适合甚低频(VLF)功率测量。E系列E9304A功率传感器有两个独立的测量路径&#xff0c;设计用于EPM系列功率计。功率计自动选择合适的功率电平路径。为了避免…

【微信小程序二维码配置】微信公众平台配置二维码,小程序测试二维码,小程序动态二维码,然后扫码打开对应页面进行操作

微信小程序二维码 操作添加二维码地址配置配置项 生成二维码动态二维码生成 操作 微信公众平台地址&#xff1a;微信公众平台 选择 开发管理 – 开发设置 – 扫普通链接二维码打开小程序 添加二维码地址 配置 配置项 二维码规则: URL 为内含下载校验文件的服务器 URL, 可以…

HTML-基础知识-基本结构,注释,文档说明,字符编码(一)

1.超文本标记语言不分大小写。 2.超文本标签属性名和属性值不区分大小写。 3.超文本标签属性值重复&#xff0c;听取第一个。 4.html结构 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"vi…

uniapp项目如何引用安卓原生aar插件(避坑指南三)

官方文档说明&#xff1a;uni小程序SDK 【彩带- 避坑知识点】 如果引用原生aar插件&#xff0c;都配置好之后&#xff0c;云打包&#xff0c;报不包含此插件&#xff0c;除了检查以下步骤流程外&#xff0c;还要检查一下是否上打包的原生插件aar流程有问题。 1.第一步在uniapp项…