预训练 BERT 使用 Hugging Face 和 PyTorch 在 AMD GPU 上

Pre-training BERT using Hugging Face & PyTorch on an AMD GPU — ROCm Blogs

2024年1月26日,作者:Vara Lakshmi Bayanagari.

这篇博客解释了如何从头开始使用 Hugging Face 库和 PyTorch 后端在 AMD GPU 上为英文语料(WikiText-103-raw-v1)预训练 BERT 基础模型的端到端过程。

你可以在 GitHub folder中找到与这篇博客相关的文件。

BERT简介

BERT是一种在2019年提出的语言表示模型。其模型架构基于一个transformer编码器,其中自注意力层对输入的每个token对进行注意力计算,整合了来自两个方向的上下文(因此称为BERT的“双向”特性)。在此之前,像ELMo和GPT这样的模型只使用从左到右的(单向)架构,这极大地限制了模型的表现力;模型性能依赖于微调。

本博客解释了BERT所采用的预训练任务,这些任务在通用语言理解评估(GLUE)基准测试中取得了最先进的成果。在接下来的章节中,我们将展示在PyTorch中的实现。

这篇BERT paper最先提出了一种新的预训练方法,称为掩码语言建模(MLM)。MLM随机掩盖输入的某些部分,并对一批输入进行训练以预测这些被掩盖的tokens。预训练期间,在对输入进行分词之后,15%的tokens被随机挑选。其中,80%被替换为一个`[MASK]`标记,10%被替换为一个随机标记,10%则保持不变。

在下面的示例中,MLM预处理方法如下:`dog`标记保持不变,`Golden`和`years`标记被掩盖,并且`and`标记被随机替换为`paper`标记。预训练的目标是使用`CategoricalCrossEntropy`损失来预测这些标记,以便模型学习语言的语法、模式和结构。

Input sentence: My dog is a Golden Retriever and his is 5 years oldAfter MLM: My dog is a [MASK] Retriever paper his is 5 [MASK] old

此外,为了捕捉句子之间的关系,超越掩码语言建模任务,论文提出了第二个预训练任务,称为下一个句子预测(NSP)。在不改变架构的情况下,论文证明了NSP有助于提升问答(QA)和自然语言推理(NLI)任务的结果。

这个任务不直接输入token流,而是输入一对句子的token,例如`A`和`B`,以及一个前置分类标记(`[CLS]`)。分类标记指示句对是随机组合的(label=0)还是`B`是`A`的下一个句子(label=1)。因此,NSP预训练是一种二元分类任务。

_IsNext_ Pair: [1] My dog is a Golden Retriever. He is five years old.Not _IsNext_ Pair: [0] My dog is a Golden Retriever. The next chapter in the book is a biography.

总之,数据集首先进行预处理以形成一对句子,然后进行分词,并最终随机掩盖某些tokens。预处理后的输入批次要么*填充*(使用`[PAD]`标记)或*修剪*(到_max_seq_length_超参数),以便所有输入元素在加载到BERT模型中之前都统一为相同的长度。BERT模型配有两个分类头:一个用于MLM(`num_cls_heads = vocab_size),另一个用于NSP(num_cls_heads=2`)。来自两个预训练任务的分类损失之和用于训练BERT。

在多台 AMD GPU 上的实现

在开始之前,确保您已经满足以下要求:

  1. 在搭载 AMD GPU 的设备上安装 ROCm 兼容的 PyTorch。本实验在 ROCm 5.7.0 和 PyTorch 2.0.1 上进行了测试。

  2. 运行命令 pip install datasets transformers accelerate 以安装 Hugging Face 的相关库。

  3.  运行 accelerate config 命令以设置分布式训练参数,详见此处。在本实验中,我们使用了单节点上的八块 GPU 并行计算,运用了 DistributedDataParallel

实现

Hugging Face 使用 Torch 作为大多数模型的默认后端,从而实现了这两个框架的良好结合。为了简化常规训练步骤并避免样板代码,Hugging Face 提供了一个名为 Trainer 的类,该类模仿了 PyTorch 的功能。类似地,Lightning AI 提供了 Trainer 类。此外,对于分布式训练,Hugging Face 可能更方便,因为代码中没有额外的配置设置,系统会根据 accelerate config 自动检测并利用所有 GPU 设备。然而,如果你希望进一步自定义你的模型并对加载预训练检查点做出额外修改,原生的 PyTorch 是更好的选择。这篇博客解释了使用 Hugging Face 的 transformers 库对 BERT 进行端到端预训练,同时提供了简化的数据预处理管道。

使用 Hugging Face 的 Trainer 进行 BERT 预训练可以用几行代码来总结。transformer 编码器、MLM 分类头和 NSP 分类头都打包在 Hugging Face 的 BertForPreTraining 模型中,该模型返回一个累积分类损失,如我们在 介绍 中所解释的。模型使用默认的 BERT base 配置参数(`NUM_LAYERS`、`ACT_FUNC`、`BATCH_SIZE`、`HIDDEN_SIZE`、`EMBED_DIM` 等)进行初始化。你可以从 Hugging Face 的 BertConfig 中导入这些参数。

那就是全部了吗?几乎。训练最关键的部分是数据预处理。预处理分为三个步骤:

  1.  将你的数据集重新组织为每个文档的句子字典。这对于从随机文档中选取随机句子以进行 NSP 任务非常有用。为此,可以对整个数据集使用简单的for循环。

  2. 使用 Hugging Face 的 AutoTokenizer 来对所有句子进行标记化。

  3. 使用另一个 for 循环,创建 50% 随机对和 50% 顺序对的句子对。

我已经对 WikiText-103-raw-v1 语料库(2,500 M单词)进行了上述的预处理步骤,并将生成的验证集放在这里。预处理的训练集已上传到 Hugging Face Hub。

接下来,导入 DataCollatorForLanguageModeling 收集器以运行 MLM 预处理,并获取掩码和句子分类标签。在使用 Trainer 类时,我们只需要访问 torch.utils.data.Dataset 和一个收集函数。与 TensorFlow 不同,Hugging Face 的 Trainer 会从数据集和收集器函数中创建数据加载器。为了演示,我们使用了有 3,000+ 句对的 Wikitext-103-raw-v1 验证集。

tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
collater = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
# tokenized_dataset = datasets.load_from_disk(args.dataset_file)
tokenized_dataset_valid = datasets.load_from_disk('./wikiTokenizedValid.hf')

创建一个 TrainerArguments 实例,并传递所有必需的参数,如以下代码所示。这部分代码有助于在训练模型时抽象样板代码。该类很灵活,因为它提供了 100 多个参数来适应不同的训练模式;有关更多信息,请参阅 Hugging face transformers 页面。

你现在可以使用 t.train() 来训练模型了。你还可以通过将 resume_from_checkpoint=True 参数传递给 t.train() 来恢复训练。trainer 类会提取 output_dir 文件夹中的最新检查点,并继续训练直到达到总共 num_train_epochs

train_args = TrainingArguments(output_dir=args.output_dir, overwrite_output_dir =True, per_device_train_batch_size =args.BATCH_SIZE, logging_first_step=True,logging_strategy='epoch', evaluation_strategy = 'epoch', save_strategy ='epoch', num_train_epochs=args.EPOCHS,save_total_limit=50)
t = Trainer(model, args = train_args, data_collator=collater, train_dataset = tokenized_dataset, optimizers=(optimizer, None), eval_dataset = tokenized_dataset_valid)
t.train()#resume_from_checkpoint=True)

上述模型使用Adam优化器(`learning_rate=2e-5`)和`per_device_train_batch_size=8`进行了大约400个epoch的训练。在一块AMD GPU(MI210,ROCm 5.7.0,PyTorch 2.0.1)上,使用3,000+句对的验证集进行预训练仅需几个小时。训练曲线如图1所示。可以使用最佳模型检查点微调不同的数据集,并在各种NLP任务上测试其表现。

Graph shows loss decreasing at a roughly exponential rate as epochs increase

完整的代码如下:

set_seed(42)
parser = argparse.ArgumentParser()
parser.add_argument('--BATCH_SIZE', type=int, default = 8) # 32 is the global batch size, since I use 8 GPUs
parser.add_argument('--EPOCHS', type=int, default=200)
parser.add_argument('--train', action='store_true')
parser.add_argument('--dataset_file', type=str, default= './wikiTokenizedValid.hf')
parser.add_argument('--lr', default = 0.00005, type=float)
parser.add_argument('--output_dir', default = './acc_valid/')
args = parser.parse_args()accelerator = Accelerator()if args.train:args.dataset_file = './wikiTokenizedTrain.hf'args.output_dir = './acc/'
print(args)tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
collater = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
tokenized_dataset = datasets.load_from_disk(args.dataset_file)
tokenized_dataset_valid = datasets.load_from_disk('./wikiTokenizedValid.hf')model = BertForPreTraining(BertConfig.from_pretrained("bert-base-cased"))
optimizer = torch.optim.Adam(model.parameters(), lr =args.lr)device = accelerator.device
model.to(accelerator.device)
train_args = TrainingArguments(output_dir=args.output_dir, overwrite_output_dir =True, per_device_train_batch_size =args.BATCH_SIZE, logging_first_step=True,logging_strategy='epoch', evaluation_strategy = 'epoch', save_strategy ='epoch', num_train_epochs=args.EPOCHS,save_total_limit=50)#, lr_scheduler_type=None)
t = Trainer(model, args = train_args, data_collator=collater, train_dataset = tokenized_dataset, optimizers=(optimizer, None), eval_dataset = tokenized_dataset_valid)
t.train()#resume_from_checkpoint=True)

推理

以一个示例文本为例,使用分词器将其转换为输入tokens,并通过collator生成一个掩码输入。

collater = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15, pad_to_multiple_of=128)
text="The author takes his own advice when it comes to writing: he seeks to ground his claims in clear, concrete examples. He shows specific examples of bad writing to help readers better grasp exactly what he’s critiquing"
tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
inp = collater([tokens])
inp['attention_mask'] = torch.where(inp['input_ids']==0,0,1)

使用预训练的权重初始化模型并进行推理。你将看到模型生成的随机tokens没有上下文意义。

config = BertConfig.from_pretrained('bert-base-cased')
model = BertForPreTraining.from_pretrained('./acc_valid/checkpoint-19600/')
model.eval()
out = model(inp['input_ids'], inp['attention_mask'], labels=inp['labels'])print('Input: ', tokenizer.decode(inp['input_ids'][0][:30]), '\n')
print('Output: ', tokenizer.decode(torch.argmax(out[0], -1)[0][:30]))

输入和输出如下所示。该模型在一个非常小的数据集(3,000多句子)上进行了训练;你可以通过在更大的数据集上训练,例如`wikiText-103-raw-v1`的训练切分数据,来提高性能。

The author takes his own advice when it comes to writing : he [MASK] to ground his claims in clear, concrete examples. He shows specific examples of bad
The Churchill takes his own, when it comes to writing : he continued to ground his claims in clear, this examples. He shows is examples of bad

源代码存储在这个 GitHub 文件夹。

结论

我们所描述的预训练BERT基础模型的过程可以很容易地扩展到不同大小的BERT版本以及不同的数据集。我们使用Hugging Face Trainer和PyTorch后端在AMD GPU上训练了我们的模型。对于训练,我们使用了`wikiText-103-raw-v1`数据集的验证集,但这可以很容易地替换为训练集,只需下载我们在Hugging Face Hub上的仓库中托管的预处理和标记化的训练文件Hugging Face Hub.

在本文中,我们通过MLM和NSP预训练任务复制了BERT的预训练过程,这与许多公共平台上仅使用MLM的方法不同。此外,我们没有使用数据集的小部分,而是预处理并上传了整个数据集到Hub上供您方便使用。在未来的文章中,我们将讨论在多个AMD GPU上使用数据并行和分布式策略来训练各种机器学习应用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/58083.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Qgis 开发初级 《ToolBox》

Qgis 有个ToolBox 的,在Processing->ToolBox 菜单里面,界面如下。 理论上Qgis这里面的工具都是可以用脚本或者C 代码调用的。界面以Vector overlay 为例子简单介绍下使用方式。Vector overlay 的意思是矢量叠置分析,和arcgis软件类似的。点…

[大模型学习推理]资料

https://juejin.cn/post/7353963878541361192 lancedb是个不错的数据库,有很多学习资料 https://github.com/lancedb/vectordb-recipes/tree/main/tutorials/Multi-Head-RAG-from-Scratch 博主讲了很多讲解,可以参考 https://juejin.cn/post/7362789…

JMeter详细介绍和相关概念

JMeter是一款开源的、强大的、用于进行性能测试和功能测试的Java应用程序。 本篇承接上一篇 JMeter快速入门示例 , 对该篇中出现的相关概念进行详细介绍。 JMeter测试计划 测试计划名称和注释:整个测试脚本保存的名称,以及对该测试计划的注…

【原创】统信UOS如何安装最新版Node.js(20.x)

注意直接使用sudo apt install nodejs命令安装十有八九会预装10.x的老旧版本Node.js,如果已经安装的建议删除后安装如下方法重装。 在统信UOS系统中更新Node.js可以通过以下步骤进行: 1. 卸载当前版本的Node.js 首先,如果系统中已经安装了N…

4.1.2 网页设计技术

文章目录 1. 万维网(WWW)的诞生2. 移动互联网的崛起3. 网页三剑客:HTML、CSS和JavaScriptHTML:网页的骨架CSS:网页的外衣JavaScript:网页的活力 4. 前端框架的演变基于CSS的框架基于JavaScript的框架基于MV…

【Django】继承框架中用户模型基类AbstractUser扩展系统用户表字段

Django项目新建好app之后,通常情况下需要首要考虑的就是可以认为最重要的用户表,即users对应的model,它对于系统来说可以说是最基础的依赖。 实际上,我们在初始进行migration的时候已经同步生成了相应的user表,如下&am…

spygalss cdc 检测的bug(二)

当allow_qualifier_merge设置为strict的时候,sg是要检查门的极性的。 如果qualifier和src经过与门汇聚,在同另一个src1信号或门汇聚,sg是报unsync的。 假设当qualifier为0时,0&&src||src1src1,src1无法被gat…

xss-labs靶场第十七关测试报告

目录 一、测试环境 1、系统环境 2、使用工具/软件 二、测试目的 三、操作过程 1、注入点寻找 2、使用hackbar进行payload测试 3、绕过结果 四、源代码分析 五、结论 一、测试环境 1、系统环境 渗透机:本机(127.0.0.1) 靶 机:本机(127.0.0.…

Jenkins发布vue项目,版本不一致导致build错误

问题一 yarn.lock文件的存在导致在自动化的时候,频频失败问题二 仓库下载的资源与项目资源版本不一致 本地跑好久的一个项目,现在需要部署在Jenkins上面进行自动化打包部署;想着部署后今后可以省下好多时间,遂兴高采烈地去部署&am…

提升数据处理效率:TDengine S3 的最佳实践与应用

在当今数据驱动的时代,如何高效地存储与处理海量数据成为了企业面临的一大挑战。为了解决这一问题,我们在 TDengine 3.2.2.0 首次发布了企业级功能 S3 存储。这一功能经历多个版本的迭代与完善后,逐渐发展成为一个全面和高效的解决方案。 S3…

python 实现一个简单的浏览器引擎

1. 浏览器引擎工作原理 浏览器引擎是用来处理、渲染和显示网页内容的核心组件。其主要任务是将用户输入的URL所代表的网页资源加载并呈现出来,通常包括HTML、CSS、JavaScript以及各种多媒体内容。浏览器引擎的工作原理可以分为以下几个主要步骤: 1.1 U…

软件系统建设方案书(word参考模板)

1 引言 1.1 编写目的 1.2 项目概述 1.3 名词解释 2 项目背景 3 业务分析 3.1 业务需求 3.2 业务需求分析与解决思路 3.3 数据需求分析【可选】 4 项目建设总体规划【可选】 4.1 系统定位【可选】 4.2 系统建设规划 5 建设目标 5.1 总体目标 5.2 分阶段目标【可选】 5.2.1 业务目…

FlinkSQL之temporary join开发

在实时开发中,双流join获取目标对应时刻的属性时,经常使用temporary join。笔者在流量升级的实时迭代中,需要让流量日志精准的匹配上浏览时间里对应的商品属性,使用temporary join开发过程中踩坑不少,将一些经验沉淀在…

【开源免费】基于SpringBoot+Vue.JS网上超市系统(JAVA毕业设计)

本文项目编号 T 037 ,文末自助获取源码 \color{red}{T037,文末自助获取源码} T037,文末自助获取源码 目录 一、系统介绍二、演示录屏三、启动教程四、功能截图五、文案资料5.1 选题背景5.2 国内外研究现状5.3 可行性分析 六、核心代码6.1 查…

研发运营一体化(DevOps)能力成熟度模型

目录 应用设计 安全风险管理 技术运 持续交付 敏捷开发管理 基于微服务的端到端持续交付流水线案例 应用设计 安全风险管理 技术运 持续交付

Android 判断手机放置的方向

#1024程序员节|征文# 文章目录 前言一、pandas是什么?二、使用步骤 1.引入库2.读入数据总结 需求 老板:我有个手持终端,不能让他倒了,当他倒或者倾斜的时候要发出报警; 程序猿:我这..... 老板…

2024-09-28 地址空间与进程控制

一、进程地址空间 Pt.2 同一个变量,地址相同,其实是虚拟地址相同,内容不同其实是被映射到了不同的物理地址 1. 页表 内存保护与页表标志位 在操作系统中,页表用于管理内存的访问权限。每个页表项通常包含一组标志位&…

二:Python学习笔记--基础知识(1) 变量,关键字,数据类型,赋值运算符,比较运算符

目录 1. 变量 2. python关键字 3. python数据类型 3.1 数字类型 整型 int 浮点型 float 内置函数-type 3.2 字符串类型 3.3 布尔类型 3.4 空类型 3.5 列表类型 3.6 元组类型 3.7 字典类型 4. python赋值运算 5. python比较运算符 1. 变量 组成:必须是数…

基于SSM的BBS社区论坛系统源码

运行环境:ideamysql5.7jdk8maven 使用技术:ssmmysqlshirolayui 功能模块:用户管理、模板管理、帖子管理、公告管理、权限管理等

yolov9目标检测/分割预测报错AttributeError: ‘list‘ object has no attribute ‘device‘常见汇总

这篇文章主要是对yolov9目标检测和目标分割预测测试时的报错,进行解决方案。 在说明解决方案前,严重投诉、吐槽一些博主发的一些文章,压根没用的解决方法,也不知道他们从哪里抄的,误人子弟、浪费时间。 我在解决前&…