深度学习:GPT-1的MindSpore实践

GPT-1简介

GPT-1(Generative Pre-trained Transformer)是2018年由Open AI提出的一个结合预训练和微调的用于解决文本理解和文本生成任务的模型。它的基础是Transformer架构,具有如下创新点:

  • NLP领域的迁移学习:通过最少的任务专项数据,利用预训练模型出色地完成具体的下游任务。
  • 语言建模作为预训练任务:使用无监督学习和大规模的文本语料库来训练模型
  • 为具体任务微调:采用预训练模型来适应监督任务

和BERT类似,GPT-1同样采取pre-train + fine-tune的思路:先基于大量未标注语料数据进行预训练, 后基于少量标注数据进行微调。但GPT-1在预训练任务思路和模型结构上与BERT有所差别。

GPT-1的目标是在预训练的过程中根据现有的所有词元,预测下一个词元。这个任务被称为“自回归语言建模”。

一个简单的例子:

输入序列为:“The sun rises in the”

训练数据的原句子为:“The sun rises in the east”

所以我们的目标输出为:“east”

将输入序列输入GPT模型,GPT根据输入预测下一个词元(“east”)在语料库中的概率分布

正确词元“east”作为一个“伪标签”来帮助模型训练

模型架构

GPT主要使用Transformer Decoder架构,但因为没有Encoder,所以在Transformer Decoder的基础上移除了计算Encoder与Decoder间注意力分数的Multi-Head Attention Layer。

Masked Multi-HeadSelf-Attention

Masked Multi-Head Self-Attention 是Multi-Head Attetion的变种。 最大的不同来自于MMSA的掩码机制,掩码机制防止模型通过观测未来的词元以进行“作弊”。

一个掩码词元<mask>被用于注意力分数矩阵,所以当前词元只能注意到序列中自己和自己之前的词元。未来的次元的注意力分数将被设为0以确保其在Softmax步骤后的实际贡献为0。

为什么掩码机制非常重要?

对于自回归任务,模型必须线性地生成词元,不能基于未来的信息预测下一个词元。

损失函数

GPT使用Cross-Entropy Loss作为损失函数:\mathcal{L} = - \sum_{t=1}^N \log P(w_t | w_1, w_2, \dots, w_{t-1})

交叉熵损失是这项任务的理想选择,因为它通过测量预测的概率分布与真实分布的距离来惩罚不正确的预测。它自然适于处理多类分类任务,其中模型从大量词汇表中选择一个标记。

模型输入

GPT-1的输入同样为句子或句子对,并添加Special Tokens。

  • [BOS]:表示句子的开始,(论文中给出的token表示为[START]),添加到序列最前;
  • [EOS]:表示序列的结束,(论文中的给出的[EXTRACT]),添加到序列最后,在进行分类任务时,会将 该special token对应的输出接入输出层;我们也可以理解为该token可以学习到整个句子的语义信息;
  • [SEP]:用于间隔句子对中的两个句子;
GPT Embedding 同样分为三类:token Embedding、Position Embedding、Segment Embedding

 

GPT-1模型具体参数

模型架构

  • 12个Transformer Decoder Block
  • hidden_size为768(模型输入和输出的向量纬度)
  • 注意力头数为12
  • FFN维度为3072
  • 词表(Vocab)大小为40000
  • 序列长度为512(上下文窗口长度)

训练过程

  • Adam优化器,超参数为:0.9, 0.99
  • 学习率:最大学习率:2.5x10e-4 使用2000步作为热身,随后线性衰退
  • 批大小:64
  • 梯度剪裁:1.0
  • Dropout率:0.1

训练过程

100000步,大约花费8张NVIDIA V100 GPU训练30天,共有117M参数。使用Xavier初始化,权重衰退为0.01。 

下游任务 

GPT按照生成式的逻辑统一了下游任务的应用模板,使用最后一个token([EOS]or[EXTRACT])对应的hidden state,输出到额外的输出层中,进行分类标签预测。
任务包括:文本分类(情感分类、新闻分类)、文本蕴含(根据前提推出假设)、文本语义相似度、多类选择(在多个next token中进行选择)

基于MindSpore微调GPT-1进行情感分类

# #安装mindnlp 0.4.0套件
# !pip install mindnlp
# !pip uninstall soundfile -y
# !pip install download
# !pip install jieba
# !pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/2.3.1/MindSpore/unified/aarch64/mindspore-2.3.1-cp39-cp39-linux_aarch64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simpleimport osimport mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nnfrom mindnlp.dataset import load_datasetfrom mindnlp.engine import Trainer# loading dataset
imdb_ds = load_dataset('imdb', split=['train', 'test'])
imdb_train = imdb_ds['train']
imdb_test = imdb_ds['test']imdb_train.get_dataset_size()import numpy as npdef process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):is_ascend = mindspore.get_context('device_target') == 'Ascend'def tokenize(text):if is_ascend:tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)else:tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)return tokenized['input_ids'], tokenized['attention_mask']if shuffle:dataset = dataset.shuffle(batch_size)# map datasetdataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'attention_mask'])dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")# batch datasetif is_ascend:dataset = dataset.batch(batch_size)else:dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),'attention_mask': (None, 0)})return datasetfrom mindnlp.transformers import OpenAIGPTTokenizer
# tokenizer
gpt_tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')# add sepcial token: <PAD>
special_tokens_dict = {"bos_token": "<bos>","eos_token": "<eos>","pad_token": "<pad>",
}
num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)#为方便体验流程,把原本数据集的十分之一拿出来体验训练和评估,
imdb_train, _ = imdb_train.split([0.1, 0.9], randomize=False)# split train dataset into train and valid datasets
imdb_train, imdb_val = imdb_train.split([0.7, 0.3])dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)
dataset_val = process_dataset(imdb_val, gpt_tokenizer)
dataset_test = process_dataset(imdb_test, gpt_tokenizer)# load GPT sequence classification model and set class=2
from mindnlp.transformers import OpenAIGPTForSequenceClassification  # Import the GPT model for sequence classification
from mindnlp import evaluate  # Import the evaluation module from MindNLP
import numpy as np  # Import NumPy for numerical operations# Set up the GPT model for sequence classification with 2 output labels (binary classification).
model = OpenAIGPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)# Set the padding token ID in the model configuration to match the tokenizer's padding token ID.
model.config.pad_token_id = gpt_tokenizer.pad_token_id# Resize the token embedding layer to account for any added tokens (e.g., special tokens).
model.resize_token_embeddings(model.config.vocab_size + 3)from mindnlp.engine import TrainingArguments  # Import training arguments for model training configuration.# Define training arguments.
training_args = TrainingArguments(output_dir="gpt_imdb_finetune",  # Directory to save model checkpoints and outputs.evaluation_strategy="epoch",  # Evaluate the model at the end of each epoch.save_strategy="epoch",  # Save model checkpoints at the end of each epoch.logging_strategy="epoch",  # Log metrics and progress at the end of each epoch.load_best_model_at_end=True,  # Automatically load the best model (based on evaluation metrics) at the end of training.num_train_epochs=1.0,  # Number of training epochs (default is 1 for quick experimentation).learning_rate=2e-5  # Learning rate for the optimizer.
)# Load the accuracy metric for evaluation.
metric = evaluate.load("accuracy")# Define a function to compute metrics during evaluation.
def compute_metrics(eval_pred):logits, labels = eval_pred  # Unpack predictions (logits) and true labels.predictions = np.argmax(logits, axis=-1)  # Convert logits to class predictions using argmax.return metric.compute(predictions=predictions, references=labels)  # Compute accuracy metric.# Initialize the Trainer class with the model, training arguments, datasets, and metric computation function.
trainer = Trainer(model=model,  # The GPT model to be fine-tuned.args=training_args,  # Training configuration arguments.train_dataset=dataset_train,  # Training dataset (must be preprocessed and tokenized).eval_dataset=dataset_val,  # Validation dataset for evaluation.compute_metrics=compute_metrics  # Metric computation function for evaluation.
)# start training
trainer.train()trainer.evaluate(dataset_test)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/60477.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

day06(单片机高级)PCB设计

目录 PCB设计 PCB设计流程 元器件符号设计 原理图设计 元器件封装设计 元器件库使用 PCB设计 目的&#xff1a;学习从画原理图到PCB设计的整个流程 PCB设计流程 元器件符号设计 元器件符号&#xff1a;这是电子元器件的图形表示&#xff0c;用于在原理图中表示特定的元器件。例…

人工智能(AI)与机器学习(ML)基础知识

目录 1. 人工智能与机器学习的核心概念 什么是人工智能&#xff08;AI&#xff09;&#xff1f; 什么是机器学习&#xff08;ML&#xff09;&#xff1f; 什么是深度学习&#xff08;DL&#xff09;&#xff1f; 2. 机器学习的三大类型 &#xff08;1&#xff09;监督式学…

Java 调用 MULTIPART_FORM_DATA 接口

以 QAnthing 上传文件&#xff08;POST&#xff09;接口为例&#xff0c;展示Java如何调用上传文件接口。 接口文档如下&#xff1a; QAnthign接口文档地址 编码 RestTemplate 版 /** * * param url 接口地址 * param filePath 文件本地路径 */ public void uploadFile(S…

Vue3-小兔鲜项目出现问题及其解决方法(未写完)

基础操作 &#xff08;1&#xff09;使用create-vue搭建Vue3项目 要保证node -v 版本在16以上 &#xff08;2&#xff09;添加pinia到vue项目 npm init vuelatest npm i pinia //导入creatPiniaimport {createPinia} from pinia//执行方法得到实例const pinia createPinia()…

【Vue】 npm install amap-js-api-loader指南

前言 项目中的地图模块突然打不开了 正文 版本太低了&#xff0c;而且Vue项目就应该正经走项目流程啊喂&#xff01; npm i amap/amap-jsapi-loader --save 官方说这样执行完&#xff0c;就这结束啦&#xff01;它结束了&#xff0c;我还没有&#xff0c;不然不可能记录这篇文…

C#桌面应用制作计算器进阶版01

基于C#桌面应用制作计算器做出了少量改动&#xff0c;其主要改动为新增加了一个label控件&#xff0c;使其每一步运算结果由label2展示出来&#xff0c;而当点击“”时&#xff0c;最终运算结果将由label1展示出来&#xff0c;此时label清空。 修改后运行效果 修改后全篇代码 …

Linux下Intel编译器oneAPI安装和链接MKL库编译

参考: https://blog.csdn.net/qq_44263574/article/details/123582481 官网下载: https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit-download.html?packagesoneapi-toolkit&oneapi-toolkit-oslinux&oneapi-linoffline 填写邮件和国家,…

文件管理 IV(文件系统)

一、文件系统结构 文件系统&#xff08;File system&#xff09;提供高效和便捷的磁盘访问&#xff0c;以便允许存储、定位、提取数据。文件系统有两个不同的设计问题&#xff1a;第一个问题是&#xff0c;定义文件系统的用户接口&#xff0c;它涉及定义文件及其属性、所允许的…

基于ToLua的C#和Lua内存共享方案保姆级教程

C#和Lua内存共享方案保姆级教程 前言 在介绍C#和Lua内存共享方案之前,先介绍下面两个点来支撑这个方案的必要性 跨语言交互很费 Lua和C#交互最早是基于反射的方式实现的,后来为了提升性能发展成Luajit+C#静态方法导出注入到lua虚拟机的方式至此Lua+Unity的性能才达到了实…

详细描述一下Elasticsearch索引文档的过程?

大家好&#xff0c;我是锋哥。今天分享关于【详细描述一下Elasticsearch索引文档的过程&#xff1f;】面试题。希望对大家有帮助&#xff1b; 详细描述一下Elasticsearch索引文档的过程&#xff1f; Elasticsearch的索引文档过程是其核心功能之一&#xff0c;涉及将数据存储到…

SpringBoot学习记录(六)配置文件参数化

SpringBoot学习记录&#xff08;六&#xff09;配置文件参数化 一、参数提取到配置文件中二、yml配置文件三、ConfigurationProperties注解实现批量属性注入 一、参数提取到配置文件中 定义在代码中的参数的值分散在各个不同的文件中&#xff0c;不便于后期维护管理&#xff0…

# ubuntu 安装的pycharm不能输入中文的解决方法

ubuntu 安装的pycharm不能输入中文的解决方法 一、问题描述&#xff1a; 当在 ubuntu 系统中&#xff0c;安装了 pycharm&#xff08;如&#xff1a;pycharm2016, 或 pycharm2018&#xff09;,打开 pycharm 输入代码时&#xff0c;发现不能正常输入中文&#xff0c;安装的搜狗…

NLP论文速读(CVPR 2024)|使用DPO进行diffusion模型对齐

论文速读|Diffusion Model Alignment Using Direct Preference Optimization 论文信息&#xff1a; 简介&#xff1a; 本文探讨的背景是大型语言模型&#xff08;LLMs&#xff09;通过人类比较数据和从人类反馈中学习&#xff08;RLHF&#xff09;的方法进行微调&#xff0c;以…

<OS 有关> ubuntu 24 不同版本介绍 安装 Vmware tools

原因 想用 apt-get download 存到本地 / NAS上&#xff0c;减少网络流浪。 看到 VMware 上的确实有 ubuntu&#xff0c;只是版本是16。 ubuntu 版本比较&#xff1a;LTS vs RR LTS: Long-Term Support 长周期支持&#xff0c; 一般每 2 年更新&#xff0c;会更可靠与更稳定…

泛微E9与金蝶云星空的集成方案:实现审批流程与财务管理的无缝对接

泛微E9与金蝶云星空的集成方案&#xff1a;实现审批流程与财务管理的无缝对接 背景介绍&#xff1a; 在企业日常运营中&#xff0c;泛微OA-E9和金蝶云星空是两个关键的系统。泛微OA-E9是一款广受企业青睐的办公自动化软件&#xff0c;它通过流程管理、文档管理、协同办公等模…

Redis最终篇分布式锁以及数据一致性

在前三篇我们几乎说完了Redis的所有的基础知识以及Redis怎么实现高可用性,那么在这一篇文章中的话我们主要就是说明如果我们使用Redis出现什么问题以及解决方案是什么,这个如果在未来的工作中也有可能会遇到,希望对看这篇博客的人有帮助,话不多说直接开干 一.Hotkey以及BigKey…

GPT中转站技术架构

本文介绍阿波罗AI中转站&#xff08;https://api.ablai.top/&#xff09;的技术架构&#xff0c;该中转API的技术架构采用了分布式架构、智能调度和API中转等技术&#xff0c;确保了全球范围内的高效访问和稳定运行。以下是对该技术架构的详细分析&#xff1a; 分布式架构 分…

【强化学习的数学原理】第02课-贝尔曼公式-笔记

学习资料&#xff1a;bilibili 西湖大学赵世钰老师的【强化学习的数学原理】课程。链接&#xff1a;强化学习的数学原理 西湖大学 赵世钰 文章目录 一、为什么return重要&#xff1f;如何计算return&#xff1f;二、state value的定义三、Bellman公式的详细推导四、公式向量形式…

[less] Operation on an invalid type

我这个是升级项目的时候遇到的&#xff0c;要从 scss 升级到 less&#xff0c;然后代码中就报了这个错误 我说一下代码的错误过程&#xff0c;但是这里没有复现&#xff0c;因为我原本报错的代码要复杂很多&#xff0c;而且是公司代码&#xff0c;不方便透露&#xff0c;这是我…

ssm面向品牌会员的在线商城小程序

摘要 随着Internet的发展&#xff0c;人们的日常生活已经离不开网络。未来人们的生活与工作将变得越来越数字化&#xff0c;网络化和电子化。它将是直接管理面向品牌会员的在线商城小程序的最新形式。本小程序是以面向品牌会员的在线商城管理为目标&#xff0c;使用 java技术制…