Transformers参数高效微调之LoRA

简介

LoRA: Low-Rank Adaptation of Large Language Models是微软研究人员为处理微调大语言模型的问题而引入的一项新技术。具有数十亿个参数的强大模型(例如 GPT-3)为了适应特定任务或领域而进行微调的成本非常高。LoRA 建议冻结预先训练的模型权重并注入可训练层 (秩分解矩阵) 在每个变压器块中.这大大减少了可训练参数的数量和 GPU 内存要求,因为不需要为大多数模型权重计算梯度。研究人员发现,通过专注于大语言模型的Transformer注意力块,LoRA的微调质量与完整模型微调相当,同时速度更快需要的计算更少。

LoRA概念

比如你有一个装满乐高积木的大盒子。
这个大盒子里有各种积木,你可以用它们建造房子、汽车、太空飞船等等。
但因为这个盒子又大又重,搬运起来很麻烦。而且,大多数时候你并不需要用到所有积木来完成你的建造任务。
于是,你挑出一小盒你最喜欢、最常用的积木。这小盒子更轻便,便于携带,虽然不如大盒子功能全面,但它足够用来完成大多数你想建造的东西。
在这个比喻中,大盒子的乐高积木就像一个大型语言模型(如 GPT-4):
功能强大,可以处理很多任务,但也需要巨大的计算资源。
而小盒子的乐高积木就像 LoRA(低秩适应):
这是经过特定任务适配后,更小、更轻量化的模型版本。虽然功能不如完整模型全面,但更高效,也更容易使用。

什么是 LoRA(低秩适应)?
LoRA 的全称是 “Low-Rank Adaptation”(低秩适应)。
“低秩”(Low-Rank)是一个数学术语,用来描述生成这种更小、更轻量模型的方法。
你可以把“低秩”理解为只阅读一本书中标注的重点部分。
而“全秩”(Full-Rank)就像是把整本书从头到尾全部读完(全参训练)。

LoRA原理

在这里插入图片描述
这张图是网上实现LoRA原理图

简单理解:在模型的Linear层的旁边,增加一个“旁支”,这个“旁支”的作用,就是代替原有的参数矩阵W进行训练(左侧部分)。

通过图我们来直观地理解一下这个过程,输入 x x x,具有维度 d d d,举个例子,在普通的transformer模型中,这个 x x x可能是embedding的输出,也有可能是上一层transformer layer的输出,而 d d d一般就是768(大多数Bert的输出维度是768)。按照原本的路线,它应该只走左边的部分,也就是原有的模型部分。

而在LoRA的策略下,增加了右侧的“旁支”,也就是先用一个Linear层A,将数据从 d d d维降到 r r r维,这个 r r r也就是LORA的秩,是LoRA中最重要的一个超参数。一般会远远小于 d d d (见的比较多的是4、8),尤其是对于现在的大模型, d d d已经不止是768或者1024,例如LLaMA-7B,每一层transformer有32个head,这样一来 d d d就达到了4096.

接着再用第二个Linear层B,将数据从 r r r变回 d d d维。最后再将左右两部分的结果相加融合,就得到了输出的hidden_state。左侧是会冻结不做计算的。

这样使用LoRA的策略下,计算量会小得多,以小的代价微调出特定任务的模型。

代码实现

LoRA的微调是使用peft中的LoraConfig实现。
那些先看下里面的参数吧。
LoraConfig()参数说明
1、task_type:
描述: 用来指定 LoRA 要适用于的任务类型。不同的任务类型会影响模型中的哪些部分应用 LoRA 以及如何配置 LoRA。根据不同的任务,LoRA 的配置方式可能会有所不同,特别是在模型的某些特定模块(如自注意力层)上。

可选值:

  • “CAUSAL_LM”: 自回归语言模型(Causal Language Modeling)。适用于像 GPT 这样的自回归语言模型,这类模型通常在生成任务上使用。
  • “SEQ_2_SEQ_LM”: 序列到序列语言模型(Sequence-to-Sequence Language Modeling)。适用于像 T5 或 BART 这样的序列到序列模型,这类模型通常用于翻译、摘要生成等任务。
  • “TOKEN_CLS”: 标注任务(Token Classification)。适用于命名实体识别(NER)、词性标注等任务。
  • “SEQ_CLS”: 序列分类(Sequence Classification)。适用于句子分类、情感分析等任务。
  • “QUESTION_ANSWERING”: 问答任务(Question Answering)。适用于问答模型,如 SQuAD 等数据集中的任务。
  • “OTHER”: 适用于其他自定义任务,或者模型的任务类型不明确时。
    2、target_modules:
  • 描述: 指定应用 LoRA 的目标模型模块或层的名称。这些是模型中应用 LoRA 低秩分解的参数,通常是网络中的线性层(如 query, value 矩阵)。
  • 数据类型:Union[List[str], str]
  • 默认值: None
  • 典型值: [“query”, “value”] 或类似参数,具体依赖于模型结构。
    3、r(Rank Reduction Factor):
  • 描述:LoRA 的低秩矩阵的秩(rank)。r 是低秩矩阵的秩,表示将原始权重矩阵分解成两个更小的矩阵,其乘积近似原始权重矩阵。r 越小,模型的计算开销越低。
  • 数据类型:int
  • 典型值:通常在 4 到 64 之间。
    4、lora_alpha:
  • 描述:缩放因子,用于缩放 LoRA 的输出。通常在 LoRA 层的输出会被 lora_alpha / r 缩放,用来平衡学习效率和模型收敛速度。
  • 数据类型:int
  • 典型值:r 的 2 到 32 倍之间。
    5、lora_dropout:
  • 描述:应用于 LoRA 层的 dropout 概率。这个参数用来防止过拟合,特别是在小数据集上训练时,使用 dropout 可以提高模型的泛化能力。
  • 数据类型:float
  • 典型值:0.1 或者更低。
    6、bias:
  • 描述:用于控制是否训练模型的偏置项(bias)。可以设置为 none(不训练 bias)、all(训练所有 bias)、或者 lora_only(仅对 LoRA 层的偏置项进行训练)。
  • 数据类型:str
  • 典型值:none 或 lora_only。
    7、modules_to_save :
  • 描述: 指定除了 LoRA 层之外,还需要保存哪些额外的模块。这通常用于微调时只保存 LoRA 层的权重,同时保存某些特殊的模块(例如全连接层)。
  • 数据类型:Optional[List[str]]
  • 默认值: None
  • 典型值: [“classifier”, “pooler”] 或类似参数。
    8、init_lora_weights :
  • 描述: 控制 LoRA 层的权重是否在初始化时进行随机初始化。如果设置为 True,则会使用标准初始化方法;否则,将不进行初始化。
  • 数据类型:bool
  • 默认值: True
    9、inference_mode :
  • 描述: 如果设置为 True,则模型只在推理阶段使用 LoRA。此模式下,LoRA 的权重会被冻结,不会进行训练。适用于将微调后的模型用于推理场景。
  • 数据类型:bool
  • 默认值: False
代码实现
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq
from transformers import TrainingArguments, Trainer
from peft import LoraConfig, TaskType, get_peft_model# 分词器
tokenizer = AutoTokenizer.from_pretrained("langboat_bloom-1b4-zh")# 函数内将instruction和response拆开分词的原因是:
# 为了便于mask掉不需要计算损失的labels, 即代码labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
def process_func(example):MAX_LENGTH = 256input_ids, attention_mask, labels = [], [], []instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")response = tokenizer(example["output"] + tokenizer.eos_token)input_ids = instruction["input_ids"] + response["input_ids"]attention_mask = instruction["attention_mask"] + response["attention_mask"]labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]if len(input_ids) > MAX_LENGTH:input_ids = input_ids[:MAX_LENGTH]attention_mask = attention_mask[:MAX_LENGTH]labels = labels[:MAX_LENGTH]return {"input_ids": input_ids,"attention_mask": attention_mask,"labels": labels}if __name__ == "__main__":# 加载数据集dataset = Dataset.load_from_disk("./alpaca_data_zh/")# 处理数据tokenized_ds = dataset.map(process_func, remove_columns=dataset.column_names)# print(tokenizer.decode(tokenized_ds[1]["input_ids"]))# print(tokenizer.decode(list(filter(lambda x: x != -100, tokenized_ds[1]["labels"]))))# 创建模型model = AutoModelForCausalLM.from_pretrained("langboat_bloom-1b4-zh", low_cpu_mem_usage=True)#打印模型参数,用于下部target_modules参数选用for name, parameter in model.named_parameters():print(name)#LoraConfig不同模型这里使用的参数是不一样,这块代码不是通用的config = LoraConfig(task_type=TaskType.CAUSAL_LM, target_modules=".*\.1.*query_key_value", modules_to_save=["word_embeddings"])print("config:", config)model = get_peft_model(model, config)print("model:", model)print("print_trainable_parameters:", model.print_trainable_parameters())# 训练参数args = TrainingArguments(output_dir="./chatbot",per_device_train_batch_size=1,gradient_accumulation_steps=8,logging_steps=10,num_train_epochs=1)# trainertrainer = Trainer(model=model,args=args,train_dataset=tokenized_ds,data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True))# 训练模型trainer.train()model = model.cuda()ipt = tokenizer("Human: {}\n{}".format("考试有哪些技巧?", "").strip() + "\n\nAssistant: ", return_tensors="pt").to(model.device)print(tokenizer.decode(model.generate(**ipt, max_length=128, do_sample=True)[0], skip_special_tokens=True))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/64143.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【原生js案例】如何让你的网页实现图片的按需加载

按需加载,这个词应该都不陌生了。我用到你的时候,你才出现就可以了。对于一个很多图片的网站,按需加载图片是优化网站性能的一个关键点。减少无效的http请求,提升网站加载速度。 感兴趣的可以关注下我的系列课程【webApp之h5端实…

博弈论1:拿走游戏(take-away game)

假设你和小红打赌,玩“拿走游戏”,输的人请对方吃饭.... 你们面前有21个筹码,放成一堆;每轮你或者小红可以从筹码堆中拿走1个/2个/3个;第一轮你先拿,第二轮小红拿,你们两个人交替进行;拿走筹码堆…

【论文阅读】IC-Light(ICLR 2025 满分论文)

Scaling In-the-Wild Training for Diffusion-based Illumination Harmonization and Editing by Imposing Consistent Light Transport 原始论文:https://openreview.net/pdf?idu1cQYxRI1H 补充材料:https://openreview.net/attachment?idu1cQYxRI1H&…

Unix 传奇 | 谁写了 Linux | Unix birthmark

注:本文为 “左耳听风”陈皓的 unix 相关文章合辑。 皓侠已走远,文章有点“年头”,但值得一阅。 文中部分超链已沉寂。 Unix 传奇 (上篇) 2010 年 04 月 09 日 陈皓 了解过去,我们才能知其然,更知所以然。总结过去…

记一个framebuffer显示混乱的低级错误

记一个framebuffer显示混乱的低级错误 由于framebuffer的基础知识不扎实,这个任务上我多卡了两天,差点把我搞死,于此记录为后鉴。 打算用awtk做一个多进程项目,计划把framebuffer的内容通过websocket输出到浏览器上去显示画面, …

常用的前端框架介绍

在前端开发中,有许多流行的框架能够帮助开发者更高效地构建用户界面和交互 1. React: • React是一个由Facebook开发的JavaScript库,用于构建用户界面。 • 它使用组件化的思想,将UI拆分成可复用的组件,每个组件都有自…

Kaggler日志-Day4

进度24/12/14 昨日复盘: Pandas课程完成 Intermediate Mechine Learning2/7 今日记录: Intermediate Mechine Learning之类型变量 读两篇讲解如何提问的文章,在提问区里发起一次提问 实战:自己从头到尾首先Housing Prices Compe…

【常考前端面试题总结】---2025

React fiber架构 1.为什么会出现 React fiber 架构? React 15 Stack Reconciler 是通过递归更新子组件 。由于递归执行,所以更新一旦开始,中途就无法中断。当层级很深时,递归更新时间超过了 16ms,用户交互就会卡顿。对于特别庞…

二三(Node2)、Node.js 模块化、package.json、npm 软件包管理器、nodemon、Express、同源、跨域、CORS

1. Node.js 模块化 1.1 CommonJS 标准 utils.js /*** 目标:基于 CommonJS 标准语法,封装属性和方法并导出*/ const baseURL "http://hmajax.itheima.net"; const getArraySum (arr) > arr.reduce((sum, item) > (sum item), 0);mo…

Java爬虫设计:淘宝商品详情接口数据获取

1. 概述 淘宝商品详情接口(如Taobao.item_get)允许开发者通过编程方式,以JSON格式实时获取淘宝商品的详细信息,包括商品标题、价格、销量等。本文档将介绍如何设计一个Java爬虫来获取这些数据。 2. 准备工作 在开始之前&#x…

LeetCode-hot100-73

https://leetcode.cn/problems/largest-rectangle-in-histogram/description/?envTypestudy-plan-v2&envIdtop-100-liked 84. 柱状图中最大的矩形 已解答 困难 相关标签 相关企业 给定 n 个非负整数,用来表示柱状图中各个柱子的高度。每个柱子彼此相邻&#x…

【docker】springboot 服务提交至docker

准备docker (不是docker hub或者harbor,就是可以运行docker run的服务),首先确保docker已经安装。 本文以linux下举例说明: systemctl stats docker ● docker.service - Docker Application Container EngineLoaded…

通过ajax的jsonp方式实现跨域访问,并处理响应

一、场景描述 现有一个项目A,需要请求项目B的某个接口,并根据B接口响应结果A处理后续逻辑。 二、具体实现 1、前端 前端项目A发送请求,这里通过jsonp的方式实现跨域访问。 $.ajax({ url:http://10.10.2.256:8280/ssoCheck, //请求的u…

Unity 沿圆周创建Sphere

思路 取圆上任意一点连接圆心即为半径,以此半径为斜边作直角三角形。当已知圆心位置与半径长度时,即可得该点与圆心在直角三角形两直角边方向上的位置偏移,从而得出该点的位置。 实现 核心代码 offsetX radius * Mathf.Cos(angle * Mathf…

9. 高效利用Excel设置归档Tag

高效利用Excel设置归档Tag 1. Excle批量新建/修改归档Tag2. 趋势记录模型批量导入归档Tag(Method1)2. 趋势记录模型批量导入归档Tag(Method2)3. 趋势记录控件1. Excle批量新建/修改归档Tag Fcatory Talk常常需要归档模拟量,对于比较大的项目工程会有成千上万个重要数据需…

网页端web内容批注插件:

感觉平时每天基本上90%左右的时间都在浏览器端度过,按理说很多资料都应该在web端输入并且输出,但是却有很多时间浪费到了各种桌面app中,比如说什么notion、语雀以及各种笔记软件中,以及导入到ipad的gn中,这些其实都是浪…

数据结构——栈的模拟实现

大家好,今天我要介绍一下数据结构中的一个经典结构——栈。 一:栈的介绍 与顺序表和单链表不同的是: 顺序表和单链表都可以在头部和尾部插入和删除数据,但是栈的结构就锁死了(栈的底部是堵死的)栈只能从…

基于springboot+vue的高校校园交友交流平台设计和实现

文章目录 系统功能部分实现截图 前台模块实现管理员模块实现 项目相关文件架构设计 MVC的设计模式基于B/S的架构技术栈 具体功能模块设计系统需求分析 可行性分析 系统测试为什么我? 关于我项目开发案例我自己的网站 源码获取: 系统功能 校园交友平台…

让文案生成更具灵活性/chatGPT新功能canvas画布编辑

​ ​ OpenAI最近在2024年12月发布了canvas画布编辑功能,这是一项用途广泛的创新工具,专为需要高效创作文案的用户设计。 无论是职场人士、学生还是创作者,这项功能都能帮助快速生成、优化和编辑文案,提升效率的同时提高内容质量…

递归问题(c++)

递归设计思路 数列递归 : 如果一个数列的项与项之间存在关联性,那么可以使用递归实现 ; 原理 : 如果一个函数可以求A(n),那么该函数就可以求A(n-1),就形成了递归调用 ; 注意: 一般起始项是不需要求解的,是已知条件 这就是一个典型…