语言模型GPT与HuggingFace应用

受到计算机视觉领域采用ImageNet对模型进行一次预训练,使得模型可以通过海量图像充分学习如何提取特征,然后再根据任务目标进行模型微调的范式影响,自然语言处理领域基于预训练语言模型的方法也逐渐成为主流。以ELMo为代表的动态词向量模型开启了语言模型预训练的大门,此后以GPT 和BERT为代表的基于Transformer 的大规模预训练语言模型的出现,使得自然语言处理全面进入了预训练微调范式新时代。

利用丰富的训练语料、自监督的预训练任务以及Transformer 等深度神经网络结构,预训练语言模型具备了通用且强大的自然语言表示能力,能够有效地学习到词汇、语法和语义信息。将预训练模型应用于下游任务时,不需要了解太多的任务细节,不需要设计特定的神经网络结构,只需要“微调”预训练模型,即使用具体任务的标注数据在预训练语言模型上进行监督训练,就可以取得显著的性能提升。

OpenAI 公司在2018 年提出的生成式预训练语言模型(Generative Pre-Training,GPT)是典型的生成式预训练语言模型之一。GPT 模型结构如图2.3所示,由多层Transformer 组成的单向语言模型,主要分为输入层,编码层和输出层三部分。
接下来我将重点介绍GPT 无监督预训练、有监督下游任务微调以及基于HuggingFace 的预训练语言模型实践。

一、 无监督预训练

GPT 采用生成式预训练方法,单向意味着模型只能从左到右或从右到左对文本序列建模,所采用的Transformer 结构和解码策略保证了输入文本每个位置只能依赖过去时刻的信息。
给定文本序列w = w1w2...wn,GPT 首先在输入层中将其映射为稠密的向量:

其中,

是词wi 的词向量,

 是词wi 的位置向量,vi 为第i 个位置的单词经过模型输入层(第0层)后的输出。GPT 模型的输入层与前文中介绍的神经网络语言模型的不同之处在于其需要添加

图1.1 GPT 预训练语言模型结构

位置向量,这是Transformer 结构自身无法感知位置导致的,因此需要来自输入层的额外位置信息。经过输入层编码,模型得到表示向量序列v = v1...vn,随后将v 送入模型编码层。编码层由L 个Transformer 模块组成,在自注意力机制的作用下,每一层的每个表示向量都会包含之前位置表示向量的信息,使每个表示向量都具备丰富的上下文信息,并且经过多层编码后,GPT 能得到每个单词层次化的组合式表示,其计算过程表示如下:

其中

 表示第L 层的表示向量序列,n 为序列长度,d 为模型隐藏层维度,L 为模型总层数。GPT 模型的输出层基于最后一层的表示h(L),预测每个位置上的条件概率,其计算过程可以表示为:

其中,

 为词向量矩阵,|V| 为词表大小。单向语言模型是按照阅读顺序输入文本序列w,用常规语言模型目标优化w 的最大似然估计,使之能根据输入历史序列对当前词能做出准确的预测:

其中θ 代表模型参数。也可以基于马尔可夫假设,只使用部分过去词进行训练。预训练时通常使用随机梯度下降法进行反向传播优化该负似然函数。

二、 有监督下游任务微调

通过无监督语言模型预训练,使得GPT 模型具备了一定的通用语义表示能力。下游任务微调(Downstream Task Fine-tuning)的目的是在通用语义表示基础上,根据下游任务的特性进行适配。下游任务通常需要利用有标注数据集进行训练,数据集合使用D 进行表示,每个样例由输入长度为n 的文本序列x = x1x2...xn 和对应的标签y 构成。
首先将文本序列x 输入GPT 模型,获得最后一层的最后一个词所对应的隐藏层输出h(L)n ,在此基础上通过全连接层变换结合Softmax 函数,得到标签预测结果。

其中

为全连接层参数,k 为标签个数。通过对整个标注数据集D 优化如下目标函数
精调下游任务:

下游任务在微调过程中,针对任务目标进行优化,很容易使得模型遗忘预训练阶段所学习到的通用语义知识表示,从而损失模型的通用性和泛化能力,造成灾难性遗忘(Catastrophic Forgetting)问题。因此,通常会采用混合预训练任务损失和下游微调损失的方法来缓解上述问题。在实际应用中,通常采用如下公式进行下游任务微调:

其中λ 取值为[0,1],用于调节预训练任务损失占比。

三、基于HuggingFace 的预训练语言模型实践

HuggingFace 是一个开源自然语言处理软件库。其的目标是通过提供一套全面的工具、库和模型,使得自然语言处理技术对开发人员和研究人员更加易于使用。HuggingFace 最著名的贡献之一是Transformer 库,基于此研究人员可以快速部署训练好的模型以及实现新的网络结构。除此之外,HuggingFace 还提供了Dataset 库,可以非常方便地下载自然语言处理研究中最常使用的基准数据集。本节中,将以构建BERT 模型为例,介绍基于Huggingface 的BERT 模型构建和使用方法。

3.1. 数据集合准备

常见的用于预训练语言模型的大规模数据集都可以在Dataset 库中直接下载并加载。例如,如果使用维基百科的英文语料集合,可以直接通过如下代码完成数据获取:

复制代码

from datasets import concatenate_datasets, load_dataset
bookcorpus = load_dataset("bookcorpus", split="train")
wiki = load_dataset("wikipedia", "20230601.en", split="train")
# 仅保留'text' 列
wiki = wiki.remove_columns([col for col in wiki.column_names if col != "text"])
dataset = concatenate_datasets([bookcorpus, wiki])
# 将数据集合切分为90% 用于训练,10% 用于测试
d = dataset.train_test_split(test_size=0.1)

接下来将训练和测试数据分别保存在本地文件中

def dataset_to_text(dataset, output_filename="data.txt"):"""Utility function to save dataset text to disk,useful for using the texts to train the tokenizer(as the tokenizer accepts files)"""with open(output_filename, "w") as f:for t in dataset["text"]:print(t, file=f)
# save the training set to train.txt
dataset_to_text(d["train"], "train.txt")
# save the testing set to test.txt
dataset_to_text(d["test"], "test.txt")

3.2. 训练词元分析器(Tokenizer)

如前所述,BERT 采用了WordPiece 分词,根据训练语料中的词频决定是否将一个完整的词切分为多个词元。因此,需要首先训练词元分析器(Tokenizer)。可以使用transformers 库中的BertWordPieceTokenizer 类来完成任务,代码如下所示:

special_tokens = [
"[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "<S>", "<T>"
]#
if you want to train the tokenizer on both sets
# files = ["train.txt", "test.txt"]
# training the tokenizer on the training set
files = ["train.txt"]
# 30,522 vocab is BERT's default vocab size, feel free to tweak
vocab_size = 30_522
# maximum sequence length, lowering will result to faster training (when increasing batch size)
max_length = 512
# whether to truncate
truncate_longer_samples = False
# initialize the WordPiece tokenizer
tokenizer = BertWordPieceTokenizer()
# train the tokenizer
tokenizer.train(files=files, vocab_size=vocab_size, special_tokens=special_tokens)
# enable truncation up to the maximum 512 tokens
tokenizer.enable_truncation(max_length=max_length)
model_path = "pretrained-bert"
# make the directory if not already there
if not os.path.isdir(model_path):os.mkdir(model_path)
# save the tokenizer
tokenizer.save_model(model_path)
# dumping some of the tokenizer config to config file,
# including special tokens, whether to lower case and the maximum sequence length
with open(os.path.join(model_path, "config.json"), "w") as f:tokenizer_cfg = {"do_lower_case": True,"unk_token": "[UNK]","sep_token": "[SEP]","pad_token": "[PAD]","cls_token": "[CLS]","mask_token": "[MASK]","model_max_length": max_length,"max_len": max_length,}json.dump(tokenizer_cfg, f)
# when the tokenizer is trained and configured, load it as BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained(model_path)

3.3. 预处理语料集合

在启动整个模型训练之前,还需要将预训练语料根据训练好的Tokenizer 进行处理。如果文档长度超过512 个词元(Token),那么就直接进行截断。数据处理代码如下所示:

def encode_with_truncation(examples):"""Mapping function to tokenize the sentences passed with truncation"""return tokenizer(examples["text"], truncation=True, padding="max_length",max_length=max_length, return_special_tokens_mask=True)
def encode_without_truncation(examples):"""Mapping function to tokenize the sentences passed without truncation"""return tokenizer(examples["text"], return_special_tokens_mask=True)
# the encode function will depend on the truncate_longer_samples variable
encode = encode_with_truncation if truncate_longer_samples else encode_without_truncation
# tokenizing the train dataset
train_dataset = d["train"].map(encode, batched=True)
# tokenizing the testing dataset
test_dataset = d["test"].map(encode, batched=True)
if truncate_longer_samples:# remove other columns and set input_ids and attention_mask as PyTorch tensorstrain_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
else:# remove other columns, and remain them as Python liststest_dataset.set_format(columns=["input_ids", "attention_mask", "special_tokens_mask"])train_dataset.set_format(columns=["input_ids", "attention_mask", "special_tokens_mask"])

truncate_longer_samples 布尔变量来控制用于对数据集进行词元处理的encode() 回调函数。如果设置为True,则会截断超过最大序列长度(max_length)的句子。否则,不会截断。如果设为truncate_longer_samples 为False,需要将没有截断的样本连接起来,并组合成固定长度的向量。

3.4. 模型训练

在构建了处理好的预训练语料之后,就可以开始模型训练。代码如下所示:

# initialize the model with the config
model_config = BertConfig(vocab_size=vocab_size, max_position_embeddings=max_length)
model = BertForMaskedLM(config=model_config)
# initialize the data collator, randomly masking 20% (default is 15%) of the tokens
# for the Masked Language Modeling (MLM) task
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.2
)
training_args = TrainingArguments(output_dir=model_path, # output directory to where save model checkpointevaluation_strategy="steps", # evaluate each `logging_steps` stepsoverwrite_output_dir=True,num_train_epochs=10, # number of training epochs, feel free to tweakper_device_train_batch_size=10, # the training batch size, put it as high as your GPU memory fitsgradient_accumulation_steps=8, # accumulating the gradients before updating the weightsper_device_eval_batch_size=64, # evaluation batch sizelogging_steps=1000, # evaluate, log and save model checkpoints every 1000 stepsave_steps=1000,# load_best_model_at_end=True, # whether to load the best model (in terms of loss)# at the end of training# save_total_limit=3, # whether you don't have much space so you# let only 3 model weights saved in the disk
)
trainer = Trainer(model=model,args=training_args,data_collator=data_collator,train_dataset=train_dataset,eval_dataset=test_dataset,
)
# train the model
trainer.train()

开始训练后,可以如下输出结果:

[10135/79670 18:53:08 < 129:35:53, 0.15 it/s, Epoch 1.27/10]
Step Training Loss Validation Loss
1000 6.904000 6.558231
2000 6.498800 6.401168
3000 6.362600 6.277831
4000 6.251000 6.172856
5000 6.155800 6.071129
6000 6.052800 5.942584
7000 5.834900 5.546123
8000 5.537200 5.248503
9000 5.272700 4.934949
10000 4.915900 4.549236

3.5. 模型使用

基于训练好的模型,可以针对不同应用需求进行使用。

# load the model checkpoint
model = BertForMaskedLM.from_pretrained(os.path.join(model_path, "checkpoint-10000"))
# load the tokenizer
tokenizer = BertTokenizerFast.from_pretrained(model_path)
fill_mask = pipeline("fill-mask", model=model, tokenizer=tokenizer)
# perform predictions
examples = [
"Today's most trending hashtags on [MASK] is Donald Trump",
"The [MASK] was cloudy yesterday, but today it's rainy.",
]
for example in examples:for prediction in fill_mask(example):print(f"{prediction['sequence']}, confidence: {prediction['score']}")print("="*50)

可以得到如下输出:

today's most trending hashtags on twitter is donald trump, confidence: 0.1027069091796875
today's most trending hashtags on monday is donald trump, confidence: 0.09271949529647827
today's most trending hashtags on tuesday is donald trump, confidence: 0.08099588006734848
today's most trending hashtags on facebook is donald trump, confidence: 0.04266013577580452
today's most trending hashtags on wednesday is donald trump, confidence: 0.04120611026883125
==================================================
the weather was cloudy yesterday, but today it's rainy., confidence: 0.04445931687951088
the day was cloudy yesterday, but today it's rainy., confidence: 0.037249673157930374
the morning was cloudy yesterday, but today it's rainy., confidence: 0.023775646463036537
the weekend was cloudy yesterday, but today it's rainy., confidence: 0.022554103285074234
the storm was cloudy yesterday, but today it's rainy., confidence: 0.019406016916036606
==================================================

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/209128.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

在线教育小程序正在成为教育行业的新生力量

教育数字化转型是目前教育领域的一个热门话题&#xff0c;那么到底什么是教育数字化转型&#xff1f;如何做好教育数字化转型&#xff1f; 教育数字化转型是利用信息技术和数字工具改变和优化教育的过程。主要特征包括技术整合、在线学习、个性化学习、大数据分析、云计算、虚拟…

【C++学习手札】基于红黑树封装模拟实现map和set

​ &#x1f3ac;慕斯主页&#xff1a;修仙—别有洞天 &#x1f49c;本文前置知识&#xff1a; 红黑树 ♈️今日夜电波&#xff1a;漂流—菅原纱由理 2:55━━━━━━️&#x1f49f;──────── 4:29 …

Appium获取toast方法封装

一、前置说明 toast消失的很快&#xff0c;并且通过uiautomatorviewer也不能获取到它的定位信息&#xff0c;如下图&#xff1a; 二、操作步骤 toast的class name值为android.widget.Toast&#xff0c;虽然toast消失的很快&#xff0c;但是它终究是在Dom结构中出现过&…

【计算机网络】HTTP请求

目录 前言 HTTP请求报文格式 一. 请求行 HTTP请求方法 GET和POST的区别 URL 二. 请求头 常见的Header 常见的额请求体数据类型 三. 请求体 结束语 前言 HTTP是应用层的一个协议。实际我们访问一个网页&#xff0c;都会像该网页的服务器发送HTTP请求&#xff0c;服务…

使用Java将图片添加到Excel的几种方式

1、超链接 使用POI&#xff0c;依赖如下 <dependency><groupId>org.apache.poi</groupId><artifactId>poi</artifactId><version>4.1.2</version></dependency>Java代码如下,运行该程序它会在桌面创建ImageLinks.xlsx文件。 …

GPT-4V 在机器人领域的应用

在科技的浩渺宇宙中&#xff0c;OpenAI如一颗璀璨的星辰&#xff0c;于2023年9月25日&#xff0c;以一种全新的方式&#xff0c;向世界揭示了其最新的人工智能力作——GPT-4V模型。这次升级&#xff0c;为其旗下的聊天机器人ChatGPT装配了语音和图像的新功能&#xff0c;使得用…

『Linux升级路』进度条小程序

&#x1f525;博客主页&#xff1a;小王又困了 &#x1f4da;系列专栏&#xff1a;Linux &#x1f31f;人之为学&#xff0c;不日近则日退 ❤️感谢大家点赞&#x1f44d;收藏⭐评论✍️ 目录 一、预备知识 &#x1f4d2;1.1缓冲区 &#x1f4d2;1.2回车和换行 二、倒计…

修改正点原子综合实验的NES模拟器按键控制加横屏

​​​​​​​ 开发板&#xff1a;stm32f407探索者开发板V2 屏幕是4.3寸-800-480-MCU屏 手头没有V3开发板&#xff0c;只有V2&#xff0c;所以没法测试 所以只讲修改哪里&#xff0c;请自行修改 先改手柄部分&#xff0c;把手柄改成按键 找到左边的nes文件夹中的nes_mai…

采用轨到轨输出设计 LTC6363HMS8-2、LTC6363HMS8-1、LTC6363HRD、LTC6363IDCB差分放大器I

产品详情 LTC6363 系列包括四个全差分、低功耗、低噪声放大器&#xff0c;具有经优化的轨到轨输出以驱动 SAR ADC。LTC6363 是一款独立的差分放大器&#xff0c;通常使用四个外部电阻设置其增益。LTC6363-0.5、LTC6363-1 和 LTC6363-2 都有内部匹配电阻&#xff0c;可分别创建…

C++数据结构:B树

目录 一. 常见的搜索结构 二. B树的概念 三. B树节点的插入和遍历 3.1 插入B树节点 3.2 B树遍历 四. B树和B*树 4.1 B树 4.2 B*树 五. B树索引原理 5.1 索引概述 5.2 MyISAM 5.3 InnoDB 六. 总结 一. 常见的搜索结构 表示1为在实际软件开发项目中&#xff0c;常用…

博途PLC SCL间接寻址编程应用

这篇博客里我们将要学习Pointer和Any指针&#xff0c;PEEK和POKE指令&#xff0c;当然我们还可以数组类型数据实现数组指针寻址&#xff0c;具体应用介绍请参考下面文章链接&#xff1a; https://rxxw-control.blog.csdn.net/article/details/134761364https://rxxw-control.b…

一文讲解如何从 Clickhouse 迁移数据至 DolphinDB

ClickHouse 是 Yandex 公司于2016年开源的 OLAP 列式数据库管理系统&#xff0c;主要用于 WEB 流量分析。凭借面向列式存储、支持数据压缩、完备的 DBMS 功能、多核心并行处理的特点&#xff0c;ClickHouse 被广泛应用于广告流量、移动分析、网站分析等领域。 DolphinDB 是一款…

【Hadoop_02】Hadoop运行模式

1、Hadoop的scp与rsync命令&#xff08;1&#xff09;本地运行模式&#xff08;2&#xff09;完全分布式搭建【1】利用102将102的文件推到103【2】利用103将102的文件拉到103【3】利用103将102的文件拉到104 &#xff08;3&#xff09;rsync命令&#xff08;4&#xff09;xsync…

使用 HTML 地标角色提高可访问性

请务必确保所有用户都可以访问您的网站&#xff0c;包括使用屏幕阅读器等辅助技术的用户。 一种方法是使用 ARIA 地标角色来帮助屏幕阅读器用户轻松浏览您的网站。使用地标角色还有其他好处&#xff0c;例如改进 HTML 的语义并更轻松地设置网站样式。在这篇博文中&#xff0c;我…

深度探索Linux操作系统 —— 构建initramfs

系列文章目录 深度探索Linux操作系统 —— 编译过程分析 深度探索Linux操作系统 —— 构建工具链 深度探索Linux操作系统 —— 构建内核 深度探索Linux操作系统 —— 构建initramfs 文章目录 系列文章目录前言一、为什么需要 initramfs二、initramfs原理探讨三、构建基本的init…

tomcat篇---第二篇

系列文章目录 文章目录 系列文章目录前言一、tomcat容器是如何创建servlet类实例?用到了什么原理?二、tomcat 如何优化?三、熟悉tomcat的哪些配置?前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站,这篇文章男女…

Web应用JSON数据保护(密码算法、密钥、数字签名和数据加密)

1.JSON&#xff08;JavaScript Object Notation&#xff09; JSON是一种轻量级的数据交换格式&#xff0c;采用完全独立于编程语言的文本格式来存储和表示数据。JSON通过简单的key-value键值对来描述数据&#xff0c;可以被广泛用于网络通信、数据存储等各种应用场景&#xff0…

Python面向对象基础

Python面向对象基础 一、概念1.1面向对象的设计思想1.2 面向过程和面向对象1.2.1 面向过程1.2.2 面向对象1.2.3 面向过程和面向对象的优缺点 二、类和对象2.1 概念2.2 类的定义2.3 对象的创建2.3.1 类中未定义构造函数2.3.2 类中定义构造函数 2.4 类的设计 三、类中的成员3.1 变…

Python教程-数组

作为软件开发者&#xff0c;我们总是努力编写干净、简洁、高效的代码。在本文中&#xff0c;我们将探索 Python 数组的各种特性和功能。我们将学习如何在 Python 中创建、操作和使用数组&#xff0c;以及数组与 Python 编程语言中的其他数据结构有何不同。我们的目标是提供有关…

资源文件、布局管理器、样式表拓展

QT 资源文件 提供了和本地路径无关的资源管理。 图片资源的获取&#xff1a;阿里巴巴矢量图库&#xff08;&#x1f448; 安全链接&#xff0c;放心跳转&#xff09; widget.ui .qrc widget.h #ifndef WIDGET_H #define WIDGET_H#include <QtWidgets>namespace Ui { c…