【深度学习入门篇 ⑩】Seq2Seq模型:语言翻译

【🍊易编橙:一个帮助编程小伙伴少走弯路的终身成长社群🍊】

大家好,我是小森( ﹡ˆoˆ﹡ ) ! 易编橙·终身成长社群创始团队嘉宾,橙似锦计划领衔成员、阿里云专家博主、腾讯云内容共创官、CSDN人工智能领域优质创作者 。


今天我们进入 Seq2Seq 的领域,了解这种更为复杂且功能强大的模型,它不仅能理解词汇(Word2Vec),还能把这些词汇串联成完整的句子。 

Seq2Seq

Seq2Seq(Sequence-to-Sequence),就是从一个序列到另一个序列的转换。它不仅仅能理解单词之间的关系,而且还能把整个句子的意思打包,并解压成另一种形式的表达。

seq2seq是一种神经网络架构,是由encoder(编码器)decoder(解码器)两个RNN的组成的。其中encoder负责对输入句子的理解,转化为context vector,decoder负责对理解后的句子的向量进行处理,解码,获得输出。  

Seq2seq模型中的encoder接受一个长度为M的序列,得到1个 context vector,之后decoder把这一个context vector转化为长度为N的序列作为输出,从而构成一个M to N的模型,能够处理很多不定长输入输出的问题,比如:文本翻译,问答,文章摘要,关键字写诗等等

  • 编码器的任务是读取并理解输入序列,然后把它转换为一个固定长度的上下文向量,也叫作状态向量。
  • 解码器的任务是接收编码器生成的上下文向量,并基于这个向量生成目标序列。 

可以加入注意力机制(Attention Mechanism):使解码器能够在生成每个输出元素时“关注”输入序列中的不同部分,从而提高模型处理长序列和捕捉复杂依赖关系的能力。 

Seq2Seq模型实现

 任务:

完成一个模型,实现往模型输入一串数字,输出这串数字+0

  •  输入12345678,输出123456780

实现流程

  • 文本转化为序列

  • 使用序列,准备数据集,准备Dataloader

  • 完成编码器

  • 完成解码器

  • 完成seq2seq模型

  • 完成模型训练的逻辑,进行训练

  • 完成模型评估的逻辑,进行模型评估

训练时可以使用GPU训练:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("训练设备为:", device)

文本转化为序列

由于输入的是数字,为了把这写数字和词典中的真实数字进行对应,可以把这些数字理解为字符串

class NumSequence:UNK_TAG = "UNK" PAD_TAG = "PAD" EOS_TAG = "EOS" #句子开始SOS_TAG = "SOS" #句子结束UNK = 0PAD = 1EOS = 2SOS = 3def __init__(self):self.dict = {self.UNK_TAG : self.UNK,self.PAD_TAG : self.PAD,self.EOS_TAG : self.EOS,self.SOS_TAG : self.SOS}# 字符串和数字对应的字典for i in range(10):self.dict[str(i)] = len(self.dict)self.index2word = dict(zip(self.dict.values(),self.dict.keys()))def __len__(self):return len(self.dict)def transform(self,sequence,max_len=None,add_eos=False):sequence_list = list(str(sequence))seq_len = len(sequence_list)+1 if add_eos else len(sequence_list)if add_eos and max_len is not None:assert max_len>= seq_len, "max_len 应该大于seq+eos的长度"_sequence_index = [self.dict.get(i,self.UNK) for i in sequence_list]if add_eos:_sequence_index += [self.EOS]if max_len is not None:sequence_index = [self.PAD]*max_lensequence_index[:seq_len] =  _sequence_indexreturn sequence_indexelse:return _sequence_indexdef inverse_transform(self,sequence_index):result = []for i in sequence_index:if i==self.EOS:breakresult.append(self.index2word.get(int(i),self.UNK_TAG))return resultnum_sequence = NumSequence()if __name__ == '__main__':num_sequence = NumSequence()print(num_sequence.dict)print(num_sequence.index2word)print(num_sequence.transform("232356",add_eos=True))
准备Dataset
from torch.utils.data import Dataset,DataLoader
import numpy as np
from word_sequence import num_sequence
import torch
import configclass RandomDataset(Dataset):def __init__(self):super(RandomDataset,self).__init__()self.total_data_size = 500000np.random.seed(10)self.total_data = np.random.randint(1,100000000,size=[self.total_data_size])def __getitem__(self, idx):input = str(self.total_data[idx])return input, input+ "0",len(input),len(input)+1def __len__(self):return self.total_data_size
准备DataLoader

在准备DataLoader的过程中,可以通过定义的collate_fn来实现对dataset中batch数据的处理

def collate_fn(batch):batch = sorted(batch,key=lambda x:x[3],reverse=True)input,target,input_length,target_length = zip(*batch)input = torch.LongTensor([num_sequence.transform(i,max_len=config.max_len) for i in input])target = torch.LongTensor([num_sequence.transform(i,max_len=config.max_len,add_eos=True) for i in target])input_length = torch.LongTensor(input_length)target_length = torch.LongTensor(target_length)return input,target,input_length,target_lengthdata_loader = DataLoader(dataset=RandomDataset(),batch_size=config.batch_size,collate_fn=collate_fn,drop_last=True)

编码器

目的就是为了对文本进行编码,把编码后的结果交给后续的程序使用,使用Embedding+GRU

import torch.nn as nn
from word_sequence import num_sequence
import configclass NumEncoder(nn.Module):def __init__(self):super(NumEncoder,self).__init__()self.vocab_size = len(num_sequence)self.dropout = config.dropoutself.embedding = nn.Embedding(num_embeddings=self.vocab_size,embedding_dim=config.embedding_dim,padding_idx=num_sequence.PAD)self.gru = nn.GRU(input_size=config.embedding_dim,hidden_size=config.hidden_size,num_layers=1,batch_first=True)def forward(self, input,input_length):embeded = self.embedding(input) embeded = nn.utils.rnn.pack_padded_sequence(embeded,lengths=input_length,batch_first=True)out,hidden = self.gru(embeded)out,outputs_length = nn.utils.rnn.pad_packed_sequence(out,batch_first=True,padding_value=num_sequence.PAD)return out,hidden

解码器

主要负责实现对编码之后结果的处理,得到预测值

import torch
import torch.nn as nn
import config
import random
import torch.nn.functional as F
from word_sequence import num_sequenceclass NumDecoder(nn.Module):def __init__(self):super(NumDecoder,self).__init__()self.max_seq_len = config.max_lenself.vocab_size = len(num_sequence)self.embedding_dim = config.embedding_dimself.dropout = config.dropoutself.embedding = nn.Embedding(num_embeddings=self.vocab_size,embedding_dim=self.embedding_dim,padding_idx=num_sequence.PAD)self.gru = nn.GRU(input_size=self.embedding_dim,hidden_size=config.hidden_size,num_layers=1,batch_first=True,dropout=self.dropout)self.log_softmax = nn.LogSoftmax()self.fc = nn.Linear(config.hidden_size,self.vocab_size)def forward(self, encoder_hidden,target,target_length):decoder_input = torch.LongTensor([[num_sequence.SOS]]*config.batch_size)decoder_outputs = torch.zeros(config.batch_size,config.max_len,self.vocab_size) decoder_hidden = encoder_hidden for t in range(config.max_len):decoder_output_t , decoder_hidden = self.forward_step(decoder_input,decoder_hidden)decoder_outputs[:,t,:] = decoder_output_tuse_teacher_forcing = random.random() > 0.5if use_teacher_forcing:decoder_input =target[:,t].unsqueeze(1) else:value, index = torch.topk(decoder_output_t, 1) decoder_input = indexreturn decoder_outputs,decoder_hiddendef forward_step(self,decoder_input,decoder_hidden):embeded = self.embedding(decoder_input)  out,decoder_hidden = self.gru(embeded,decoder_hidden) out = out.squeeze(0) out = F.log_softmax(self.fc(out),dim=-1)out = out.squeeze(1)return out,decoder_hidden

seq2seq模型

完成模型的搭建

import torch
import torch.nn as nnclass Seq2Seq(nn.Module):def __init__(self,encoder,decoder):super(Seq2Seq,self).__init__()self.encoder = encoderself.decoder = decoderdef forward(self, input,target,input_length,target_length):encoder_outputs,encoder_hidden = self.encoder(input,input_length)decoder_outputs,decoder_hidden = self.decoder(encoder_hidden,target,target_length)return decoder_outputs,decoder_hidden

完成训练:

import torch
import config
from torch import optim
import torch.nn as nn
from encoder import NumEncoder
from decoder import NumDecoder
from seq2seq import Seq2Seq
from dataset import data_loader as train_dataloader
from word_sequence import num_sequenceencoder = NumEncoder()
decoder = NumDecoder()
model = Seq2Seq(encoder,decoder)
print(model)optimizer =  optim.Adam(model.parameters())
criterion= nn.NLLLoss(ignore_index=num_sequence.PAD,reduction="mean")def get_loss(decoder_outputs,target):target = target.view(-1)decoder_outputs = decoder_outputs.view(config.batch_size*config.max_len,-1)return criterion(decoder_outputs,target)def train(epoch):for idx,(input,target,input_length,target_len) in enumerate(train_dataloader):optimizer.zero_grad()##[seq_len,batch_size,vocab_size] [batch_size,seq_len]decoder_outputs,decoder_hidden = model(input,target,input_length,target_len)loss = get_loss(decoder_outputs,target)loss.backward()optimizer.step()print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, idx * len(input), len(train_dataloader.dataset),100. * idx / len(train_dataloader), loss.item()))torch.save(model.state_dict(), "model/seq2seq_model.pkl")torch.save(optimizer.state_dict(), 'model/seq2seq_optimizer.pkl')if __name__ == '__main__':for i in range(5):train(i)

Seq2Seq优点:能处理输入和输出长度不固定的序列转换任务,灵活性高

Seq2Seq缺点:使用固定上下文长度、训练和推理通常需要逐步处理输入和输出序列,以及参数量较少,面对复杂场景可能受限。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/48300.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Sklearn-混淆矩阵】一文搞懂分类模型的基础评估指标:混淆矩阵ConfusionMatrixDisplay

【Sklearn-混淆矩阵】一文搞懂分类模型的基础评估指标:混淆矩阵ConfusionMatrixDisplay 本次修炼方法请往下查看 🌈 欢迎莅临我的个人主页 👈这里是我工作、学习、实践 IT领域、真诚分享 踩坑集合,智慧小天地! &…

unity渲染人物模型透明度问题

问题1:有独立的手和衣服的模型,但最终只渲染出来半透明衣服 问题2:透明度贴图是正确的但显示却不正确 这上面两个模型的问题都是因为人物模型是一个完整的,为啥有些地方可以正常显示,有些地方透明度却有问题。 其中…

使用C#实现无人超市管理系统——数据结构课设(代码+PPT+说明书)

说明:这是自己做的课程设计作业,得分情况98/100 如果想要获取私信我 本项目采用线性表中的链表来进行本次系统程序的设计。链表分为两条线,分别是存储用户信息和商品信息,并且都设为公共属性,方便对用户信息和商品信息…

艺术与技术的交响曲:CSS绘图的艺术与实践

在前端开发的世界里,CSS(层叠样式表)作为网页布局和样式的基石,其功能早已超越了简单的颜色和间距设置。近年来,随着CSS3的普及,开发者们开始探索CSS在图形绘制方面的潜力,用纯粹的代码创造出令…

UniApp__微信小程序项目实战 实现长列表分页,通过 onReachBottom 方法上划分次加载数据

UniApp 实现长列表分页,通过 onReachBottom 方法上划分次加载数据 项目实战中比较常见,方便下次使用 文章目录 一、应用场景? 二、作用 三、使用步骤?          3.1 实现的整体思路?    …

基于python深度学习遥感影像地物分类与目标识别、分割实践技术应用

目录 专题一、深度学习发展与机器学习 专题二、深度卷积网络基本原理 专题三、TensorFlow与Keras介绍与入门 专题四、PyTorch介绍与入门 专题五、卷积神经网络实践与遥感图像场景分类 专题六、深度学习与遥感图像检测 专题七、遥感图像检测案例 专题八、深度学习与遥感…

字节码编程之bytebuddy结合javaagent支持多种监控方式

写在前面 打印方法执行耗时是监控,获取程序运行的JVM信息是监控,链路追踪也是监控。 本文看下如何实现一个通用的监控解决方案。 1:程序 定义premain: package com.dahuyou.multi.monitor;import com.dahuyou.multi.monitor.…

安卓逆向入门(3)------Frida基础

安装frida pip install frida pip install frida-tools //验证安装成功 frida --versionfrida连接手机 1、Android(已ROOT) frida-server 参考:https://www.jianshu.com/p/c349471bdef7 2、Android(非ROOT) pip3 in…

智能门锁的工作原理

智能门锁的工作原理是一个复杂而精密的过程,它结合了物联网、密码学、身份认证和通信技术等多个领域的先进技术。以下是智能门锁工作原理的详细解析: 一、身份认证 智能门锁通过身份认证机制来确保只有授权的用户才能开启门锁。常见的身份认证方式包括…

数据库内核研发学习之路(五)创建postgres系统表

写在前面 在使用postgres的时候,有很多表是我们一开始安装好数据库就存在的,这些表称为系统表,他们记载一些数据库信息,比如我们做运维工作常用的pg_stat_activity;我们在数据库中查询这张表可以发现他存储了一些数据库连接信息。…

多租户架构的艺术:在SQL Server中实现数据库的多租户

多租户架构的艺术:在SQL Server中实现数据库的多租户 在云计算和SaaS(软件即服务)时代,多租户架构(Multi-Tenancy)成为了数据库设计中的一个关键概念。它允许多个租户(客户)共享相同…

初等数论精解【2】

文章目录 素数基础素数理论互素定义性质应用示例最大公约数方法一:欧几里得算法方法二:列举法(适用于较小的数)欧几里得算法编程实现扩展欧几里得算法概述算法背景算法原理算法步骤应用场景示例代码 结论素数分布素数概述一、定义…

GO:Socket编程

目录 一、TCP/IP协议族和四层模型概述 1.1 互联网协议族(TCP/IP) 1.2 TCP/IP四层模型 1. 网络访问层(Network Access Layer) 2. 网络层(Internet Layer) 3. 传输层(Transport Layer&#…

WPF+Mvvm 项目入门完整教程(一)

WPF+Mvvm 入门完整教程一 创建项目MvvmLight框架安装完善整个项目的目录结构创建自定义的字体资源下载更新和使用字体资源创建项目 打开VS2022,点击创建新项目,选择**WPF应用(.NET Framework)** 创建一个名称为 CommonProject_DeskTop 的项目,如下图所示:MvvmLight框架安装…

机器学习-19-基于交互式web应用框架streamlit和gradio转化数据和机器学习模型

参考Streamlit:简单快速的Python Web应用开发工具 参考Python(Web时代)—— 超简单:一行代码就能搭建网站 参考对比Streamlit和Gradio:选择最适合你的Python交互式应用框架 参考Gradio:构建交互式界面的简单而强大的Python库 参考【吴恩达 X HuggingFace】使用Gradio快速…

【JavaScript 算法】双指针法:高效处理数组问题

🔥 个人主页:空白诗 文章目录 一、算法原理二、算法实现示例问题1:两数之和 II - 输入有序数组示例问题2:反转字符串中的元音字母注释说明: 三、应用场景四、总结 双指针法(Two Pointer Technique&#xff…

sqlalchemy_dm

1、参考文档: https://blog.csdn.net/njcwwddcz/article/details/126554118 https://eco.dameng.com/document/dm/zh-cn/pm/dmpython-dialect-package.html 2、生成工具 sqlalchemy2.0.0.zip 3、安装步骤 conda create --name kes --clone kes1 rz unzip sql…

高等数学用到的初等数学

指数 同指不同底乘法 (ab)xaxbx

如何做到高级Kotlin强化实战?(三)

高级Kotlin强化实战(二) 2.13 constructor 构造器2.14 Get Set 构造器2.15 操作符2.16 换行 2.13 constructor 构造器 //Java public class Utils { private Utils() {} public static int getScore(int value) { return 2 * value;} }//Kotlin class U…

深入理解Java并发线程阻塞唤醒类LockSupport

LockSupprot 用来阻塞和唤醒线程,底层实现依赖于Unsafe类 该类包含一组用于阻塞和唤醒线程的静态方法,这些方法主要是围绕 park 和 unpark 展开 public class LockSupportDemo1 {public static void main(String[] args) {Thread mainThread Thread.cu…