LORA概述: 大语言模型的低阶适应

LORA概述: 大语言模型的低阶适应

  • LORA: 大语言模型的低阶适应
    • 前言
    • 摘要
    • 论文十问
    • 实验
      • RoBERTa
      • DeBERTa
      • GPT-2
      • GPT-3
    • 结论
    • 代码调用

LORA: 大语言模型的低阶适应

前言

LoRA的核心思想在于优化预训练语言模型的微调过程,通过有效地处理权重矩阵的变化(即梯度更新的累积),使其具有“低秩”结构。简而言之,这意味着可以通过低秩分解有效地表示变化矩阵。

具体来说,对于预训练权重矩阵W₀,其更新量可以表示为∆W = BA,其中B和A都是低秩矩阵(例如,秩为r,r明显小于矩阵维度d)。在训练期间,W₀被冻结,而B和A中的参数是可训练的。这明显减少了适应的可训练参数数量。

如下图所示,在原始预训练语言模型旁边添加一个旁路,执行降维再升维的操作,以模拟内在秩。在训练过程中,固定预训练语言模型的参数,只训练降维矩阵A和升维矩阵B。模型的输入输出维度保持不变,输出时将BA与预训练语言模型的参数叠加。矩阵A使用随机高斯分布进行初始化,而矩阵B则使用零矩阵进行初始化,以确保在训练开始时,该旁路矩阵仍然是零矩阵。

在这里插入图片描述

摘要

自然语言处理的一个重要范式包括在通用域数据上进行大规模预训练,以及针对特定任务或域进行适配。随着我们预训练更大的模型,全面微调,即重新训练所有模型参数,变得更加不可行。

以GPT-3 175B为例,单独部署经过微调的独立实例模型,每个实例拥有1750亿个参数,是极其昂贵的。我们提出了低秩适应(LoRA)方法,其中冻结预训练模型权重,并在transformer体系结构的每个层中插入可训练的低秩分解矩阵,从而大大减少下游任务的可训练参数数量。

与使用Adam微调GPT-3 175B相比,LoRA可以将可训练参数数量减少10000倍,GPU内存需求减少3倍。尽管只有更少的可训练参数和更高的训练吞吐量,但LoRA在RoBERTa、DeBERTa、GPT-2和GPT-3上的性能优于或等同于微调。

我们还对语言模型适配中的秩缺失进行了实证研究,这解释了LoRA的功效。我们发布了一个软件包,可以方便地将LoRA与PyTorch模型集成,并为RoBERTa、DeBERTa和GPT-2提供了我们的实现和模型检查点。

论文十问

  1. 论文试图解决什么问题?

这篇论文试图解决大规模预训练语言模型(如GPT-3)微调(fine-tuning)所带来的巨大的存储、部署和任务切换成本的问题。

  1. 这是否是一个新的问题?

这不是一个全新的问题,但随着 transformer 语言模型规模的不断增长(如 GPT-3 175B 参数),这个问题的严重性在增加。论文中也提到了许多相关的已有工作。

  1. 这篇文章要验证一个什么科学假设?

这篇文章的主要科学假设是微调过程中模型参数的变化矩阵具有低秩结构(rank-deficient)。基于这个假设,作者提出了低秩适应(LoRA)方法来有效地适应下游任务。

  1. 有哪些相关研究?如何归类?谁是这一课题在领域内值得关注的研究员?

相关的工作包括适配器模块、prompt tuning等参数高效适应方法。

  1. 论文中提到的解决方案之关键是什么?

LoRA的关键是只训练插入到每个transformer层中的低秩分解矩阵,而保持预训练权重固定。这大大降低了适应的可训练参数数量。

  1. 论文中的实验是如何设计的?

在 RoBERTa、DeBERTa、GPT-2 和 GPT-3 等模型上进行了大量实验比较。实验设计针对性强,测试了性能和参数数量的权衡。

  1. 用于定量评估的数据集是什么?代码有没有开源?

使用的数据集包括 GLUE、WikiSQL、SAMSum 等。实验代码和模型检查点开源。

  1. 论文中的实验及结果有没有很好地支持需要验证的科学假设?

是的,丰富的实验验证了 LoRA 在性能、存储效率、训练速度等方面都优于或匹敌全微调基线,支持了低秩适应的有效性。

  1. 这篇论文到底有什么贡献?

主要贡献是提出 LoRA 方法,大幅降低大模型微调的成本,并给出可复现的实验验证。

  1. 下一步呢?有什么工作可以继续深入?

下一步可以考虑与其他高效适应方法(如prompt tuning)的结合,解释微调过程中模型内部表示的变化,进一步提高 LoRA 的泛化性等。

实验

评估了 LoRA 在 RoBERTa (Liu et al., 2019)、DeBERTa (He et al., 2021) 和 GPT-2 (Radford etal., b) 上的下游任务性能,然后再扩展到 GPT- 3 175B(布朗等人,2020 年)

RoBERTa

RoBERTa是Facebook AI于2019年提出的语言表示模型。相比BERT有更优化的预训练步骤,性能更好,参数规模类似,分Base和Large两个版本。实验中分别使用了1.25亿参数和3.55亿参数的RoBERTa模型。

DeBERTa

微软于2020年提出的改进型BERT模型。采用多任务预训练、增强型注意力机制等技术。实验中使用了极大规模的DeBERTa XXL,包含了1500亿参数。
在这里插入图片描述

GPT-2

OpenAI于2019年提出的基于Transformer的语言生成模型GPT-2。模型架构采用了解码器,支持自回归文本生成。实验分别基于中等规模(3.54亿参数)和大规模(7.74亿参数)的GPT-2进行。
在这里插入图片描述

GPT-3

OpenAI于2020年发布的巨大语言模型, Transformer规模达到了1750亿参数,是当时最大的神经语言模型。论文使用了这个极具挑战性的大模型进行扩展实验。这几种模型的选择,可以让作者全面验证LoRA适配方法在不同规模的Transformer类模型中的有效性。覆盖了目前最典型和最前沿的语言表示与生成模型。
在这里插入图片描述

结论

实际好处:

  1. 内存和存储使用减少: 在使用Adam训练的大型Transformer中,通过使用LoRA,显著减少了VRAM(显存)和存储的使用量。例如,在GPT-3 175B上,将训练期间的VRAM消耗从1.2TB减少到350GB。
  2. 检查点大小减小: 在一定条件下,检查点大小减少了大约10,000倍,从350GB减少到35MB。这降低了GPU训练的硬件需求,并避免了I/O瓶颈。
  3. 任务切换成本降低: LoRA允许在任务之间进行切换,通过仅交换LoRA权重而不是所有参数,降低了部署的成本。这使得可以在机器上动态换入和换出预训练权重,创建自定义模型。
  4. 加速训练: 在GPT-3 175B的训练中,相较于完全微调,观察到25%的加速,因为不需要计算绝大多数参数的梯度。

局限性:

  1. 前向传递复杂性: 吸收不同任务的A和B到W中,以消除额外推理延迟,在单个前向传递中批量输入并不简单。需要考虑不同任务的权重合并和动态选择LoRA模块的复杂性。
  2. 推理延迟问题: 尽管可以动态选择LoRA模块以处理不同任务的推理延迟,但在一些场景中,合并权重可能引入不可避免的问题。

代码调用

使用 🤗 PEFT 训练您的模型

下面的示例是使用 LoRA 进行微调的情况。

from transformers import AutoModelForSeq2SeqLM
from peft import PeftModel, PeftConfigpeft_model_id = "smangrul/twitter_complaints_bigscience_T0_3B_LORA_SEQ_2_SEQ_LM"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, peft_model_id)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)model = model.to(device)
model.eval()
inputs = tokenizer("Tweet text : @HondaCustSvc Your customer service has been horrible during the recall process. I will never purchase a Honda again. Label :", return_tensors="pt")with torch.no_grad():outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), max_new_tokens=10)print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/198696.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

常用sql记录

备份一张表 PostgreSQL CREATE TABLE new_table AS SELECT * FROM old_table;-- 下面这个比上面好,这个复制表结构时,会把默认值、约束、注释都复制 CREATE TABLE new_table (LIKE old_table INCLUDING ALL) WITHOUT OIDS; INSERT INTO new_table SELE…

极智芯 | 解读国产AI算力 璧仞产品矩阵

欢迎关注我,获取我的更多经验分享 大家好,我是极智视界,本文分享一下 解读国产AI算力 璧仞产品矩阵。 璧仞在国产 AI 芯领域就是 "迷" 一样的存在,你要说它在市场上的 "建树" 泛善可陈的话,它又 "赫然" 在美国芯片禁令名单中。而这一切的一…

Vue2学习笔记(列表渲染)

案例由浅到深&#xff0c;逐步深入。 一 、案例1&#xff1a;v-for的使用方法 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scal…

跨网文件摆渡系统:安全、可控的数字传输桥梁

在企业高度信息化的时代&#xff0c;数据的流通与共享已经成为企业、组织乃至个人之间不可或缺的沟通方式。然而&#xff0c;在数据流通的过程中&#xff0c;我们经常会遇到各种难题和挑战&#xff0c;尤其是当涉及到不同网络环境之间的文件传输。这不仅需要保证文件的安全性&a…

MySQL的多表查询

<!DOCTYPE html> <html> <head> <meta charset"UTF-8" /> <title>多表查询</title> </head> <body> <!-- 多表关系 概述&#xff1a;基本上分为三种&#xff1a; 一对多&#xff08;多对一&#xff09; 案例…

09、pytest多种调用方式

官方用例 # content of myivoke.py import sys import pytestclass MyPlugin:def pytest_sessionfinish(self):print("*** test run reporting finishing")if __name__ "__main__":sys.exit(pytest.main(["-qq"],plugins[MyPlugin()]))# conte…

烤鱼纸包鱼外卖配送小程序商城作用是什么

烤鱼、纸包鱼等餐品是聚会、娱乐、餐食等场景中常见的餐品&#xff0c;到店和外送都有较高需求度&#xff0c;对消费者来说需要找到美味的餐厅和快速享受到美食的流程&#xff1b;对商家来说是如何找到更多消费&#xff0c;并且能快速转化和持续复购及相应的管理。 线下竞争激…

第三方支付原理

1.什么是第三方支付 所谓第三方支付&#xff0c;就是一些和各大银行签约、并具备一定实力和信誉保障的第三方独立机构提供的交易支持平台。在通过第三方支付平台的交易中&#xff0c;买方选购商品后&#xff0c;使用第三方平台提供的账户进行货款支付&#xff0c;由第三方通知卖…

【langchain实战】开源项目-RasaGpt

1、概述 RasaGpt是一个建立在 Rasa 和 Langchain 之上的没有显示界面的LMM聊天机器人平台。它是一个Rasa和Telegram这种利用像Langchain这样的LMM库进行索引、检索和上下文注入的样板及参考实现。 开源地址&#xff1a; GitHub - paulpierre/RasaGPT: &#x1f4ac; RasaGPT is…

接口压测指南

接口压测指南 一、 为什么需要进行接口压测二 、接口压测的目标是什么三、 用什么工具进行接口压测四、 接口压测核心指标4.1 JMeter的报告模板4.2 ApiPost报告模板 五、 接口慢如何排查5.1 大体排查思路5.2 排查工具5.3 压测经验 一、 为什么需要进行接口压测 突然有一天领导…

漏洞复现--万户ezoffice wpsservlet任意文件上传

免责声明&#xff1a; 文章中涉及的漏洞均已修复&#xff0c;敏感信息均已做打码处理&#xff0c;文章仅做经验分享用途&#xff0c;切勿当真&#xff0c;未授权的攻击属于非法行为&#xff01;文章中敏感信息均已做多层打马处理。传播、利用本文章所提供的信息而造成的任何直…

虚拟机-桥接模式连接

文章目录 1.查看宿主机再用的IP信息2.桥接模式-虚拟机设置VMware设置虚拟机设置重启网络服务 1.查看宿主机再用的IP信息 ipconfig /all 注&#xff1a; 在虚拟机中要设置同网段的ip设置同一个子网掩码设置同一个网关设置同一个DNS服务器 2.桥接模式-虚拟机设置 VMware设置 虚…

JS Object.values()

一、官方定义 返回一个给定对象的自有可枚举字符串键属性值组成的数组 二、语法 Object.values(obj)参数 obj 被返回可枚举属性值的对象。返回值 一个包含对象自身的所有可枚举属性值的数组。描述 Object.values() 返回一个数组&#xff0c;其元素是在对象上找到的可枚举…

Python 高性能 web 框架 - FastApi 全面指南

原文&#xff1a;Python 高性能 web 框架 - FastApi 全面指南 - 知乎 一、简介 FastAPI 是一个用于构建 API 的现代、快速&#xff08;高性能&#xff09;的 web 框架&#xff0c;使用 Python 3.6 并基于标准的 Python 类型提示。 它具有如下这些优点&#xff1a; 快速&…

C盘爆满,python pip无法安装应用

解决方法1 C盘扩容 从其他盘压缩空间&#xff0c;C盘使用压缩的空间进行扩容&#xff0c;治标不治本&#xff0c;以后C盘还会越来越大 解决方法2 转移pip安装目录 1. 获取显示pip安装目录 C:\Users\biewang>pip show pip Name: pip Version: 23.3.1 Summary: The PyPA r…

PHP基础知识和操作

PHP在线运行 https://c.runoob.com/compile/1/ https://www.sotool.net/php80 将驼峰字符串转化为蛇形字符串 <?phpfunction CamelToSnake($camelValue) {$initValue preg_replace(/\s/u, , $camelValue);$snakeValue strtolower(preg_replace(/(.)(?[A-Z])/u, &quo…

【实战技能】 单步运行源码分析,一期视频整明白FreeRTOS内核源码框架和运行机制,RTOS Trace链表功能展示

从源码的角度来看&#xff0c;OS内核源码就是通过各种链表组装起来的&#xff0c;FreeRTOS就是下面几个链表组成的。FreeRTOS的调度&#xff0c;任务切换就是倒腾这几个链表。而其它的几款OS是一个链表就一撸到底了&#xff0c;FreeRTOS是搞了好几个。所以视频里面就重点介绍下…

Docker容器资源限制 CPU / 内存 / 磁盘

在一台物理机上启动了多个docker容器时,就需要对内存及cpu做出相关的限制,以达到容器互不影响的目的。 1.限制容器对内存的使用 ⼀个 docker host 上会运⾏若⼲容器,每个容器都需要 CPU、内存和 IO 资源。对于 KVM,VMware 等虚拟化技 术,⽤户可以控制分配多少 CPU、内存…

代码随想录-刷题第十七天

110.平衡二叉树 题目链接&#xff1a;110. 平衡二叉树 思路&#xff1a;其实是判断左右子树的高度差是否大于1。判断高度的话仍然是采用后序遍历。 求深度可以从上到下去查&#xff0c;所以需要前序遍历&#xff08;中左右&#xff09;&#xff0c;而高度只能从下到上去查&a…

使用Redis做动态页面缓存,提高网页访问速度

目的 本关目的&#xff1a;实现使用Redis缓存网页。 相关知识 本文将教会你掌握&#xff1a;1&#xff0e;SETEX命令&#xff0c;2&#xff0e;hash()方法。 在动态生成网页的时候&#xff0c;通常会使用模板&#xff08;template&#xff09;来简化网页的生成&#xff0c;…