使用LORA微调RoBERTa

模型微调是指在一个已经训练好的模型的基础上,针对特定任务或者特定数据集进行再次训练以提高性能的过程。微调可以在使其适应特定任务时产生显着的结果。

RoBERTa(Robustly optimized BERT approach)是由Facebook AI提出的一种基于Transformer架构的预训练语言模型。它是对Google提出的BERT(Bidirectional Encoder Representations from Transformers)模型的改进和优化。

“Low-Rank Adaptation”(低秩自适应)是一种用于模型微调或迁移学习的技术。一般来说我们只是使用LORA来微调大语言模型,但是其实只要是使用了Transformers块的模型,LORA都可以进行微调,本文将介绍如何利用🤗PEFT库,使用LORA提高微调过程的效率。

LORA可以大大减少了可训练参数的数量,节省了训练时间、存储和计算成本,并且可以与其他模型自适应技术(如前缀调优)一起使用,以进一步增强模型。

但是,LORA会引入额外的超参数调优层(特定于LORA的秩、alpha等)。并且在某些情况下,性能不如完全微调的模型最优,这个需要根据不同的需求来进行测试。

首先我们安装需要的包:

 !pip install transformers datasets evaluate accelerate peft

数据预处理

 import torchfrom transformers import RobertaModel, RobertaTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPaddingfrom peft import LoraConfig, get_peft_modelfrom datasets import load_datasetpeft_model_name = 'roberta-base-peft'modified_base = 'roberta-base-modified'base_model = 'roberta-base'dataset = load_dataset('ag_news')tokenizer = RobertaTokenizer.from_pretrained(base_model)def preprocess(examples):tokenized = tokenizer(examples['text'], truncation=True, padding=True)return tokenizedtokenized_dataset = dataset.map(preprocess, batched=True,  remove_columns=["text"])train_dataset=tokenized_dataset['train']eval_dataset=tokenized_dataset['test'].shard(num_shards=2, index=0)test_dataset=tokenized_dataset['test'].shard(num_shards=2, index=1)# Extract the number of classess and their namesnum_labels = dataset['train'].features['label'].num_classesclass_names = dataset["train"].features["label"].namesprint(f"number of labels: {num_labels}")print(f"the labels: {class_names}")# Create an id2label mapping# We will need this for our classifier.id2label = {i: label for i, label in enumerate(class_names)}data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")

训练

我们训练两个模型,一个使用LORA,另一个使用完整的微调流程。这里可以看到LORA的训练时间和训练参数的数量能减少多少

以下是使用完整微调

 training_args = TrainingArguments(output_dir='./results',evaluation_strategy='steps',learning_rate=5e-5,num_train_epochs=1,per_device_train_batch_size=16,)

然后进行训练:

 def get_trainer(model):return  Trainer(model=model,args=training_args,train_dataset=train_dataset,eval_dataset=eval_dataset,data_collator=data_collator,)full_finetuning_trainer = get_trainer(AutoModelForSequenceClassification.from_pretrained(base_model, id2label=id2label),)full_finetuning_trainer.train()

下面看看PEFT的LORA

 model = AutoModelForSequenceClassification.from_pretrained(base_model, id2label=id2label)peft_config = LoraConfig(task_type="SEQ_CLS", inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1)peft_model = get_peft_model(model, peft_config)print('PEFT Model')peft_model.print_trainable_parameters()peft_lora_finetuning_trainer = get_trainer(peft_model)peft_lora_finetuning_trainer.train()peft_lora_finetuning_trainer.evaluate()

可以看到

模型参数总计:125,537,288,而LORA模型的训练参数为:888,580,我们只需要用LORA训练~0.70%的参数!这会大大减少内存的占用和训练时间。

在训练完成后,我们保存模型:

 tokenizer.save_pretrained(modified_base)peft_model.save_pretrained(peft_model_name)

最后测试我们的模型

 from peft import AutoPeftModelForSequenceClassificationfrom transformers import AutoTokenizer# LOAD the Saved PEFT modelinference_model = AutoPeftModelForSequenceClassification.from_pretrained(peft_model_name, id2label=id2label)tokenizer = AutoTokenizer.from_pretrained(modified_base)def classify(text):inputs = tokenizer(text, truncation=True, padding=True, return_tensors="pt")output = inference_model(**inputs)prediction = output.logits.argmax(dim=-1).item()print(f'\n Class: {prediction}, Label: {id2label[prediction]}, Text: {text}')# return id2label[prediction]classify( "Kederis proclaims innocence Olympic champion Kostas Kederis today left hospital ahead of his date with IOC inquisitors claiming his ...")classify( "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.")

模型评估

我们还需要对PEFT模型的性能与完全微调的模型的性能进行对比,看看这种方式有没有性能的损失

 from torch.utils.data import DataLoaderimport evaluatefrom tqdm import tqdmmetric = evaluate.load('accuracy')def evaluate_model(inference_model, dataset):eval_dataloader = DataLoader(dataset.rename_column("label", "labels"), batch_size=8, collate_fn=data_collator)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")inference_model.to(device)inference_model.eval()for step, batch in enumerate(tqdm(eval_dataloader)):batch.to(device)with torch.no_grad():outputs = inference_model(**batch)predictions = outputs.logits.argmax(dim=-1)predictions, references = predictions, batch["labels"]metric.add_batch(predictions=predictions,references=references,)eval_metric = metric.compute()print(eval_metric)

首先是没有进行微调的模型,也就是原始模型

 evaluate_model(AutoModelForSequenceClassification.from_pretrained(base_model, id2label=id2label), test_dataset)

accuracy: 0.24868421052631579‘

下面是LORA微调模型

 evaluate_model(inference_model, test_dataset)

accuracy: 0.9278947368421052

最后是完全微调的模型:

 evaluate_model(full_finetuning_trainer.model, test_dataset)

accuracy: 0.9460526315789474

总结

我们使用PEFT对RoBERTa模型进行了微调和评估,可以看到使用LORA进行微调可以大大减少训练的参数和时间,但是在准确性方面还是要比完整的微调要稍稍下降。

本文代码:

https://avoid.overfit.cn/post/26e401b70f9840dab185a6a83aac06b0

作者:Achilles Moraites

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/682499.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python算法之 Dijkstra 算法

文章目录 基本思想:步骤:复杂度:注意事项:代码实现K 站中转内最便宜的航班 Dijkstra 算法是一种用于解决单源最短路径问题的经典算法。该问题的目标是找到从图中的一个固定顶点(称为源点)到图中所有其他顶点…

四、OpenAI之文本生成模型(Text Generation)

文本生成模型 OpenAI的文本生成模型(也叫做生成预训练的转换器(Generative pre-trained transformers)或大语言模型)已经被训练成可以理解自然语言、代码和图片的模型。模型提供文本的输出作为输入的响应。对这些模型的输入内容也被称作“提示词”。设计提示词的本质是你如何对…

ELAdmin 配置定时任务

定义方法 在自己的 Module 中写个要执行的方法。 比如获取微信公众号的 accessToken,每两个小时更新一次。这种的其实使用 Spring 的 Scheduled 更方便些,此处仅为演示。 package me.zhengjie.mp.task;import com.alibaba.fastjson.JSON; import lombo…

java的面向对象编程(oop)——认识泛型

前言&#xff1a; 打好基础&#xff0c;daydayup! 泛型 1&#xff0c;认识泛型&#xff1a; 定义类&#xff0c;接口&#xff0c;方法时&#xff0c;同时声明了一个或多个类型变量&#xff08;例&#xff1a;<E>&#xff09;,称为泛型&#xff0c;泛型接口&#xff0c;泛…

springboot743二手交易平台

springboot743二手交易平台 获取源码——》公主号&#xff1a;计算机专业毕设大全

电子元器件基础2---电容

两个相互靠近的导体&#xff0c;中间夹一层不导电的绝缘介质&#xff0c;这就构成了电容器。当电容器的两个极板之间加上电压时&#xff0c;电容器就会储存电荷。电容器的电容量在数值上等于一个导电极板上的电荷量与两个极板之间的电压之比。电容器的电容量的基本单位是法拉(F…

C语言学习day12:for循环

前面学了dowhile循环&#xff0c;今天我们来学习经常用到的for循环&#xff1a; for循环&#xff1a; 例子&#xff1a; int main() {//int i;for (int i 0; i < 10;i) {printf("%d\n",i);};system("pause");return EXIT_SUCCESS; } 解释&#xff…

SNMP 简单网络管理协议、网络管理

目录 1 网络管理 1.1 网络管理的五大功能 1.2 网络管理的一般模型 1.3 网络管理模型中的主要构件 1.4 被管对象 (Managed Object) 1.5 代理 (agent) 1.6 网络管理协议 1.6.1 简单网络管理协议 SNMP 1.6.2 SNMP 的指导思想 1.6.3 SNMP 的管理站和委托代理 1.6.4 SNMP…

AcWing 122 糖果传递(贪心)

[题目概述] 有 n 个小朋友坐成一圈&#xff0c;每人有 a[i] 个糖果。 每人只能给左右两人传递糖果。 每人每次传递一个糖果代价为 1。 求使所有人获得均等糖果的最小代价。 输入格式 第一行输入一个正整数 n&#xff0c;表示小朋友的个数。 接下来 n 行&#xff0c;每行一个…

2.第一个Electron程序

目录 一、前言二、基本运行结构三、代码详解四、打包 一、前言 原文以及系列文章后续请参考&#xff1a;第一个Electron程序 上一章我们完成了Electron的环境搭建&#xff0c;本章就开始详解如何使用Electron开发一个完整的Electron桌面端程序。 注意开发环境&#xff0c;个…

【王道数据结构】【chapter5树与二叉树】【P158t6】

二叉树按二叉链表形式存储&#xff0c;试编写一个判别二叉树是否是完全二叉树的算法 #include <iostream> #include <queue> typedef struct treenode{char data;struct treenode *left;struct treenode *right; }treenode,*ptreenode;ptreenode buytreenode(char …

2024.2.14

二维数组实现杨辉三角形 #include<stdio.h> #include<string.h> int main(int argc, const char *argv[]) {int n;scanf("%d",&n);int a[n][n];for(int i0;i<n;i){for(int j0;j<i;j){if(j0||ij){ a[i][j]1;}else{a[i][j]a[i-1][j]a[i-1][j-…

docker 2:安装

docker 2&#xff1a;安装 ‍ ubuntu 安装 docker sudo apt install docker.io‍ 把当前用户放进 docker 用户组&#xff0c;避免每次运行 docker 命都要使用 sudo​ 或者 root​ 权限。 sudo usermod -aG docker $USER​id $USER ​看到用户已加入 docker 组 ​​ ‍ …

CSS介绍

本章目标&#xff1a; CSS概述 三种样式表 简单选择器 复合选择器 盒子模型 常用背景样式 浮动 常用文本样式 伪类样式 列表样式 表格样式 定位 一、CSS概述: CSS&#xff1a;cascading style sheets-层叠样式表 专门负责对网页的美化 二、有三种使用方式&…

SpringBoot与虚拟线程,接口吞吐量成倍增加,太爽了!

我们看一下如何在spring-boot中利用loom虚拟线程。 我们将做一些负载测试&#xff0c;看看虚拟线程和普通线程的响应时间如何。 让我们快速设置我们的 Spring Boot 项目。 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http:/…

【开源图床】使用Typora+PicGo+Gitee搭建个人博客图床

准备工作&#xff1a; 首先电脑得提前完成安装如下&#xff1a; 1. nodejs环境(node ,npm):【安装指南】nodejs下载、安装与配置详细教程 2. Picgo:【安装指南】图床神器之Picgo下载、安装与配置详细教程 3. Typora:【安装指南】markdown神器之Typora下载、安装与无限使用详细教…

JavaScript中如何判断数据类型

在JavaScript中&#xff0c;判断数据类型是我们在日常开发中经常会遇到的问题。正确地判断数据类型不仅有助于我们编写出更加健壮的代码&#xff0c;还可以提高程序的可读性和可维护性。本文将为大家介绍几种判断数据类型的方法 使用typeof运算符&#xff1a; typeof运算符可以…

[NSSRound#17 Basic]WEB

1.真签到 看robots.txt 密码先base32再base64得到md5加密的密文&#xff0c;在线解得到密码为Nss hint用16进制转字符串&#xff0c;提示新生赛遇到过 是一个敲击码加密 账号是ctfer,登录之后源码提示在F111n4l.php 要求nss参数若比较等于732339662&#xff0c;但是不能是数…

关于idea无法检测出lombok,导致代码爆红的处理

为啥需要本地安装lombok插件&#xff1f; 编译错误提示&#xff1a;Lombok 使用注解来自动生成代码&#xff0c;这些代码在编译时会由 Lombok 插件进行处理。如果没有安装 Lombok 插件&#xff0c;IDEA 在编译过程中可能会报告错误&#xff0c;因为它无法识别并处理 Lombok 注解…

【RL】Bellman Optimality Equation(贝尔曼最优等式)

Lecture3: Optimal Policy and Bellman Optimality Equation Definition of optimal policy state value可以被用来去评估policy的好坏&#xff0c;如果&#xff1a; v π 1 ( s ) ≥ v π 2 ( s ) for all s ∈ S v_{\pi_1}(s) \ge v_{\pi_2}(s) \;\;\;\;\; \text{for all…