【LLM】PISSA:一种高效的微调方法

前言

介绍PISSA前,先简单过一下LLMs微调经常采用的LoRA(Low-Rank Adaptation)微调的方法,LoRA 假设权重更新的过程中有一个较低的本征秩,对于预训练的权重参数矩阵 W 0 ∈ R d × k W_0 ∈ R^{d×k} W0Rd×k,( d d d 为上一层输出维度, k k k 为下一层输入维度),使用低秩分解来表示其更新:

在训练过程中, W 0 W_0 W0冻结不更新, A A A B B B 包含可训练参数。

则 LoRA 的前向传递函数为:

初始化时,常将低秩矩阵 A A A高斯初始化, B B B初始化为0。这样在训练初期AB接近于零,不会影响模型的输出。

LoRA微调架构

PISSA

三种微调方式架构

从图中可以看出,PISSA和LoRA主要的区别是初始化方式不同:

  • LoRA:使用随机高斯分布初始化 A A A B B B初始化为零。过程中只训练了低秩矩阵 A A A B B B
  • PISSA:同样基于低秩特性的假设,但PISSA不是去近似 ∆ W ∆W W,而是直接对 W W W进行操作。PiSSA使用奇异值分解(SVD)将 W W W分解为两个矩阵 A A A B B B的乘积加上一个残差矩阵 W r e s W^{res} Wres A A A B B B使用 W W W的主奇异值和奇异向量进行初始化,而 W r e s W^{res} Wres则使用剩余的奇异值和奇异向量初始化,并在微调过程中保持不变。也就能保证初始化时和基座模型一样。因此,和LoRA一样,PISSA的训练中也只训练了低秩矩阵 A A A B B B,而 W r e s W^{res} Wres保持冻结

初始化A和B矩阵:使用主要的奇异值和奇异向量初始化两个可训练的矩阵:


构建残差矩阵 W r e s W^{res} Wres:使用残差奇异值和奇异向量构建残差矩阵:

实验

PISSA微调

import torch
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTTrainer
from datasets import load_datasetmodel = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.pad_token_id = tokenizer.eos_token_id
lora_config = LoraConfig(# init_lora_weights="pissa", # Configure the initialization method to "pissa", which may take several minutes to execute SVD on the pre-trained model.init_lora_weights="pissa_niter_4", # Initialize the PiSSA with fast SVD, which completes in just a few seconds.
)
peft_model = get_peft_model(model, lora_config)peft_model.print_trainable_parameters()dataset = load_dataset("imdb", split="train[:1%]")trainer = SFTTrainer(model=peft_model,train_dataset=dataset,dataset_text_field="text",max_seq_length=128,tokenizer=tokenizer,
)
trainer.train()
peft_model.save_pretrained("pissa-llama-2-7b")

pissa加载

import torch
from peft import PeftModel
from transformers import AutoModelForCausalLMmodel = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto"
)
# Performs SVD again to initialize the residual model and loads the state_dict of the fine-tuned PiSSA modules.
peft_model = PeftModel.from_pretrained(model, "pissa-llama-2-7b")

将 PiSSA 转换为 LoRA

import torch
from peft import PeftModel
from transformers import AutoModelForCausalLMmodel = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto"
)
# No SVD is performed during this step, and the base model remains unaltered.
peft_model = PeftModel.from_pretrained(model, "pissa-llama-2-7b-lora")

总结

PiSSA是一种高效的微调方法,它通过奇异值分解提取大型语言模型中的关键参数,并仅对这些参数进行更新,以实现与全参数微调相似的性能,同时显著降低计算成本和参数数量。

参考文献

  • PISSA: PRINCIPAL SINGULAR VALUES AND SINGULAR VECTORS ADAPTATION OF LARGE LANGUAGE MODELS,https://arxiv.org/pdf/2404.02948
  • https://github.com/GraphPKU/PiSSA
  • LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS,https://arxiv.org/pdf/2106.09685
  • https://github.com/microsoft/LoRA

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/32144.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux_内核缓冲区

目录 1、用户缓冲区概念 2、用户缓冲区刷新策略 3、用户缓冲区的好处 4、内核缓冲区 5、验证内核缓冲区 6、用户缓冲区存放的位置 7、全缓冲 结语 前言: Linux下的内核缓冲区存在于系统中,该缓冲区和用户层面的缓冲区不过同一个概念&#x…

数据结构与算法引入(Python)

华子目录 引入第一次尝试第二次尝试 算法的概念算法的五大特性 算法效率衡量执行时间单靠时间值绝对可信吗? 时间复杂度与 "大O记法"如何理解 “大O记法” 最坏时间复杂度时间复杂度的几条基本计算规则 算法分析常见的时间复杂度常见时间复杂度之间的关系…

2024最新版DataGrip安装教程-全网最全教程!!!

1.DataGrip下载安装 1.打开DataGrip官网,选择自己需要的版本下载即可: 2.进行安装: 3.重启打开: 我这个是正版激活码激活的,需要教程可以关注留言

[Redis]持久化机制

众所周知,Redis是内存数据库,也就是把数据存在内存上,读写速度很快,但是,内存的数据容易丢失,为了数据的持久性,还得把数据存储到硬盘上 也就是说,内存有一份数据,硬盘也…

RuoYi Swagger请求401

问题描述: 提示:这里简述项目相关背景: 使用ruoyi-vue分离版,访问swagger,发现接口都调用失败:401 解决方案: 最终解决问题如下步骤: 1、 调用swagger中的接口,报错&a…

【Mysql】DQL操作单表、创建数据库、排序、聚合函数、分组、limit关键字

DQL操作单表 1.1 创建数据库 •创建一个新的数据库 db2 CREATE DATABASE db2 CHARACTER SET utf8;•将db1数据库中的 emp表 复制到当前 db2数据库 ** 1.2 排序** 通过 ORDER BY 子句,可以将查询出的结果进行排序 (排序只是显示效果,不会影响真实数据) 语法结构:…

算法:渐进记号的含义及时间复杂度计算

渐进记号及时间复杂度计算 渐近符号渐近记号 Ω \Omega Ω渐进记号 Θ \Theta Θ渐进记号小 ο \omicron ο渐进记号小 ω \omega ω渐进记号大 O \Omicron O常见的时间复杂度关系 时间复杂度计算:递归方程代入法迭代法套用公式法 渐近符号 渐近记号 Ω \Omega Ω …

每天写java到期末考试--接口1--基础--6.22

规则: 练习: 抽象类的抽象方法 动物类Animal package 期末复习;public abstract class Animal {private String name;private int age;//1.空构造public Animal(){}public Animal(String name,int age){this.ageage;this.namename;}public String getNa…

【C++提高编程-11】----C++ STL常用集合算法

🎩 欢迎来到技术探索的奇幻世界👨‍💻 📜 个人主页:一伦明悦-CSDN博客 ✍🏻 作者简介: C软件开发、Python机器学习爱好者 🗣️ 互动与支持:💬评论 &…

Nginx 负载均衡实现上游服务健康检查

Nginx 负载均衡实现上游服务健康检查 Author:Arsen Date:2024/06/20 目录 Nginx 负载均衡实现上游服务健康检查 前言一、Nginx 部署并新增模块二、健康检查配置2.1 准备 nodeJS 应用程序2.2 Nginx 配置负载均衡健康检查 小结 前言 如果你使用云负载均衡…

【Linux】 yum学习

yum介绍 在Linux系统中,yum(Yellowdog Updater, Modified)是一个用于管理软件包的命令行工具,特别适用于基于RPM(Red Hat Package Manager)的系统,如CentOS、Fedora和Red Hat Enterprise Linux…

【Arduino】实验使用ESP32单片机根据光线变化控制LED小灯开关(图文)

今天小飞鱼继续来实验ESP32的开发,这里使用关敏电阻来配合ESP32做一个我们平常接触比较多的根据光线变化开关灯的实验。当白天时有太阳光,则把小灯关闭;当光线不好或者黑天时,自动打开小灯。 int value;void setup() {pinMode(34…

音视频开发29 FFmpeg 音频编码- 流程以及重要API,该章节使用AAC编码说明

此章节的一些参数,需要先掌握aac的一些基本知识:​​​​​​aac音视频开发13 FFmpeg 音频 --- 常用音频格式AAC,AAC编码器, AAC ADTS格式 。_ffmpeg aac data数据格式-CSDN博客 目的: 从本地⽂件读取PCM数据进⾏AAC格…

【CARD】多变化字幕的上下文感知差异提炼(ACL 2024)

摘要 Multi-change captioning旨在用自然语言描述图像对中的复杂变化。和图像字幕相比,这个任务要求模型具有更高层次的认知能力来推理任意数量的变化。本文提出一种新的上下文感知差异提取网络(CARD)。给定一个图像对,CARD首先解…

Multigranularity and MultiscaleProgressive Contrastive Learning

这篇文章将一张图片划分为四个不同细粒度大小的图片,然后输出四个神经网络,这四个神经网络共享权重,得到四个输出,将这四个输出求交叉熵损失和对比学习损失,共同监督模型学习。 通过对比学习,最大化一个Bat…

Microsoft Edge无法启动搜索问题的解决

今天本来想清一下电脑,看到visual studio2022没怎么用了就打算卸载掉。然后看到网上有篇文章说进入C盘的ProgramFiles(x86)目录下的microsoft目录下的microsoft visual studio目录下的install目录中,双击InstallCleanup.exe&#…

Windows环境利用 OpenCV 中 CascadeClassifier 分类器识别人脸 c++

Windows环境中配置OpenCV 关于在Windows环境中配置opencv的说明,具体可以参考:VS2022 配置OpenCV开发环境详细教程。 CascadeClassifier 分类器 CascadeClassifier 是 OpenCV 库中的一个类,它用于实现一种快速的物体检测算法,称…

API接口技术开发分享;按关键字搜索淘宝、天猫商品API返回值接入说明

淘宝数据API的接入流程主要包括注册key账号、创建开发者应用、获取ApiKey和ApiSecret、申请API权限等步骤。淘通过这些接口可以获取商品、订单、用户、营销和物流管理等多方面的数据。以下是关于淘宝数据API接入流程的相关介绍: 注册key账号:进行账号注册…

JAVA医院绩效考核系统源码 功能特点:大型医院绩效考核系统源码

JAVA医院绩效考核系统源码 功能特点:大型医院绩效考核系统源码 医院绩效管理系统主要用于对科室和岗位的工作量、工作质量、服务质量进行全面考核,并对科室绩效工资和岗位绩效工资进行核算的系统。医院绩效管理系统开发主要用到的管理工具有RBRVS、DRGS…

AUCell和AddModuleScore函数进行基因集评分

AUCell 和AddModuleScore 分析是两种主流的用于单细胞RNA测序数据的基因集活性分析的方法。这些基因集可以来自文献、数据库或者根据具体研究问题进行自行定义。 AUCell分析原理: 1、AUCell分析可以将细胞中的所有基因按表达量进行排序,生成一个基因排…