Transformer实战 单词预测

  •    🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:TensorFlow入门实战|第3周:天气识别
  • 🍖 原作者:K同学啊|接辅导、项目定制

一、定义模型

from tempfile import TemporaryDirectory
from typing import Tuple
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import Dataset
import math, os, torchclass TransformerModel(nn.Module):def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int, nlayers: int, dropout: float = 0.5):super().__init__()self.pos_encoder = PositionalEncoding(d_model, dropout)# 编码器层堆栈encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)# 编码器堆栈. pytorch已经实现了Transformer编码器层的堆栈,这里直接调用即可self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)self.embedding = nn.Embedding(ntoken, d_model)self.d_model = d_modelself.linear = nn.Linear(d_model, ntoken)self.init_weights()# 初始化权重def init_weights(self) -> None:initrange = 0.1self.embedding.weight.data.uniform_(-initrange, initrange)self.linear.bias.data.zero_()self.linear.weight.data.uniform_(-initrange, initrange)def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor:"""Arguments:src: Tensor, 形状为 [seq_len, batch_size]src_mask: Tensor, 形状为 [seq_len, seq_len]Returns:最终的 Tensor, 形状为 [seq_len, batch_size, ntoken]"""src = self.embedding(src) * math.sqrt(self.d_model)src = self.pos_encoder(src)output = self.transformer_encoder(src, src_mask)output = self.linear(output)return outputclass PositionalEncoding(nn.Module):def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):super().__init__()self.dropout = nn.Dropout(p=dropout)# 位置编码器的初始化部分position = torch.arange(max_len).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))pe = torch.zeros(max_len, 1, d_model)pe[:, 0, 0::2] = torch.sin(position * div_term)pe[:, 0, 1::2] = torch.cos(position * div_term)# 注册为持久状态变量,不参与参数优化self.register_buffer('pe', pe)def forward(self, x: Tensor) -> Tensor:"""Arguments:x: Tensor, 形状为 [seq_len, batch_size, embedding_dim]Returns:最终的 Tensor, 形状为 [seq_len, batch_size, embedding_dim]"""x = x + self.pe[:x.size(0)]return self.dropout(x)

wAAACH5BAEKAAAALAAAAAABAAEAAAICRAEAOw==

二、加载数据集

wikitext2_dir = "d:/wikitext-2-v1/wikitext-2"# Modify the data processing function to read from the local file
def data_process(file_path: str) -> Tensor:with open(file_path, 'r', encoding='utf-8') as file:data = [torch.tensor(vocab(tokenizer(line)), dtype=torch.long) for line in file]return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))# Load train, validation, and test data from local files
train_file = os.path.join(wikitext2_dir, "wiki.train.tokens")
val_file = os.path.join(wikitext2_dir, "wiki.valid.tokens")
test_file = os.path.join(wikitext2_dir, "wiki.test.tokens")train_data = data_process(train_file)
val_data = data_process(val_file)
test_data = data_process(test_file)# 使用数据处理函数处理数据集
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)# 设置设备优先使用GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 将数据集分批的函数
def batchify(data: Tensor, bsz: int) -> Tensor:# 计算批次大小nbatch = data.size(0) // bsz# 裁剪掉多余的部分使得能够完全分为批次data = data.narrow(0, 0, nbatch * bsz)# 重新整理数据维度为[批次, 批次大小]data = data.view(bsz, -1).t().contiguous()# 将数据移动到指定设备return data.to(device)# 批次大小
batch_size = 20
eval_batch_size = 10# 应用batchify函数分批处理训练集、验证集和测试集
train_data = batchify(train_data, batch_size)  # 结果为 [序列长度, batch_size]
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)
bptt = 35def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:"""生成批次数据参数:source: Tensor, 形状为 `[full_seq_len, batch_size]`i: int, 批次索引值返回:tuple (data, target).- data包含输入 [seq_len, batch_size],- target包含标签 [seq_len * batch_size]"""seq_len = min(bptt, len(source) - 1 - i)data = source[i:i+seq_len]target = source[i+1:i+1+seq_len].reshape(-1)return data, target

三、实例初始化

ntokens = len(vocab) # 词汇表的大小
emsize = 200         # 嵌入维度
nhid = 200           # nn.TransformerEncoder 中间层的维度
nlayers = 2          # nn.TransformerEncoder层的数量
nhead = 2            # nn.MultiheadAttention 头的数量
dropout = 0.2        # 丢弃率# 初始化 Transformer 模型,并将其发送到指定设备
model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)

四、训练模型

import time# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss() # 交叉熵损失函数
lr = 5.0 # 学习率
optimizer = torch.optim.SGD(model.parameters(), lr=lr) # 使用SGD优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.95) # 学习率衰减# 训练函数
def train(model: nn.Module) -> None:model.train() # 开启训练模式total_loss = 0.log_interval = 200 # 每隔200个batch打印一次日志start_time = time.time()num_batches = len(train_data) // bpttfor batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):data, targets = get_batch(train_data, i)output = model(data)output_flat = output.view(-1, ntokens)loss = criterion(output_flat, targets)optimizer.zero_grad() # 梯度清零loss.backward() # 反向传播torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) # 梯度裁剪optimizer.step() # 更新参数total_loss += loss.item()if batch % log_interval == 0 and batch > 0:lr = scheduler.get_last_lr()[0] # 获取当前学习率ms_per_batch = (time.time() - start_time) * 1000 / log_intervalcur_loss = total_loss / log_intervalppl = math.exp(cur_loss)print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | 'f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | 'f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')total_loss = 0start_time = time.time()# 评估函数
def evaluate(model: nn.Module, eval_data: Tensor) -> float:model.eval() # 开启评估模式total_loss = 0.with torch.no_grad():for i in range(0, eval_data.size(0) - 1, bptt):data, targets = get_batch(eval_data, i)output = model(data)output_flat = output.view(-1, ntokens)total_loss += criterion(output_flat, targets).item()return total_loss / (len(eval_data) - 1)

训练函数train通过多次迭代数据,并使用梯度下降来更新模型的权重。它还包括了每个日志间隔打印损失和困惑度(perplexity,常用于语言模型的评估指标)。评估函数evaluate用于计算模型在验证集或测试集上的性能,但不会进行参数更新。代码还展示了如何使用学习率调度器来随着训练进行逐步减小学习率。

best_val_loss = float('inf') # 初始设置最佳验证集损失为无穷大
epochs = 1 # 设置训练的总轮数为1
best_model_params = None # 用于存储最佳模型参数# 使用临时目录存储模型参数
with TemporaryDirectory() as tempdir:best_model_params_path = os.path.join(tempdir, "best_model_params.pt")# 循环遍历每个epochfor epoch in range(1, epochs + 1):epoch_start_time = time.time()train(model)val_loss = evaluate(model, val_data) # 在验证集上评估当前模型print('-' * 89)elapsed = time.time() - epoch_start_timeprint(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | valid loss {val_loss:5.2f} | 'f'valid ppl {math.exp(val_loss):8.2f}')print('-' * 89)# 检查当前epoch的验证集损失是否为最佳if val_loss < best_val_loss:best_val_loss = val_loss # 更新最佳验证集损失best_model_params = model.state_dict() # 保存最佳模型参数# 保存有最佳验证集损失的模型参数torch.save(best_model_params, best_model_params_path)scheduler.step() # 更新学习率# 加载最佳模型参数,以便在测试集上进行评估或进一步训练
model.load_state_dict(torch.load(best_model_params_path))

五、评估模型

test_loss = evaluate(model, test_data)
test_ppl = math.exp(test_loss)print('-' * 89)
print(f'| End of training | test loss {test_loss:5.2f} | 'f'test ppl {test_ppl:8.2f}')
print('-' * 89)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/4102.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ruoyi-nbcio-plus基于vue3的flowable为了适配文件上传改造VForm3的代码记录

更多ruoyi-nbcio功能请看演示系统 gitee源代码地址 前后端代码&#xff1a; https://gitee.com/nbacheng/ruoyi-nbcio 演示地址&#xff1a;RuoYi-Nbcio后台管理系统 http://218.75.87.38:9666/ 更多nbcio-boot功能请看演示系统 gitee源代码地址 后端代码&#xff1a; h…

java如何使用webService方式调用对接第三方平台

实际使用记录&#xff0c;做个记录&#xff1a; 1、需要对方提供wsdl文件,该文件中有接口的Ip地址&#xff0c;方法名、参数等详细信息&#xff0c; wsdl文档中targetNamespace为命名空间 <xsd:element name"searchBGDMIInfo">标签中name是方法名&#xff1…

数据结构-树和森林之间的转化

从树的二叉链表的定义可知&#xff0c;任何一棵和树对应的二叉树&#xff0c;其根节点的右子树必为空。这里我们举三个树&#xff0c;将这个由三个树组成的森林组成二叉树是这个样子的。 下面我们说明一下详细过程&#xff0c;首先将每个树转化为二叉的状态&#xff0c;如图所示…

NAT网络地址转换实验(华为)

思科设备参考&#xff1a;NAT网络地址转换实验&#xff08;思科&#xff09; 一&#xff0c;技术简介 NAT&#xff08;Network Address Translation&#xff09;&#xff0c;即网络地址转换技术&#xff0c;是一种在现代计算机网络中广泛应用的技术&#xff0c;主要用于有效管…

汇编语言(详解)

汇编语言安装指南 第一步&#xff1a;在github上下载汇编语言的安装包 网址&#xff1a;GitHub - HaiPenglai/bilibili_assembly: B站-汇编语言-pdf、代码、环境等资料B站-汇编语言-pdf、代码、环境等资料. Contribute to HaiPenglai/bilibili_assembly development by creat…

李廉洋:4.27黄金原油下周一行情分析及走势策略。

金价将出现六周来的首次单周下跌&#xff0c;因投资者在金价上涨数月后获利了结。自2月中旬的低点以来&#xff0c;金价已经上涨了约17%&#xff0c;尽管对美联储放松政策的预期正在减弱&#xff0c;但金价仍屡创新高。周五公布的最新通胀数据强化了高利率将暂时维持的观点。“…

MATLAB的几种边缘检测算子(Sobel、Prewitt、Laplacian)

MATLAB的几种边缘检测算子(Sobel、Prewitt、Laplacian) clc;close all;clear all;warning off;%清除变量 rand(seed, 100); randn(seed, 100); format long g;% 读取图像 image imread(lena.png); % 转换为灰度图像 gray_image rgb2gray(image); % 转换为double类型以进行计算…

Git泄露和hg泄露原理理解和题目实操

一.Git泄露 1.简介 Git是一个开源的分布式版本控制系统&#xff0c;它可以实现有效控制应用版本&#xff0c;但是在一旦在代码发布的时候&#xff0c;存在不规范的操作及配置&#xff0c;就很可能将源代码泄露出去。那么&#xff0c;一旦攻击者发现这个问题之后&#xff0c;就…

论文速览 | IEEE Symposium on Security and Privacy (SP), 2023 | FMCW雷达反射阵列欺骗攻击

注1:本文系"计算成像最新论文速览"系列之一,致力于简洁清晰地介绍、解读非视距成像领域最新的顶会/顶刊论文(包括但不限于 Nature/Science及其子刊; CVPR, ICCV, ECCV, SIGGRAPH, TPAMI; Light‑Science & Applications, Optica 等)。 本次介绍的论文是:<I…

MariaDB 修改用户密码的 SQL

有时候我们希望能够修改数据库中访问用户的密码。 但是我们只能 SQL 登录服务器后才能进行修改。 修改的 SQL 为&#xff1a; ALTER USER root% IDENTIFIED WITH mysql_native_password BY 123;针对实际上数据的配置情况&#xff0c;上面的 SQL 是需要进行一些调整的。 MySQ…

鸿蒙云函数调试坑点

如果你要本地调试请使用 const {payload, action} event.body/** 本地调试不需要序列化远程需要序列化 */ // const {payload, action} JSON.parse(event.body) const {payload, action} event.body 注意: 只要修改云函数&#xff0c;必须上传云函数 如果使用 const {pay…

25计算机考研院校数据分析 | 南京大学

南京大学&#xff08;Nanjing University&#xff09;&#xff0c;简称“南大”&#xff0c;是中华人民共和国教育部直属、中央直管副部级建制的全国重点大学&#xff0c;国家首批“双一流”、“211工程”、“985工程”重点建设高校&#xff0c;入选首批“珠峰计划”、“111计划…

WordPress AI Engine 插件 文件上传致RCE漏洞复现(CVE-2023-51409)

0x01 产品简介 AI Engine插件是WordPress中的AI一体化解决方案,包括创建聊天机器人、生成内容和图像、推荐标题和帖子摘录、支持多种人工智能引擎等功能,可以节省用户时间。 0x02 漏洞概述 WordPress AI Engine 插件upload接口存在文件上传漏洞,未经身份验证的远程攻击者…

(四)Servlet教程——Maven的安装与配置

1.在C盘根目录下新建一个Java文件夹,该文件夹用来放置以下步骤下载的Maven&#xff1b; 2. 下载Maven的来源有清华大学开源软件镜像站和Apache Maven的官网&#xff0c;由于清华大学开源软件镜像站上只能下载3.8.8版本以上的Maven&#xff0c;我们选择在Apache Maven的官网上下…

codeforce#933 题解

E. Rudolf and k Bridges 题意不讲了&#xff0c;不如去题干看图。 传统dp&#xff0c;每个点有两个选择&#xff0c;那么建桥要么不建。需要注意的是在状态转移的时候&#xff0c;桥是有长度的&#xff0c;如果不建需要前d格中建桥花费最少的位置作为状态转移的初态。 #incl…

深度学习论文: MobileNetV4 - Universal Models for the Mobile Ecosystem及其PyTorch实现

深度学习论文: MobileNetV4 - Universal Models for the Mobile Ecosystem及其PyTorch实现 MobileNetV4 - Universal Models for the Mobile Ecosystem PDF: https://arxiv.org/pdf/2404.10518.pdf PyTorch代码: https://github.com/shanglianlm0525/CvPytorch PyTorch代码: ht…

swagger xss漏洞复现

swagger xss漏洞复现 文章目录 swagger xss漏洞复现漏洞介绍影响版本实现原理漏洞复现修复建议: 漏洞介绍 Swagger UI 有一个有趣的功能&#xff0c;允许您提供 API 规范的 URL - 一个 yaml 或 json 文件&#xff0c;将被获取并显示给用户 根本原因非常简单 - 一个过时的库Dom…

高级控件5-RecyclerView

与ViewPager类似的一个滑动的高级控件是RecyclerView&#xff0c;使用更加灵活。 第1步&#xff1a;添加依赖 打开mvn官网&#xff0c;检索recyclerview&#xff0c;选择使用人数较多的版本&#xff0c;复制依赖&#xff0c;放入项目中即可 快捷方法&#xff08;复制下面的代…

科普:PD协议、QC协议、三星AFC、华为SCP是什么,怎么获取这些协议及协议通讯原理

PD协议是什么 PD协议是由 USB-IF 组织制定的一种快速充电规范&#xff0c;它一般使用Type-C接口&#xff0c;所以常见的Type-C接口充电器一般都是支持PD协议。 USB Power Delivery(USB PD)是目前主流的快充协议之一&#xff0c;USB PD 通过Type-C电缆和连接器增加电力输送&…

【Unity动画系统】动画基本原理与Avater骨骼复用

动画基本原理 动画片段文件是一个描述物体变化状态的文本文件 在Unity中创建的资源文件大多都是YAML语言编写的文本文件 Curves表示一种变化状态&#xff0c;为空的话则没有记录任何内容 位置变化后的旋转变化状态&#xff1a; 动画文件里的Path名字要相同才能播放相同的动画 …