从头开始构建和训练 Transformer(下)

导 读

上一篇推文从头开始构建和训练 Transformer(上)icon-default.png?t=N7T8https://blog.csdn.net/weixin_46287760/article/details/136048418介绍了构建和训练Transformer的过程和构建每个组件的代码示例。本文将使用数据对该架构进行代码演示,验证其模型性能。

本期『数据+代码』已上传百度网盘。

有需要的朋友关注公众号【小Z的科研日常】,回复关键词[Transformer]获取

01、加载数据集

对于此任务,我们将使用🤗Hugging Face 上提供的OpusBooks 数据集。该数据集由两个特征组成,idtranslation。该translation功能包含不同语言的句子对,例如西班牙语和葡萄牙语、英语和法语等。

我首先尝试将句子从英语翻译成葡萄牙语,但是这对句子只有 1.4k个示例,因此在该模型的当前配置中结果并不令人满意。然后,我尝试使用英语-法语对,因为它的示例数量较多(127k),但使用当前配置进行训练需要很长时间。

我们首先定义get_all_sentences函数来迭代数据集并根据定义的语言对提取句子。

# 迭代数据集,提取原句及其译文
def get_all_sentences(ds, lang):for pair in ds:yield pair['translation'][lang]

get_ds函数定义为加载和准备数据集以进行训练和验证。在此函数中,我们构建或加载分词器、拆分数据集并创建 DataLoader,以便模型可以成功地批量迭代数据集。这些函数的结果是源语言和目标语言的标记器以及 DataLoader 对象。

def get_ds(config):# 语言对将在我们稍后创建的 "配置 "字典中定义。ds_raw = load_dataset('opus_books', f'{config["lang_src"]}-{config["lang_tgt"]}', split = 'train') # 为源语言和目标语言构建或加载标记符tokenizer_src = build_tokenizer(config, ds_raw, config['lang_src'])tokenizer_tgt = build_tokenizer(config, ds_raw, config['lang_tgt'])# 分割数据集进行训练和验证train_ds_size = int(0.9 * len(ds_raw)) # 90% for trainingval_ds_size = len(ds_raw) - train_ds_size # 10% for validationtrain_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size]) # Randomly splitting the dataset# 使用双语数据集(BilingualDataset)类处理数据,我们将在下面定义该类train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])# 对整个数据集进行迭代,并打印在源语言和目标语言句子中找到的最大长度max_len_src = 0max_len_tgt = 0for pair in ds_raw:src_ids = tokenizer_src.encode(pair['translation'][config['lang_src']]).idstgt_ids = tokenizer_src.encode(pair['translation'][config['lang_tgt']]).idsmax_len_src = max(max_len_src, len(src_ids))max_len_tgt = max(max_len_tgt, len(tgt_ids))print(f'Max length of source sentence: {max_len_src}')print(f'Max length of target sentence: {max_len_tgt}')# 为训练集和验证集创建数据加载器# 在训练和验证过程中,使用数据加载器分批迭代数据集train_dataloader = DataLoader(train_ds, batch_size = config['batch_size'], shuffle = True) # Batch size will be defined in the config dictionaryval_dataloader = DataLoader(val_ds, batch_size = 1, shuffle = True)return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt # Returning the DataLoader objects and tokenizers

我们定义casual_mask函数来为解码器的注意力机制创建掩码。此掩码可防止模型获得有关序列中未来元素的信息。

我们首先制作一个充满 1 的方形网格。我们用参数确定网格大小size。然后,我们将主对角线上方的所有数字更改为零。一侧的每个数字都变成零,而其余的仍然是1。然后该函数翻转所有这些值,将 1 变为 0,将 0 变为 1。这个过程对于预测序列中未来标记的模型至关重要。

02、验证循环

我们现在将为验证循环创建两个函数。验证循环对于评估模型从训练期间未见过的数据翻译句子的性能至关重要。

我们将定义两个函数。第一个函数 ,greedy_decode通过获取最可能的下一个标记为我们提供模型的输出。第二个函数run_validation负责运行验证过程,在该过程中我们解码模型的输出并将其与目标句子的参考文本进行比较。

class BilingualDataset(Dataset):def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len) -> None:super().__init__()self.seq_len = seq_lenself.ds = dsself.tokenizer_src = tokenizer_srcself.tokenizer_tgt = tokenizer_tgtself.src_lang = src_langself.tgt_lang = tgt_langself.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64)self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64)self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64)def __len__(self):return len(self.ds)def __getitem__(self, index: Any) -> Any:src_target_pair = self.ds[index]src_text = src_target_pair['translation'][self.src_lang]tgt_text = src_target_pair['translation'][self.tgt_lang]enc_input_tokens = self.tokenizer_src.encode(src_text).idsdec_input_tokens = self.tokenizer_tgt.encode(tgt_text).idsenc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2 # Subtracting the two '[EOS]' and '[SOS]' special tokensdec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1 # Subtracting the '[SOS]' special tokenif enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:raise ValueError('Sentence is too long')encoder_input = torch.cat([self.sos_token, # inserting the '[SOS]' tokentorch.tensor(enc_input_tokens, dtype = torch.int64), # Inserting the tokenized source textself.eos_token, # Inserting the '[EOS]' tokentorch.tensor([self.pad_token] * enc_num_padding_tokens, dtype = torch.int64) # Addind padding tokens])decoder_input = torch.cat([self.sos_token, # inserting the '[SOS]' token torch.tensor(dec_input_tokens, dtype = torch.int64), # Inserting the tokenized target texttorch.tensor([self.pad_token] * dec_num_padding_tokens, dtype = torch.int64) # Addind padding tokens])label = torch.cat([torch.tensor(dec_input_tokens, dtype = torch.int64), # Inserting the tokenized target textself.eos_token, # Inserting the '[EOS]' token torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype = torch.int64) # Adding padding tokens])assert encoder_input.size(0) == self.seq_lenassert decoder_input.size(0) == self.seq_lenassert label.size(0) == self.seq_lenreturn {'encoder_input': encoder_input,'decoder_input': decoder_input, 'encoder_mask': (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(),'decoder_mask': (decoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int() & casual_mask(decoder_input.size(0)), 'label': label,'src_text': src_text,'tgt_text': tgt_text}  

03、训练循环

我们已准备好在 OpusBook 数据集上训练 Transformer 模型,以执行英语到意大利语翻译任务。我们首先通过调用我们之前定义的get_model函数来定义加载模型的函数。build_transformer该函数使用config字典来设置一些参数。

def get_model(config, vocab_src_len, vocab_tgt_len):model = build_transformer(vocab_src_len, vocab_tgt_len, config['seq_len'], config['seq_len'], config['d_model'])return model 

下面我们将定义两个函数来配置我们的模型和训练过程。

get_config函数中,我们定义了训练过程的关键参数。batch_size一次迭代中使用的训练示例的数量、num_epochs整个数据集通过 Transformer 向前和向后传递的次数、lr优化器的学习率等。我们最终还将定义来自 OpusBook 数据集的对,'lang_src': 'en'用于选择英语作为源语言以及'lang_tgt': 'it'选择意大利语作为目标语言。

get_weights_file_path函数构建用于保存或加载任何特定时期的模型权重的文件路径。

def get_config():return{'batch_size': 8,'num_epochs': 20,'lr': 10**-4,'seq_len': 350,'d_model': 512, 'lang_src': 'en','lang_tgt': 'it','model_folder': 'weights','model_basename': 'tmodel_','preload': None,'tokenizer_file': 'tokenizer_{0}.json','experiment_name': 'runs/tmodel'}def get_weights_file_path(config, epoch: str):model_folder = config['model_folder'] model_basename = config['model_basename'] model_filename = f"{model_basename}{epoch}.pt" return str(Path('.')/ model_folder/ model_filename)

我们最终定义了最后一个函数 ,train_model它将config参数作为输入。

在此函数中,我们将为训练设置一切。我们将模型及其必要组件加载到 GPU 上以加快训练速度,设置Adam优化器并配置CrossEntropyLoss函数来计算模型输出的翻译与数据集中的参考翻译之间的差异。

迭代训练批次、执行反向传播和计算梯度所需的每个循环都在此函数中。我们还将使用它来运行验证函数并保存模型的当前状态。

def train_model(config):# 设置设备在 GPU 上运行,以加快训练速度device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(f"Using device {device}")# 创建模型目录以存储权重Path(config['model_folder']).mkdir(parents=True, exist_ok=True)# 使用 "get_ds "函数检索源语言和目标语言的数据加载器和标记器train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)# 使用 "get_model "函数在 GPU 上初始化模型model = get_model(config,tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)# Tensorboardwriter = SummaryWriter(config['experiment_name'])# 使用'# config'字典中的指定学习率和ε值设置优化器optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps = 1e-9)# 初始化和全局步长变量initial_epoch = 0global_step = 0if config['preload']:model_filename = get_weights_file_path(config, config['preload'])print(f'Preloading model {model_filename}')state = torch.load(model_filename) # Loading modelinitial_epoch = state['epoch'] + 1optimizer.load_state_dict(state['optimizer_state_dict'])global_step = state['global_step']loss_fn = nn.CrossEntropyLoss(ignore_index = tokenizer_src.token_to_id('[PAD]'), label_smoothing = 0.1).to(device)for epoch in range(initial_epoch, config['num_epochs']):batch_iterator = tqdm(train_dataloader, desc = f'Processing epoch {epoch:02d}')for batch in batch_iterator:model.train() # Train the modelencoder_input = batch['encoder_input'].to(device)decoder_input = batch['decoder_input'].to(device)encoder_mask = batch['encoder_mask'].to(device)decoder_mask = batch['decoder_mask'].to(device)encoder_output = model.encode(encoder_input, encoder_mask)decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)proj_output = model.project(decoder_output)label = batch['label'].to(device)loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))batch_iterator.set_postfix({f"loss": f"{loss.item():6.3f}"})writer.add_scalar('train loss', loss.item(), global_step)writer.flush()loss.backward()optimizer.step()optimizer.zero_grad()global_step += 1 run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step, writer)model_filename = get_weights_file_path(config, f'{epoch:02d}')torch.save({'epoch': epoch, # Current epoch'model_state_dict': model.state_dict(),# Current model state'optimizer_state_dict': optimizer.state_dict(), # Current optimizer state'global_step': global_step # Current global step }, model_filename)

现在开始训练我们的模型!

if __name__ == '__main__':warnings.filterwarnings('ignore') # 忽略警告config = get_config() # 检索配置设置train_model(config) # 使用配置参数训练模型

结果如下:

Using device cuda
Downloading builder script:
6.08k/? [00:00<00:00, 391kB/s]
Downloading metadata:
161k/? [00:00<00:00, 11.0MB/s]
Downloading and preparing dataset opus_books/en-it (download: 3.14 MiB, generated: 8.58 MiB, post-processed: Unknown size, total: 11.72 MiB) to /root/.cache/huggingface/datasets/opus_books/en-it/1.0.0/e8f950a4f32dc39b7f9088908216cd2d7e21ac35f893d04d39eb594746af2daf...
Downloading data: 100%
3.30M/3.30M [00:00<00:00, 10.6MB/s]
Dataset opus_books downloaded and prepared to /root/.cache/huggingface/datasets/opus_books/en-it/1.0.0/e8f950a4f32dc39b7f9088908216cd2d7e21ac35f893d04d39eb594746af2daf. Subsequent calls will reuse this data.
Max length of source sentence: 309
Max length of target sentence: 274
....................................................................

04、结论

在本文中,我们深入探索了原始 Transformer 架构,如《Attention Is All You Need》研究论文中所述。我们使用 PyTorch 在语言翻译任务上逐步实现它,使用 OpusBook 数据集进行英语到意大利语的翻译。

Transformer 是向当今最先进模型(例如 OpenAI 的 GPT-4 模型)迈出的革命性一步。这就是为什么理解这种架构如何工作以及它可以实现什么如此重要。

参考论文:“Attention Is All You Need”

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/670733.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2-1 动手学深度学习v2-Softmax回归-笔记

回归 VS 分类 回归估计一个连续值分类预测一个离散类别 从回归到多类分类 回归 单连续数值输出输出的区间&#xff1a;自然区间 R \mathbb{R} R损失&#xff1a;跟真实值的区别 分类 通常多个输出&#xff08;这个输出的个数是等于类别的个数&#xff09;输出的第 i i i…

MATLAB知识点:矩阵的除法

​讲解视频&#xff1a;可以在bilibili搜索《MATLAB教程新手入门篇——数学建模清风主讲》。​ MATLAB教程新手入门篇&#xff08;数学建模清风主讲&#xff0c;适合零基础同学观看&#xff09;_哔哩哔哩_bilibili 节选自第3章 3.4.2 算术运算 下面我们再来介绍矩阵的除法。事…

企业数字化转型面临什么挑战?

数字化转型是一个复杂且持续的过程&#xff0c;涉及将数字技术集成到组织的各个方面&#xff0c;从根本上改变组织的运营方式和为客户提供价值的方式。虽然具体的挑战可能因企业的性质和规模而异&#xff0c;但一些常见的挑战包括&#xff1a; 1.抵制变革&#xff1a; 文化阻…

Java入门之JavaSe(韩顺平p1-p?)

学习背景&#xff1a; 本科搞过一段ACM、研究生搞了一篇B会后&#xff0c;本人在研二要学Java找工作啦~~&#xff08;宇宙尽头是Java&#xff1f;&#xff09;爪洼纯小白入门&#xff0c;C只会STL、python只会基础Pytorch、golang参与了一个Web后端项目&#xff0c;可以说项目小…

Flink-CDC实时读Postgresql数据

前言 CDC,Change Data Capture,变更数据获取的简称,使用CDC我们可以从数据库中获取已提交的更改并将这些更改发送到下游,供下游使用。这些变更可以包括INSERT,DELETE,UPDATE等。 用户可以在如下的场景使用cdc: 实时数据同步:比如将Postgresql库中的数据同步到我们的数仓中…

Python初学者学习记录——python基础综合案例:数据可视化——动态柱状图

一、案例效果 通过pyecharts可以实现数据的动态显示&#xff0c;直观的感受1960~2019年世界各国GDP的变化趋势 二、通过Bar构建基础柱状图 反转x轴和y轴 标签数值在右侧 from pyecharts.charts import Bar from pyecharts.options import LabelOpts# 构建柱状图对象 bar Bar()…

二进制安全虚拟机Protostar靶场(7)heap2 UAF(use-after-free)漏洞

前言 这是一个系列文章&#xff0c;之前已经介绍过一些二进制安全的基础知识&#xff0c;这里就不过多重复提及&#xff0c;不熟悉的同学可以去看看我之前写的文章 heap2 程序静态分析 https://exploit.education/protostar/heap-two/#include <stdlib.h> #include &…

环境配置:Ubuntu18.04 ROS Melodic安装

前言 不同版本的Ubuntu与ROS存在对应关系。 ROS作为目前最受欢迎的机器人操作系统&#xff0c;其核心代码采用C编写&#xff0c;并以BSD许可发布。ROS起源于2007年&#xff0c;是由斯坦福大学与机器人技术公司Willow Garage合作的Switchyard项目。2012年&#xff0c;ROS团队从…

力扣面试题 05.03. 翻转数位(前、后缀和)

Problem: 面试题 05.03. 翻转数位 文章目录 题目描述思路及解法复杂度Code 题目描述 思路及解法 1.将十进制数转换为二进制数&#xff08;每次按位与1求与&#xff0c;并且右移&#xff09;&#xff1b; 2.依次求取二进制数中每一位的前缀1的数量和&#xff0c;和后缀1的数量和…

计算机项目SpringBoot项目 办公小程序开发

从零构建后端项目、利用UNI-APP创建移动端项目 实现注册与登陆、人脸考勤签到、实现系统通知模块 实现会议管理功能、完成在线视频会议功能、 发布Emos在线办公系统 项目分享&#xff1a; SpringBoot项目 办公小程序开发https://pan.baidu.com/s/1sYPLOAMtaopJCFHAWDa2xQ?…

极狐GitLab 使用阿里云作为 OmniAuth 身份验证 provider

使用阿里云作为 OmniAuth 身份验证 provider 您可以启用阿里云 OAuth 2.0 OmniAuth provider并使用您的阿里云账户登录极狐GitLab。 创建阿里云应用 登录阿里云平台&#xff0c;在上面创建一个应用。阿里云会生成一个 client ID and secret key 供您使用。 登录到阿里云平台…

PHP实现DESede/ECB/PKCS5Padding加密算法兼容Java SHA1PRNG

这里写自定义目录标题 背景JAVA代码解决思路PHP解密 背景 公司PHP开发对接一个Java项目接口&#xff0c;接口返回数据有用DESede/ECB/PKCS5Padding加密&#xff0c;并且key也使用了SHA1PRNG加密了&#xff0c;网上找了各种办法都不能解密&#xff0c;耗了一两天的时间&#xf…

C语言:内存函数

创作不易&#xff0c;友友们给个三连吧&#xff01;&#xff01; C语言标准库中有这样一些内存函数&#xff0c;让我们一起学习吧&#xff01;&#xff01; 一、memcpy函数的使用和模拟实现 void * memcpy ( void * destination, const void * source, size_t num ); 1.1 使…

微信小程序(三十四)搜索框-带历史记录

注释很详细&#xff0c;直接上代码 上一篇 新增内容&#xff1a; 1.搜索框基本模板 2.历史记录基本模板 3.细节处理 源码&#xff1a; index.wxml <!-- 1.点击搜索按钮a.非空判断b.历史记录&#xff08;去重&#xff09;c.清空搜索框d.去除前后多余空格2.删除搜索 3.无搜索…

Golang 学习(一)基础知识

面向对象 Golang 也支持面向对象编程(OOP)&#xff0c;但是和传统的面向对象编程有区别&#xff0c;并不是纯粹的面向对象语言。 Golang 没有类(class)&#xff0c;Go 语言的结构体(struct)和其它编程语言的类(class)有同等的地位&#xff0c;Golang 是基于 struct 来实现 OOP…

部署 Zabbix 监控平台

部署 Zabbix 监控平台 目录 部署 Zabbix 监控平台一、 Zabbix简介Zabbix 特性Zabbix监控功能 二、Zabbix 概述Server数据库Web 界面ProxyAgent数据流Zabbix serverZabbix agentzabbix配置文件 三、部署Zabbix1&#xff1a;部署监控服务器1.1安装 LNMP 环境1.2 修改 Nginx 配置文…

Unity类银河恶魔城学习记录1-14 AttackDirection源代码 P41

Alex教程每一P的教程原代码加上我自己的理解初步理解写的注释&#xff0c;可供学习Alex教程的人参考 此代码仅为较上一P有所改变的代码 【Unity教程】从0编程制作类银河恶魔城游戏_哔哩哔哩_bilibili PlayerPrimaryAttackState.cs using System.Collections; using System.Co…

C语言的malloc(0)问题

malloc(0)详解 首先来解释malloc&#xff08;0&#xff09;的问题&#xff0c;这个语法是对的&#xff0c;而且确实也分配了内存&#xff0c;但是内存空间是0&#xff0c;就是说返回给你的指针是不能用的&#xff0c;感觉奇怪吧&#xff1f;但是从操作系统的原理来解释就不奇怪…

6-2、T型加减速计算简化【51单片机+L298N步进电机系列教程】

↑↑↑点击上方【目录】&#xff0c;查看本系列全部文章 摘要&#xff1a;本节介绍简化T型加减速计算过程&#xff0c;使其适用于单片机数据处理。简化内容包括浮点数转整型数计算、加减速对称处理、预处理计算 一、浮点数转整型数计算 根据上一节内容已知 常用的晶振大小…

【Vue3】项目实战前基本知识

Vue3ViteTypeScriptpinia Vue3更新点新建项目方式一新建项目方式二vite-demo目录讲解安装常用扩展 vue3书写风格动态css也可以这样使用 虚拟DOMRef全家桶ref小知识1ref小知2&#xff0c;可以直接操作Dom recative全家桶数组赋值方式一数组赋值方式二 to系列全家桶Vue3的响应式原…