昇思大模型平台打卡体验活动:项目4基于MindSpore实现Roberta模型Prompt Tuning

基于MindNLP的Roberta模型Prompt Tuning

本文档介绍了如何基于MindNLP进行Roberta模型的Prompt Tuning,主要用于GLUE基准数据集的微调。本文提供了完整的代码示例以及详细的步骤说明,便于理解和复现实验。

环境配置

在运行此代码前,请确保MindNLP库已经安装。本文档基于大模型平台运行,因此需要进行适当的环境配置,确保代码可以在相应的平台上运行。

模型与数据集加载

在本案例中,我们使用 roberta-large 模型并基于GLUE基准数据集进行Prompt Tuning。GLUE (General Language Understanding Evaluation) 是自然语言处理中的标准评估基准,包括多个子任务,如句子相似性匹配、自然语言推理等。Prompt Tuning是一种新的微调技术,通过插入虚拟的“提示”Token在模型的输入中,以微调较少的参数达到较好的性能。

import mindspore
from tqdm import tqdm
from mindnlp import evaluate
from mindnlp.dataset import load_dataset
from mindnlp.transformers import AutoModelForSequenceClassification, AutoTokenizer
from mindnlp.core.optim import AdamW
from mindnlp.transformers.optimization import get_linear_schedule_with_warmup
from mindnlp.peft import (get_peft_model,PeftType,PromptTuningConfig,
)

1. 定义训练参数

首先,定义模型名称、数据集任务名称、Prompt Tuning类型、训练轮数等基本参数。

batch_size = 32
model_name_or_path = "roberta-large"
task = "mrpc"
peft_type = PeftType.PROMPT_TUNING
num_epochs = 20

2. 配置Prompt Tuning

在Prompt Tuning的配置中,选择任务类型为"SEQ_CLS"(序列分类任务),并定义虚拟Token的数量。虚拟Token即为插入模型输入中的“提示”Token,通过这些Token的微调,使得模型能够更好地完成下游任务。

peft_config = PromptTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=10)
lr = 1e-3

3. 加载Tokenizer

根据模型类型选择padding的侧边,如果模型为GPT、OPT或BLOOM类模型,则从序列左侧填充(padding),否则从序列右侧填充。

if any(k in model_name_or_path for k in ("gpt", "opt", "bloom")):padding_side = "left"
else:padding_side = "right"tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)
if getattr(tokenizer, "pad_token_id") is None:tokenizer.pad_token_id = tokenizer.eos_token_id

4. 加载数据集

通过MindNLP加载GLUE数据集,并打印样本以便确认数据格式。在此示例中,我们使用GLUE的MRPC(Microsoft Research Paraphrase Corpus)任务,该任务用于句子匹配,即判断两个句子是否表达相同的意思。

datasets = load_dataset("glue", task)
print(next(datasets['train'].create_dict_iterator()))

5. 数据预处理

为了适配MindNLP的数据处理流程,我们定义了一个映射函数 MapFunc,用于将句子转换为 input_idsattention_mask,并对数据进行padding处理。

from mindnlp.dataset import BaseMapFunctionclass MapFunc(BaseMapFunction):def __call__(self, sentence1, sentence2, label, idx):outputs = tokenizer(sentence1, sentence2, truncation=True, max_length=None)return outputs['input_ids'], outputs['attention_mask'], labeldef get_dataset(dataset, tokenizer):input_colums=['sentence1', 'sentence2', 'label', 'idx']output_columns=['input_ids', 'attention_mask', 'labels']dataset = dataset.map(MapFunc(input_colums, output_columns),input_colums, output_columns)dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),'attention_mask': (None, 0)})return datasettrain_dataset = get_dataset(datasets['train'], tokenizer)
eval_dataset = get_dataset(datasets['validation'], tokenizer)

6. 设置评估指标

我们使用 evaluate 模块加载评估指标(accuracy 和 F1-score)来评估模型的性能。

metric = evaluate.load("./glue.py", task)

7. 加载模型并配置Prompt Tuning

加载 roberta-large 模型,并根据配置进行Prompt Tuning。可以看到,微调的参数量仅为总参数量的0.3%左右,节省了大量计算资源。

model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

模型微调(Prompt Tuning)

在Prompt Tuning中,训练过程中仅微调部分参数(主要是虚拟Token相关的参数),相比于传统微调而言,大大减少了需要调整的参数量,使得模型能够高效适应下游任务。

1. 优化器与学习率调整

使用 AdamW 优化器,并设置线性学习率调整策略。

optimizer = AdamW(params=model.parameters(), lr=lr)# Instantiate scheduler
lr_scheduler = get_linear_schedule_with_warmup(optimizer=optimizer,num_warmup_steps=0.06 * (len(train_dataset) * num_epochs),num_training_steps=(len(train_dataset) * num_epochs),
)

2. 训练逻辑定义

训练步骤如下:

  1. 构建正向计算函数 forward_fn
  2. 定义梯度计算函数 grad_fn
  3. 定义每一步的训练逻辑 train_step
  4. 遍历数据集进行训练和评估,在每个 epoch 结束时,计算评估指标。
def forward_fn(**batch):outputs = model(**batch)loss = outputs.lossreturn lossgrad_fn = mindspore.value_and_grad(forward_fn, None, tuple(model.parameters()))def train_step(**batch):loss, grads = grad_fn(**batch)optimizer.step(grads)return lossfor epoch in range(num_epochs):model.set_train()train_total_size = train_dataset.get_dataset_size()for step, batch in enumerate(tqdm(train_dataset.create_dict_iterator(), total=train_total_size)):loss = train_step(**batch)lr_scheduler.step()model.set_train(False)eval_total_size = eval_dataset.get_dataset_size()for step, batch in enumerate(tqdm(eval_dataset.create_dict_iterator(), total=eval_total_size)):outputs = model(**batch)predictions = outputs.logits.argmax(axis=-1)predictions, references = predictions, batch["labels"]metric.add_batch(predictions=predictions,references=references,)eval_metric = metric.compute()print(f"epoch {epoch}:", eval_metric)

在每个 epoch 后,程序输出当前模型的评估指标(accuracy 和 F1-score)。从结果中可以看到,模型的准确率和 F1-score 会随着训练的进展逐渐提升。
7797b4532920b53cb41371e07cfa81c6.png
7797b4532920b53cb41371e07cfa81c6.png

总结

本案例通过Prompt Tuning技术,在Roberta模型上进行了微调以适应GLUE数据集任务。通过控制微调参数量,Prompt Tuning展示了较强的高效性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/60119.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

中国药品注册审批数据库- 药品注册信息查询与审评进度查询方法

药品的注册、审评审批进度信息是医药研发相关人员每天都会关注的信息,为了保证药品注册申请受理及审评审批进度信息的公开透明,CDE药审中心提供药品不同注册分类序列及药品注册申请受理的审评审批进度信息查询服务。但因CDE官网的改版导致很大一部分人不…

FMC 扩展子卡6 路 422,8 组 LVDS,8 路 GPIO

FMC 扩展子卡6 路 422,8 组 LVDS,8 路 GPIO 卡是一款支持多路 LVCMOS 和 LVDS 信号互转的 FMC 扩展子板。它能支持 6 路 422 信号的输入 / 输出 ,8 组 LVDS 信号的输入 / 输出和 8 路 GPIO 信号的输入 / 输出。本产品基于一些逻辑转换芯片而设计,能实现差分信号转单…

项目管理-招标文书都有哪些文件且各自作用

招标文书是招标过程中使用的一系列文件,它们定义了招标的条件、规则、程序和要求。 以下是一些常见的招标文书及其作用: 1. 招标公告/招标邀请书: - 作用:公开告知潜在的投标者有关招标项目的相关信息,包括项目…

C++builder中的人工智能(21):Barabási–Albert model(BA)模型

在此之前,大多数网络被想当然的认为是随机的,因此连接度分布可以近似用泊松分布来表示,而巴拉巴西与其学生阿尔伯特、郑浩雄通过对万维网度分布测量的结果却显示万维网度分布服从幂律分布,存在枢纽节点(拥有大量链接的…

MyBatis3-获取参数值的方式、查询功能及特殊SQL执行

目录 准备工作 获取参数值的方式(重点) 查询功能 查询一个实体类对象 查询一个list集合 查询单个数据 查询一条数据为map集合 查询多条数据为map集合 特殊SQL执行 模糊查询 批量删除 动态设置表名 添加功能获取自增的主键 准备工作 模块My…

链表(Linkedlist)

序言 我们都了解链表是一种数据的存储结构,在Java使用中逻辑与c,c语言数据结构别无二致,但主要由于Java中不存在指针的说法,从而导致在实现过程中的代码不同,所以在学习的过程中我们无需过于担心,逻辑都是…

jupyter添加、删除、查看内核

以下操作均在pytorch环境中操作 1.检查是否有ipykernel python -m ipykernel --version2.没有就安装 python -m pip install ipykernel3.查看内核环境列表 jupyter kernelspec list4.添加内核 python -m ipykernel install --user --name 环境名称 --display-name "在…

【分布式事务】二、NET8分布式事务实践: DotNetCore.CAP 框架 、 消息队列(RabbitMQ)、 多类型数据库(MySql、MongoDB)

介绍 DotNetCore.CAP简称CAP, [CAP]是一个用来解决微服务或者分布式系统中分布式事务问题的一个开源项目解决方案, 同样可以用来作为 EventBus 使用,CAP 拥有自己的特色,它不要求使用者发送消息或者处理消息的时候实现或者继承任何接口,拥有非常高的灵活性。我们一直坚信…

利用pythonstudio写的PDF、图片批量水印生成器,可同时为不同读者生成多组水印

现在很多场合需要将PDF或图片加水印,本程序利用pythonstudio编写。 第一步 界面 其中: LstMask:列表框 PopupMenu:PmnMark LstFiles:列表框 PopupMenu:PmnFiles OdFiles:文件选择器 Filter:PDF文件(.PDF)|.PDF|图像文件(.JPG)|.JPG|图像文件(.png…

面试:TCP、UDP如何解决丢包问题

文章目录 一、TCP丢包原因、解决办法1.1 TCP为什么会丢包1.2 TCP传输协议如何解决丢包问题1.3 其他丢包情况(拓展)1.4 补充1.4.1 TCP端口号1.4.2 多个TCP请求的逻辑1.4.3 处理大量TCP连接请求的方法1.4.4 总结 二、UDP丢包2.1 UDP协议2.1.1 UDP简介2.1.2…

Python的函数(补充浅拷贝和深拷贝)

一、定义 函数的定义:实现【特定功能】的代码块。 形参:函数定义时的参数,没有实际意义 实参:函数调用/使用时的参数,有实际意义 函数的作用: 简化代码提高代码重用性便于维护和修改提高代码的可扩展性…

Spring Boot框架的知识分类技术解析

2 开发技术 2.1 VUE框架 Vue.js(读音 /vjuː/, 类似于 view) 是一套构建用户界面的渐进式框架。 Vue 只关注视图层, 采用自底向上增量开发的设计。 Vue 的目标是通过尽可能简单的 API 实现响应的数据绑定和组合的视图组件。 2.2 Mysql数据库 …

Hive详解

1 Hive基本概念 Hive是一个构建在Hadoop上的数据仓库框架。最初,Hive是由Facebook开发,后来移交由Apache软件基金会开发,并作为一个Apache开源项目。 Hive是基于Hadoop的一个数据仓库工具,可以将结构化的数据文件映射为一张数据…

llamaIndex和langchain对比及优劣对比

一. LangChain vs LlamaIndex: 基本描述 LlamaIndex在搜索和检索任务方面表现出色。它是一个强大的数据索引和查询工具,非常适合需要高级搜索的项目。LlamaIndex能够处理大型数据集,从而实现快速准确的信息检索。 LangChain是一个模块化和灵活的工具集框…

《重学Java设计模式》之 工厂方法模式

《重学Java设计模式》之 建造者模式 《重学Java设计模式》之 原型模式 《重学Java设计模式》之 单例模式 模拟发奖多种商品 工程结构 奖品发放接口 package com.yys.mes.design.factory.store;public interface ICommodity {/*** Author Sherry* Date 14:20 2024/11/6**/voi…

十六:Spring Boot依赖 (1)-- spring-boot-starter 依赖详解

1. 简介: spring-boot-starter 是 Spring Boot 项目中的基础启动器依赖,它为开发者提供了 Spring Boot 应用所需的核心功能和自动配置 spring-boot-starter 不是一个具体的功能模块,而是一个基础的启动器。 Spring Boot 提供了一系列的 sta…

leetcode203. Remove Linked List Elements

Given the head of a linked list and an integer val, remove all the nodes of the linked list that has Node.val val, and return the new head. Input: head [1,2,6,3,4,5,6], val 6 Output: [1,2,3,4,5] 递归法 通过递归的方法去删除节点 递归程序会先一路遍历来到节…

【C++笔记】string类的模拟实现

前言 各位读者朋友们大家好!上期我们讲解了string类的基础用法,这期我们来模拟实现一下string类。 目录 前言一. string类的构造函数1. 1 无参构造2.2 带参构造1.3 无参和带参构造结合1.4 拷贝构造1.5 c_str 二. string类的析构函数三. 字符串的遍历3.…

java中ArrayList的使用存储对象的易错点

ArrayList存储对象的易错点 上面这种写法是有逻辑问题的,因为只创建了一个Student对象,因此最后打印出来的结果是三个最后赋值的结果。 下面我们来形象看下存储关系 集合中存储的始终是第一个对象的地址,而每次输入新的名字和年龄&#xf…

Java NIO实现高性能HTTP代理

NIO采用多路复用IO模型,相比传统BIO(阻塞IO),通过轮询机制检测注册的Channel是否有事件发生,可以实现一个线程处理客户端的多个连接,极大提升了并发性能。 在5年前,本人出于对HTTP正向代理的好…