使用LORA微调RoBERTa

模型微调是指在一个已经训练好的模型的基础上,针对特定任务或者特定数据集进行再次训练以提高性能的过程。微调可以在使其适应特定任务时产生显着的结果。

RoBERTa(Robustly optimized BERT approach)是由Facebook AI提出的一种基于Transformer架构的预训练语言模型。它是对Google提出的BERT(Bidirectional Encoder Representations from Transformers)模型的改进和优化。

“Low-Rank Adaptation”(低秩自适应)是一种用于模型微调或迁移学习的技术。一般来说我们只是使用LORA来微调大语言模型,但是其实只要是使用了Transformers块的模型,LORA都可以进行微调,本文将介绍如何利用🤗PEFT库,使用LORA提高微调过程的效率。

LORA可以大大减少了可训练参数的数量,节省了训练时间、存储和计算成本,并且可以与其他模型自适应技术(如前缀调优)一起使用,以进一步增强模型。

但是,LORA会引入额外的超参数调优层(特定于LORA的秩、alpha等)。并且在某些情况下,性能不如完全微调的模型最优,这个需要根据不同的需求来进行测试。

首先我们安装需要的包:

 !pip install transformers datasets evaluate accelerate peft

数据预处理

 import torchfrom transformers import RobertaModel, RobertaTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPaddingfrom peft import LoraConfig, get_peft_modelfrom datasets import load_datasetpeft_model_name = 'roberta-base-peft'modified_base = 'roberta-base-modified'base_model = 'roberta-base'dataset = load_dataset('ag_news')tokenizer = RobertaTokenizer.from_pretrained(base_model)def preprocess(examples):tokenized = tokenizer(examples['text'], truncation=True, padding=True)return tokenizedtokenized_dataset = dataset.map(preprocess, batched=True,  remove_columns=["text"])train_dataset=tokenized_dataset['train']eval_dataset=tokenized_dataset['test'].shard(num_shards=2, index=0)test_dataset=tokenized_dataset['test'].shard(num_shards=2, index=1)# Extract the number of classess and their namesnum_labels = dataset['train'].features['label'].num_classesclass_names = dataset["train"].features["label"].namesprint(f"number of labels: {num_labels}")print(f"the labels: {class_names}")# Create an id2label mapping# We will need this for our classifier.id2label = {i: label for i, label in enumerate(class_names)}data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")

训练

我们训练两个模型,一个使用LORA,另一个使用完整的微调流程。这里可以看到LORA的训练时间和训练参数的数量能减少多少

以下是使用完整微调

 training_args = TrainingArguments(output_dir='./results',evaluation_strategy='steps',learning_rate=5e-5,num_train_epochs=1,per_device_train_batch_size=16,)

然后进行训练:

 def get_trainer(model):return  Trainer(model=model,args=training_args,train_dataset=train_dataset,eval_dataset=eval_dataset,data_collator=data_collator,)full_finetuning_trainer = get_trainer(AutoModelForSequenceClassification.from_pretrained(base_model, id2label=id2label),)full_finetuning_trainer.train()

下面看看PEFT的LORA

 model = AutoModelForSequenceClassification.from_pretrained(base_model, id2label=id2label)peft_config = LoraConfig(task_type="SEQ_CLS", inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1)peft_model = get_peft_model(model, peft_config)print('PEFT Model')peft_model.print_trainable_parameters()peft_lora_finetuning_trainer = get_trainer(peft_model)peft_lora_finetuning_trainer.train()peft_lora_finetuning_trainer.evaluate()

可以看到

模型参数总计:125,537,288,而LORA模型的训练参数为:888,580,我们只需要用LORA训练~0.70%的参数!这会大大减少内存的占用和训练时间。

在训练完成后,我们保存模型:

 tokenizer.save_pretrained(modified_base)peft_model.save_pretrained(peft_model_name)

最后测试我们的模型

 from peft import AutoPeftModelForSequenceClassificationfrom transformers import AutoTokenizer# LOAD the Saved PEFT modelinference_model = AutoPeftModelForSequenceClassification.from_pretrained(peft_model_name, id2label=id2label)tokenizer = AutoTokenizer.from_pretrained(modified_base)def classify(text):inputs = tokenizer(text, truncation=True, padding=True, return_tensors="pt")output = inference_model(**inputs)prediction = output.logits.argmax(dim=-1).item()print(f'\n Class: {prediction}, Label: {id2label[prediction]}, Text: {text}')# return id2label[prediction]classify( "Kederis proclaims innocence Olympic champion Kostas Kederis today left hospital ahead of his date with IOC inquisitors claiming his ...")classify( "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.")

模型评估

我们还需要对PEFT模型的性能与完全微调的模型的性能进行对比,看看这种方式有没有性能的损失

 from torch.utils.data import DataLoaderimport evaluatefrom tqdm import tqdmmetric = evaluate.load('accuracy')def evaluate_model(inference_model, dataset):eval_dataloader = DataLoader(dataset.rename_column("label", "labels"), batch_size=8, collate_fn=data_collator)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")inference_model.to(device)inference_model.eval()for step, batch in enumerate(tqdm(eval_dataloader)):batch.to(device)with torch.no_grad():outputs = inference_model(**batch)predictions = outputs.logits.argmax(dim=-1)predictions, references = predictions, batch["labels"]metric.add_batch(predictions=predictions,references=references,)eval_metric = metric.compute()print(eval_metric)

首先是没有进行微调的模型,也就是原始模型

 evaluate_model(AutoModelForSequenceClassification.from_pretrained(base_model, id2label=id2label), test_dataset)

accuracy: 0.24868421052631579‘

下面是LORA微调模型

 evaluate_model(inference_model, test_dataset)

accuracy: 0.9278947368421052

最后是完全微调的模型:

 evaluate_model(full_finetuning_trainer.model, test_dataset)

accuracy: 0.9460526315789474

总结

我们使用PEFT对RoBERTa模型进行了微调和评估,可以看到使用LORA进行微调可以大大减少训练的参数和时间,但是在准确性方面还是要比完整的微调要稍稍下降。

本文代码:

https://avoid.overfit.cn/post/26e401b70f9840dab185a6a83aac06b0

作者:Achilles Moraites

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/682499.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python算法之 Dijkstra 算法

文章目录 基本思想:步骤:复杂度:注意事项:代码实现K 站中转内最便宜的航班 Dijkstra 算法是一种用于解决单源最短路径问题的经典算法。该问题的目标是找到从图中的一个固定顶点(称为源点)到图中所有其他顶点…

Linux命令速查表:简洁高效,一表掌握

Linux是一套免费使用和自由传播的类Unix操作系统,是一个基于POSIX和Unix的多用户、多任务、支持多线程和多CPU的操作系统。它能运行主要的Unix工具软件、应用程序和网络协议。它支持32位和64位硬件。Linux继承了Unix以网络为核心的设计思想,是一个性能稳…

四、OpenAI之文本生成模型(Text Generation)

文本生成模型 OpenAI的文本生成模型(也叫做生成预训练的转换器(Generative pre-trained transformers)或大语言模型)已经被训练成可以理解自然语言、代码和图片的模型。模型提供文本的输出作为输入的响应。对这些模型的输入内容也被称作“提示词”。设计提示词的本质是你如何对…

ELAdmin 配置定时任务

定义方法 在自己的 Module 中写个要执行的方法。 比如获取微信公众号的 accessToken,每两个小时更新一次。这种的其实使用 Spring 的 Scheduled 更方便些,此处仅为演示。 package me.zhengjie.mp.task;import com.alibaba.fastjson.JSON; import lombo…

java的面向对象编程(oop)——认识泛型

前言&#xff1a; 打好基础&#xff0c;daydayup! 泛型 1&#xff0c;认识泛型&#xff1a; 定义类&#xff0c;接口&#xff0c;方法时&#xff0c;同时声明了一个或多个类型变量&#xff08;例&#xff1a;<E>&#xff09;,称为泛型&#xff0c;泛型接口&#xff0c;泛…

leetcode算法-位运算

位运算&#xff0c;直接在二进制上进行的按位操作&#xff0c;位运算的种类如下&#xff1a; 1.按位异或^:异或的含义是操作的两位不同&#xff0c;则结果为1&#xff0c;相同则结果为0&#xff0c;所以两个相同的数异或&#xff0c;结果应该是0&#xff0c;3^3的结果是0,3^4的…

springboot743二手交易平台

springboot743二手交易平台 获取源码——》公主号&#xff1a;计算机专业毕设大全

电子元器件基础2---电容

两个相互靠近的导体&#xff0c;中间夹一层不导电的绝缘介质&#xff0c;这就构成了电容器。当电容器的两个极板之间加上电压时&#xff0c;电容器就会储存电荷。电容器的电容量在数值上等于一个导电极板上的电荷量与两个极板之间的电压之比。电容器的电容量的基本单位是法拉(F…

C语言学习day12:for循环

前面学了dowhile循环&#xff0c;今天我们来学习经常用到的for循环&#xff1a; for循环&#xff1a; 例子&#xff1a; int main() {//int i;for (int i 0; i < 10;i) {printf("%d\n",i);};system("pause");return EXIT_SUCCESS; } 解释&#xff…

SNMP 简单网络管理协议、网络管理

目录 1 网络管理 1.1 网络管理的五大功能 1.2 网络管理的一般模型 1.3 网络管理模型中的主要构件 1.4 被管对象 (Managed Object) 1.5 代理 (agent) 1.6 网络管理协议 1.6.1 简单网络管理协议 SNMP 1.6.2 SNMP 的指导思想 1.6.3 SNMP 的管理站和委托代理 1.6.4 SNMP…

Yann LeCun 小传

以下内容整理自 《科学之路》。 Yann LeCun&#xff0c;中文译名杨立昆&#xff0c;1960年出生于法国巴黎附近。他的父亲是一位航空工程师&#xff0c;业余时间喜欢做一些电子产品。 1968年&#xff0c;8岁的杨立昆看了电影《2001太空漫游》。 1978年&#xff0c;高中毕业&a…

AcWing 122 糖果传递(贪心)

[题目概述] 有 n 个小朋友坐成一圈&#xff0c;每人有 a[i] 个糖果。 每人只能给左右两人传递糖果。 每人每次传递一个糖果代价为 1。 求使所有人获得均等糖果的最小代价。 输入格式 第一行输入一个正整数 n&#xff0c;表示小朋友的个数。 接下来 n 行&#xff0c;每行一个…

2.第一个Electron程序

目录 一、前言二、基本运行结构三、代码详解四、打包 一、前言 原文以及系列文章后续请参考&#xff1a;第一个Electron程序 上一章我们完成了Electron的环境搭建&#xff0c;本章就开始详解如何使用Electron开发一个完整的Electron桌面端程序。 注意开发环境&#xff0c;个…

maven仓库的加载步骤

仓库加载步骤 Maven 在判断资源来自哪个仓库时&#xff0c;是根据 Maven 项目的配置以及 Maven 的工作机制来进行判断的。以下是 Maven 判断资源仓库的一般步骤&#xff1a; 从 Maven 项目的 pom.xml 文件中读取配置信息&#xff0c;包括依赖项、仓库设置以及镜像设置。…

【王道数据结构】【chapter5树与二叉树】【P158t6】

二叉树按二叉链表形式存储&#xff0c;试编写一个判别二叉树是否是完全二叉树的算法 #include <iostream> #include <queue> typedef struct treenode{char data;struct treenode *left;struct treenode *right; }treenode,*ptreenode;ptreenode buytreenode(char …

2024.2.14

二维数组实现杨辉三角形 #include<stdio.h> #include<string.h> int main(int argc, const char *argv[]) {int n;scanf("%d",&n);int a[n][n];for(int i0;i<n;i){for(int j0;j<i;j){if(j0||ij){ a[i][j]1;}else{a[i][j]a[i-1][j]a[i-1][j-…

“反内卷”代码书写原则

有人相爱&#xff0c;有人夜里开车看海。有人看着这些代码一句话也说不出来。这是一个你的项目应该遵循的垃圾代码书写准则&#xff0c;只有这样写了才能让人看不懂&#xff0c;这才是真正的反内卷之道&#xff0c;请恪守以下原则&#xff0c;时刻铭记&#xff0c;切勿反向操作…

docker 2:安装

docker 2&#xff1a;安装 ‍ ubuntu 安装 docker sudo apt install docker.io‍ 把当前用户放进 docker 用户组&#xff0c;避免每次运行 docker 命都要使用 sudo​ 或者 root​ 权限。 sudo usermod -aG docker $USER​id $USER ​看到用户已加入 docker 组 ​​ ‍ …

CSS介绍

本章目标&#xff1a; CSS概述 三种样式表 简单选择器 复合选择器 盒子模型 常用背景样式 浮动 常用文本样式 伪类样式 列表样式 表格样式 定位 一、CSS概述: CSS&#xff1a;cascading style sheets-层叠样式表 专门负责对网页的美化 二、有三种使用方式&…

SpringBoot与虚拟线程,接口吞吐量成倍增加,太爽了!

我们看一下如何在spring-boot中利用loom虚拟线程。 我们将做一些负载测试&#xff0c;看看虚拟线程和普通线程的响应时间如何。 让我们快速设置我们的 Spring Boot 项目。 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http:/…