【NLP练习】Transformer实战-单词预测

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

任务:自定义输入一段英文文本进行预测

一、定义模型

from tempfile import TemporaryDirectory
from typing import Tuple
from torch import nn,Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import math, os, torchclass TransformerModel(nn.Module):def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int, nlayers: int, dropout: float = 0.5):super().__init__()self.pos_encoder = PositionalEncoding(d_model, dropout)#定义编码器层encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)#定义编码器,pytorch将Transformer编码器进行了打包,这里直接调用即可self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)self.embedding = nn.Embedding(ntoken,d_model)self.d_model = d_modelself.linear = nn.Linear(d_model, ntoken)self.init_weights()#初始化权重def init_weights(self) -> None:initrange = 0.1self.embedding.weight.data.uniform_(-initrange, initrange)self.linear.bias.data.zeros_()self.linear.weight.data.uniform_(-initrange, initrange)def forward(self, src:Tensor, src_mask: Tensor = None) -> Tensor:"""Arguments:src:      Tensor, 形状为[seq_len, batch_size]src_mask: Tensor, 形状为[seq_len, seq_len]Returns:输出的Tensor,形状为[seq_len, batch_size, ntoken]"""src = self.embedding(src) * math.sqrt(self.d_model)src = self.pos_encoder(src)output = self.transformer_encoder(src, src_mask)output = self.linear(output)return output
class PositionalEncoding(nn.Module):def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):super().__init__()self.dropout = nn.Dropout(p = dropout)#生成位置编码的位置张量position = torch.arange(max_len).unsqueeze(1)#计算位置编码的除数项div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))#创建位置编码张量pe = torch.zeros(max_len, 1, d_model)#使用正弦函数计算位置编码中的基数维度部分pe[:, 0, 1::2] = torch.sin(position * div_term)#使用余弦函数计算位置编码中的偶数维度部分pe[:, 0, 1::2] = torch.cos(position * div_term)self.register_buffer('pe', pe)def forward(self, x: Tensor) -> Tensor:"""Arguments:x:      Tensor, 形状为[seq_len, batch_size, embedding_dim]"""#将位置编码添加到输入张量x = x + self.pe[:x.size(0)]#应用dropoutreturn self.dropout(x)

二、加载数据集

本实验使用torchtext生成Wikitext-2数据集。在此之前,你需要安装下面的包:

  • pip install portalocker
  • pip install torchdata
import torchtext
from torchtext.datasets.wikitext2 import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator#从torchtext库中导入WikiTetx2数据集
train_iter = WikiText2(split = 'train')#获取基本的英语分词器
tokenizer = get_tokenizer('basic_english')
#通过迭代器构建词汇表
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
#将默认索引设置为'<unk>'
vocab.set_default_index(vocab['<unk>'])def data_process(raw_text_iter: dataset.IterableDataset) -> Tensor:"""将原始文本转换为扁平的张量"""data = [torch.tensor(vocab(tokenizer(item)),dtype = torch.long) for item in raw_text_iter]return torch.cat(tuple(filer(lambda t: t.numel() > 0, data)))#由于构建词汇表时"train_iter"被使用了,所以需要重新创建
train_iter, val_iter, test_iter = WikiText2()#队训练、验证和测试数据进行处理
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)#检查是否有可用的CUDA设备,将设备设置为GPU或者CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')def batchify(data: Tensor, bsz: int) -> Tensor:"""将数据划分为bsz个单独的序列,去除不能完全容纳的额外元素。参数:data: Tensor,形状为``[N]``bsz:int,批大小返回:形状为[N // bsz, bsz]的张量"""seq_len = data.size(0) // bszdata = data[:seq_len * bsz]data = data.view(bsz, seq_len).t().contiguous()return data.to(device)#设置批大小和评估批大小
batch_size = 20
eval_batch_size = 10
#将训练、验证和测试数据进行批处理
train_data = batchify(train_data, batch_size)   #形状为[seq_len, batch_size]
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)
bptt = 35#获取批次数据
def get_batch(source:Tensor, i: int) -> Tuple[Tensor, Tensor]:"""参数:source: Tensor,形状为``[full_seq_len, batch_size]``i : int, 当前批次索引返回:tuple(data, target),-data形状为[seq_len, batch_size],-target形状为[seq_len * batch_size]"""#计算当前批次的序列长度,最大为bptt,确保不超过source的长度seq_len = min(bptt, len(source) - 1 - i)#获取data,从i开始,长度为seq_lendata = source[i:i+seq_len]#获取target,从i+1开始,长度为seq_len,并将其形状转换为一维张量target = source[i+1:i+1+seq_len].reshape(-1)return data, target

三、初始化实例

ntokend = len(vocab)
emsize = 200
d_hid = 200
nlayers = 2
nhead = 2
dropout = 0.2
#创建transformer模型
model = TransformerModel(ntokend,emsize,nhead,d_hid,nlayers,dropout).to(device)

四、训练模型

结合使用CrossEntropyLoss与SGD(随机梯度下降优化器)。训练期间,使用torch.nn.utils.clip_grad_norm_来防止梯度爆炸

import time
criterion = nn.CrossEntropyLoss() #定义交叉熵损失函数
lr = 5.0
optimizer = torch.optim.SGD(model.parameters(),lr = lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gama = 0.95)def train(model: nn.Module) -> None:model.train() #开启训练模式total_loss = 0.log_interval = 200 #start_time = time.time()num_batches = len(train_data) // bpttfor batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):data, targets = get_batch(train_data, i)output = model(data)output_flat = output.view(-1, ntokens)loss = criterion(output_flat, targets) #计算损失optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)optimizer.step()total_loss += loss.item()if batch % log_interval == 0 and batch > 0:lr = scheduler.get_last_lr()[0]ms_per_batch = (time.time() - start_time) * 1000 / log_intervalcur_loss = total_loss / log_intervalppl = math.exp(cur_loss)print(f'| epoch{epoch:3d} | {batch:5d} / {num_batches:5d} batches |'f'lr{lr:02.2f} | ms/batch {ms_per_batch:5.2f} |'f'loss {cur_loss:5.2f}|ppl{ppl:8.2f}')total_loss = 0start_time = time.time()def evaluate(model:nn.Module, eval_data:Tensor) -> float:model.eval()total_loss = 0.with torch.no_grad():for i in range(0,eval_data.size(0) - 1, bptt):data, targets = get_batch(eval_data,i)seq_len = data.size(0)output = model(data)output_flat = output.view(-1,ntokens)total_loss += seq_len * criterion(output_flat, targets).item()return total_loss / (len(eval_data) - 1)
best_val_loss = float('inf')
epochs = 1with TemporaryDirectory() as tempdir:best_model_params_path = os.path.join(tempdir, "best_model_params.pt")for epoch in range(1, epochs + 1):epoch_start_time = time.time()train(model)val_loss = evaluate(model, val_data)val_ppl = math_exp(val_loss)elapsed = time.time() - epoch_start_time#打印当前epoch的信息,包括耗时、验证损失和困惑度print('-' * 89)print(f'|end of epoch {epoch:3d} | time:{elapsed: 5.2f}s |'f'valid loss {val_loss:5.2f} | valid ppl {val_ppl: 8.2f}')print('-' * 89)if val_loss < best_val_loss:best_val_loss = val_losstorch.save(model.state_dict(), best_model_params_path)scheduler.step()    #更新学习率model.load_state_dict(torch.load(best_model_params_path))

代码输出:
在这里插入图片描述

五、总结

加载数据集时,注意包的版本关联关系。另外,注意结合使用优化器提升优化性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/31689.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AI时代:硬件狂欢,软件落寞 华为开发者大会2024

内容提要 分析师表示&#xff0c;目前AI行业大多数的支出都流向用于训练大模型的硬件或云基础设备。相较之下&#xff0c;软件应用商们在AI时代显得停滞不前。尽管软件应用商们十分热衷于构建AI工具&#xff0c;然而其收入状况却并不乐观。 文章正文 AI浪潮之下&#xff0c;英…

AIGC时代,重塑人的核心竞争力?

随着人工智能技术的飞速发展&#xff0c;AIGC&#xff08;人工智能生成内容&#xff09;的时代已经悄然而至。在这个时代&#xff0c;AI不再仅仅是一个技术概念&#xff0c;而是深入到我们生活的方方面面&#xff0c;从创作到生产&#xff0c;从娱乐到工作&#xff0c;AI都在以…

RabbitMQ 相关概念

引言 什么是消息中间件 消息是指在应用间传送的数据&#xff0c;包含文本字符串、JSON等。消息队列中间件&#xff08;MQ&#xff09;指利用高效可靠的消息传递机制进行平台无关的数据交流&#xff0c;并基于数据通信来进行分布式系统的集成。通过提供消息传递和消息排队模型…

剑指offer 算法题(搜索二维矩阵)

剑指offer 第二题 去力扣里测试算法 思路一&#xff1a; 直接暴力遍历二维数组。 class Solution { public:bool searchMatrix(vector<vector<int>>& matrix, int target) {for (unsigned int i{ 0 }; i < matrix.size(); i){for (unsigned int j{ 0 };…

Shell脚本:条件语句(if、case)

目录 硬编码 硬编码的缺点 条件判断 $? 命令行语句 判断指定目录是否存在 判断指定文件是否存在 判断指定对象是否存在 表达式形式语句 判断对象是否存在 判断对象是否有权限 与、或、非 运算 与运算 或运算 非运算 比较大小 判断磁盘利用率实验步骤 字符串…

Java基础之练习(2)

需求: 键盘录入一个字符串,使用程序实现在控制台遍历该字符串 package String;import java.util.Scanner;public class StringDemo5 {public static void main(String[] args) {//录入一个字符串Scanner sc new Scanner(System.in);System.out.println("请输入一个字符串…

1. 基础设计流程(以时钟分频器的设计为例)

1. 准备工作 1. 写有vcs编译命令的run_vcs.csh的shell脚本 2. 装有timescale&#xff0c;设计文件以及仿真文件的flish.f&#xff08;filelist文件&#xff0c;用于VCS直接读取&#xff09; vcs -R -full64 -fsdb -f flist.f -l test.log 2. 写代码&#xff08;重点了解代码…

如何将办公文档压缩成rar格式文件?

压缩包格式是我们生活工作中常用到的文件格式&#xff0c;那么如何得到一个rar格式的压缩文件&#xff1f;或者说如何将文件压缩成rar格式而不是zip格式呢&#xff1f;今天我们来了解一下如何压缩为rar格式文件。 首先&#xff0c;下载并安装WinRAR&#xff0c;然后用鼠标选择需…

【Python】成功解决TypeError: missing 1 required positional argument

【Python】成功解决TypeError: missing 1 required positional argument 下滑即可查看博客内容 &#x1f308; 欢迎莅临我的个人主页 &#x1f448;这里是我静心耕耘深度学习领域、真诚分享知识与智慧的小天地&#xff01;&#x1f387; &#x1f393; 博主简介&#xff1…

React的服务器端渲染(SSR)和客户端渲染(CSR)有什么区别?

React的服务器端渲染&#xff08;SSR&#xff09;和客户端渲染&#xff08;CSR&#xff09;是两种不同的页面渲染方式&#xff0c;它们各自有不同的特点和适用场景&#xff1a; 服务器端渲染&#xff08;SSR&#xff09; 页面渲染: 页面在服务器上生成&#xff0c;然后将完整的…

复盘最近的面试

这个礼拜一直在面试&#xff0c;想着看看能否拿到不错的offer前去实习&#xff0c;从周一到周四&#xff0c;面了将近10家&#xff0c;特整理此份面经&#xff0c;希望对秋招的各位有所帮助 A公司 一面 面试官人很好&#xff0c;我回答的时候不会他会笑笑然后提醒我 自我介绍~…

数据通信与网络(三)

物理层概述&#xff1a; 物理层是网络体系结构中的最低层 它既不是指连接计算机的具体物理设备&#xff0c;也不是指负责信号传输的具体物理介质&#xff0c; 而是指在连接开放系统的物理媒体上为上一层(指数据链路层)提供传送比特流的一个物理连接。 物理层的主要功能——为…

项目中eventbus和rabbitmq配置后,不起作用

如下&#xff1a;配置了baseService层和SupplyDemand层得RabbitMQ和EventBus 但是在执行订阅事件时&#xff0c;发送得消息在base项目中没有执行&#xff0c;后来发现是虚拟机使用得不是一个&#xff0c;即上图中得EventBus下得VirtualHost&#xff0c;修改成一直就可以了

肆拾玖坊三级众筹模式玩法揭秘,白酒体验馆运作模式

发展至今&#xff0c;肆拾玖坊已积累了数百万忠实用户&#xff0c;拥有100多家分销商、5000多个新零售终端&#xff0c;覆盖全国34个省级行政区域、200余地市、1500个县区。成为中国创业界和酒行业的“现象级”企业。 今天&#xff0c;我们就来深入解析肆拾玖坊的营销模式&…

Linux入门攻坚——26、Web Service基础知识与httpd配置-2

http协议 URL&#xff1a;Uniform Resource Locator&#xff0c;统一资源定位符 URL方案&#xff1a;scheme&#xff0c;如http://&#xff0c;https:// 服务器地址&#xff1a;IP&#xff1a;port 资源路径&#xff1a; 示例&#xff1a;http://www.test.com:80/bbs/…

ios18计算器大更新使用指南,一招掌握新计算器使用技巧!

苹果推出iOS 18系统中&#xff0c;变化较大的之一就是以多年没有更新的计算器应用程序&#xff0c;新增了多个使用的功能&#xff0c;经过小编几天的使用&#xff0c;总结了几个iOS 18计算器的使用技巧和更新点分享给大家。 一、界面布局变化 与iOS 17相比&#xff0c;iOS18的…

Java学习笔记(二)变量原理、常用编码、类型转换

Hi i,m JinXiang ⭐ 前言 ⭐ 本篇文章主要介绍Java变量原理、常用编码、类型转换详细使用以及部分理论知识 🍉欢迎点赞 👍 收藏 ⭐留言评论 📝私信必回哟😁 🍉博主收将持续更新学习记录获,友友们有任何问题可以在评论区留言 1、变量原理 1.1、变量的介绍 变量是程…

生成式AI时代,数据存储管理与成本如何不失控?

无数据&#xff0c;不AI。 由生成式AI掀起的这一次人工智能浪潮&#xff0c;对企业的产品、服务乃至商业模式都有着颠覆性的影响。因此&#xff0c;在多云、大数据、生成式AI等多元技术的驱动下&#xff0c;数据要素变得愈发重要的同时&#xff0c;企业对于数据存储的需求也在…

【Android14 ShellTransitions】(六)SyncGroup完成

这一节的内容在WMCore中&#xff0c;回想我们的场景&#xff0c;是在Launcher启动某一个App&#xff0c;那么参与动画的就是该App对应Task&#xff08;OPEN&#xff09;&#xff0c;以及Launcher App对应的Task&#xff08;TO_BACK&#xff09;。在确定了动画的参与者后&#x…

JVS开源底座与核心引擎的全方位探索,助力IT智能、高效、便捷的进化

引言 JVS产品的诞生背景 JVS是软开企服构建的一站式数字化的解决方案&#xff0c;产生的背景主要来源于如下几个方面&#xff1a; 企业数字化需求的增长&#xff1a;企业对IT建设的依赖程度越来越高&#xff0c;数字化、指标化的经营已经是很多企业的生存的基础和前提&#…