NLP Seq2Seq模型

🍨 本文为[🔗365天深度学习训练营学习记录博客🍦 参考文章:365天深度学习训练营🍖 原作者:[K同学啊 | 接辅导、项目定制]\n🚀 文章来源:[K同学的学习圈子](https://www.yuque.com/mingtian-fkmxf/zxwb45)

Seq2Seq模型是一种深度学习模型,用于处理序列到序列的任务,它由两个主要部分组成:编码器(Encoder)和解码器(Decoder)。

  1. 编码器(Encoder): 编码器负责将输入序列(例如源语言句子)编码成一个固定长度的向量,通常称为上下文向量或编码器的隐藏状态。编码器可以是循环神经网络(RNN)、长短期记忆网络(LSTM)或者变种如门控循环单元(GRU)等。编码器的目标是捕捉输入序列中的语义信息,并将其编码成一个固定维度的向量表示。

  2. 解码器(Decoder): 解码器接收编码器生成的上下文向量,并根据它来生成输出序列(例如目标语言句子)。解码器也可以是RNN、LSTM、GRU等。在训练阶段,解码器一次生成一个词或一个标记,并且其隐藏状态从一个时间步传递到下一个时间步。解码器的目标是根据上下文向量生成与输入序列对应的输出序列。

在训练阶段,Seq2Seq模型的目标是最大化目标序列的条件概率给定输入序列。为了实现这一点,通常使用了一种称为教师强制(Teacher Forcing)的技术,即将目标序列中的真实标记作为解码器的输入。但是,在推理阶段(即模型用于生成新的序列时),解码器则根据先前生成的标记生成下一个标记,直到生成一个特殊的终止标记或达到最大长度为止。

下面演示了如何使用PyTorch实现一个简单的Seq2Seq模型,用于将一个序列翻译成另一个序列。这里我们将使用一个虚构的数据集来进行简单的法语到英语翻译。

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset# 定义数据集
class SimpleDataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx]# 定义Encoder
class Encoder(nn.Module):def __init__(self, input_dim, emb_dim, hidden_dim):super(Encoder, self).__init__()self.embedding = nn.Embedding(input_dim, emb_dim)self.rnn = nn.GRU(emb_dim, hidden_dim)def forward(self, src):embedded = self.embedding(src)outputs, hidden = self.rnn(embedded)return outputs, hidden# 定义Decoder
class Decoder(nn.Module):def __init__(self, output_dim, emb_dim, hidden_dim):super(Decoder, self).__init__()self.embedding = nn.Embedding(output_dim, emb_dim)self.rnn = nn.GRU(emb_dim, hidden_dim)self.fc_out = nn.Linear(hidden_dim, output_dim)def forward(self, input, hidden):input = input.unsqueeze(0)embedded = self.embedding(input)output, hidden = self.rnn(embedded, hidden)prediction = self.fc_out(output.squeeze(0))return prediction, hidden# 定义Seq2Seq模型
class Seq2Seq(nn.Module):def __init__(self, encoder, decoder, device):super(Seq2Seq, self).__init__()self.encoder = encoderself.decoder = decoderself.device = devicedef forward(self, src, trg, teacher_forcing_ratio=0.5):batch_size = trg.shape[1]trg_len = trg.shape[0]trg_vocab_size = self.decoder.fc_out.out_featuresoutputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)encoder_outputs, hidden = self.encoder(src)input = trg[0,:]for t in range(1, trg_len):output, hidden = self.decoder(input, hidden)outputs[t] = outputteacher_force = np.random.rand() < teacher_forcing_ratiotop1 = output.argmax(1) input = trg[t] if teacher_force else top1return outputs# 设置参数
INPUT_DIM = 10
OUTPUT_DIM = 10
ENC_EMB_DIM = 32
DEC_EMB_DIM = 32
HID_DIM = 64
N_LAYERS = 1
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5# 实例化模型
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Seq2Seq(enc, dec, device).to(device)# 打印模型结构
print(model)# 定义训练函数
def train(model, iterator, optimizer, criterion, clip):model.train()epoch_loss = 0for i, batch in enumerate(iterator):src, trg = batchsrc = src.to(device)trg = trg.to(device)optimizer.zero_grad()output = model(src, trg)output_dim = output.shape[-1]output = output[1:].view(-1, output_dim)trg = trg[1:].view(-1)loss = criterion(output, trg)loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), clip)optimizer.step()epoch_loss += loss.item()return epoch_loss / len(iterator)# 定义测试函数
def evaluate(model, iterator, criterion):model.eval()epoch_loss = 0with torch.no_grad():for i, batch in enumerate(iterator):src, trg = batchsrc = src.to(device)trg = trg.to(device)output = model(src, trg, 0) # 关闭teacher forcingoutput_dim = output.shape[-1]output = output[1:].view(-1, output_dim)trg = trg[1:].view(-1)loss = criterion(output, trg)epoch_loss += loss.item()return epoch_loss / len(iterator)# 示例数据
train_data = [(torch.tensor([1, 2, 3]), torch.tensor([3, 2, 1])),(torch.tensor([4, 5, 6]), torch.tensor([6, 5, 4])),(torch.tensor([7, 8, 9]), torch.tensor([9, 8, 7]))]# 超参数
BATCH_SIZE = 3
N_EPOCHS = 10
LEARNING_RATE = 0.001
CLIP = 1# 数据集与迭代器
train_dataset = SimpleDataset(train_data)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)# 定义损失函数与优化器
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()# 训练模型
for epoch in range(N_EPOCHS):train_loss = train(model, train_loader, optimizer, criterion, CLIP)print(f'Epoch: {epoch+1:02} | Train Loss: {train_loss:.3f}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/711927.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深入理解Linux线程(LWP):概念、结构与实现机制(2)

&#x1f3ac;慕斯主页&#xff1a;修仙—别有洞天 ♈️今日夜电波&#xff1a;会いたい—Naomile 1:12━━━━━━️&#x1f49f;──────── 4:59 &#x1f504; ◀️ ⏸ ▶️ ☰ &a…

Vue3+vite打包后页面空白问题

vite.config.js vite.config.js 增加 base: ./ import { fileURLToPath, URL } from node:url import { defineConfig } from vite import vue from vitejs/plugin-vue// https://vitejs.dev/config/ export default defineConfig({base: ./,resolve: {alias: {: fileURLToPath…

解析短视频美颜SDK:美颜美型技术的深度剖析

美颜并非简单的滤镜叠加&#xff0c;而是依托着先进的图像处理和人工智能技术&#xff0c;才能够达到如此出色的效果。本文将深入探讨短视频美颜SDK背后的技术原理和实现方法&#xff0c;从而揭示其美颜美型技术的深度剖析。 一、美颜SDK的基本原理 美颜SDK的基本原理是通过对…

java 企业培训管理系统Myeclipse开发mysql数据库web结构jsp编程计算机网页项目

一、源码特点 java 企业培训管理系统是一套完善的java web信息管理系统&#xff0c;对理解JSP java编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。开发环境为TOMCAT7.0,Myeclipse8.5开发&#xff0c;数据库为Mysql5.0&…

UCWSC

feature fusion neural network based on a decomposition mechanism (FFDM) 辅助信息 作者未提供代码

学习大数据,所必需的java基础(6)

文章目录 集合Set集合介绍HashSet集合的介绍和使用LinkedHashSet的介绍以及使用哈希值哈希值的计算方式HashSet的存储去重的过程 Map集合Map的介绍HashMap的介绍以及使用HashMap的两种遍历方式方式1&#xff1a;获取key&#xff0c;然后再根据key获取value方式2&#xff1a;同时…

【Sql Server】Update中的From语句,以及常见更新操作方式

欢迎来到《小5讲堂》&#xff0c;大家好&#xff0c;我是全栈小5。 这是《Sql Server》系列文章&#xff0c;每篇文章将以博主理解的角度展开讲解&#xff0c; 特别是针对知识点的概念进行叙说&#xff0c;大部分文章将会对这些概念进行实际例子验证&#xff0c;以此达到加深对…

Docker技术概论(4):Docker CLI 基本用法解析

Docker技术概论&#xff08;4&#xff09; Docker CLI 基本用法解析 - 文章信息 - Author: 李俊才 (jcLee95) Visit me at: https://jclee95.blog.csdn.netMy WebSite&#xff1a;http://thispage.tech/Email: 291148484163.com. Shenzhen ChinaAddress of this article:http…

Python实现PPT演示文稿中视频的添加、替换及提取

无论是在教室、会议室还是虚拟会议中&#xff0c;PowerPoint 演示文稿都已成为一种无处不在的工具&#xff0c;用于提供具有影响力的可视化内容。PowerPoint 提供了一系列增强演示的功能&#xff0c;在其中加入视频的功能可以大大提升整体体验。视频可以传达复杂的概念、演示产…

ArkTS中的路由跳转和HTTP数据请求

路由跳转 步骤1&#xff1a;找到箭头所指的文件&#xff0c;在其中添加已创建的页面 步骤2&#xff1a;导包 步骤3&#xff1a; HTTP数据请求 步骤1&#xff1a;导包 > import http from ohos.net.http; 步骤2&#xff1a;&#xff08;如果需要在页面加载前请求&#xf…

TcpServer服务器管理模块(模块十)

目录 类功能 类定义 类实现 编译测试 server.cc gdb测试断点 忽略SIGPIPE信号 类功能 类定义 // TcpServer服务器管理模块(即全部模块的整合) class TcpServer { private:uint64_t _next_id; // 这是一个自动增长的连接IDint _port;i…

Linux学习-C语言-运算符

目录 算术运算符&#xff1a; - * /:不能除0 %:不能对浮点数操作 &#xff1a;自增与运算符 i&#xff1a;先用再加 i:先加再用 --&#xff1a;自减运算符 常量&#xff0c;表达式不可以&#xff0c;--&#xff0c;变量可以 赋值运算符 三目运算符 逗号表达式 size…

alpine创建lnmp环境alpine安装nginx+php5.6+mysql

前言 制作lnmp环境&#xff0c;你可以在alpine基础镜像中安装相关的服务&#xff0c;也可以直接使用Dockerfile创建自己需要的环境镜像。 注意&#xff1a;提前确认自己的alpine版本&#xff0c;本次创建基于alpine3.6进行创建&#xff0c;官方在一些版本中删除了php5 1、拉取…

JS正则02——js正则表达式中常用的方法、常见修饰符的使用详解以及各种方法使用情况示例

JS正则02——js正则表达式中常用的方法、常见修饰符的使用详解以及各种方法使用情况示例 1. 前言1.1 简介1.2 js正则特殊字符即使用示例 2. 创建正则表达式的方式2.1 两种创建正则表达式的方式2.2 关于修饰符 3. 正则表达式中常用的方法3.1 test() 方法——正则表达式对象的方法…

Vue之监测数据的原理(对象)

大家有没有想过&#xff0c;为什么vue可以监测到数据发生改变&#xff1f;其实底层借助了Object.defineProperty&#xff0c;底层有一个Observer的构造函数 让我为大家简单的介绍一下吧&#xff01; 我用对象为大家演示一下 const vm new Vue({el: "#app",data: {ob…

文献速递:帕金森的疾病分享--多模态机器学习预测帕金森病

文献速递&#xff1a;帕金森的疾病分享–多模态机器学习预测帕金森病 Title 题目 Multi-modality machine learning predicting Parkinson’s disease 多模态机器学习预测帕金森病 01 文献速递介绍 对于渐进性神经退行性疾病&#xff0c;早期和准确的诊断是有效开发和使…

【精品】集合list去重

示例一&#xff1a;对于简单类型&#xff0c;比如String public static void main(String[] args) {List<String> list new ArrayList< >();list.add("aaa");list.add("bbb");list.add("bbb");list.add("ccc");list.add(…

网络工程师必备的网络端口大全(建议收藏)

端口是一种数字标识&#xff0c;用于在计算机网络中进行通信&#xff0c;你完全可以把端口简单的理解为是计算机和外界通讯交流的出口。但在网络技术中&#xff0c;端口一般有两种含义&#xff1a; &#xff08;1&#xff09;硬件设备中的端口 如交换机、路由器中用于链接其他…

“金三银四”招聘季,大厂争招鸿蒙人才

在金三银四的招聘季中&#xff0c;各大知名互联网企业纷纷加入了对鸿蒙人才的争夺战。近日&#xff0c;包括淘宝、京东、得物等在内的知名APP均宣布启动鸿蒙星河版原生应用开发计划。这一举措不仅彰显了鸿蒙生态系统的迅猛发展&#xff0c;还催生了人才市场的繁荣景象。据数据显…

遥感影像处理(ENVI+ChatGPT+python+ GEE)处理高光谱及多光谱遥感数据

遥感技术主要通过卫星和飞机从远处观察和测量我们的环境&#xff0c;是理解和监测地球物理、化学和生物系统的基石。ChatGPT是由OpenAI开发的最先进的语言模型&#xff0c;在理解和生成人类语言方面表现出了非凡的能力。本文重点介绍ChatGPT在遥感中的应用&#xff0c;人工智能…