【LLM】PISSA:一种高效的微调方法

前言

介绍PISSA前,先简单过一下LLMs微调经常采用的LoRA(Low-Rank Adaptation)微调的方法,LoRA 假设权重更新的过程中有一个较低的本征秩,对于预训练的权重参数矩阵 W 0 ∈ R d × k W_0 ∈ R^{d×k} W0Rd×k,( d d d 为上一层输出维度, k k k 为下一层输入维度),使用低秩分解来表示其更新:

在训练过程中, W 0 W_0 W0冻结不更新, A A A B B B 包含可训练参数。

则 LoRA 的前向传递函数为:

初始化时,常将低秩矩阵 A A A高斯初始化, B B B初始化为0。这样在训练初期AB接近于零,不会影响模型的输出。

LoRA微调架构

PISSA

三种微调方式架构

从图中可以看出,PISSA和LoRA主要的区别是初始化方式不同:

  • LoRA:使用随机高斯分布初始化 A A A B B B初始化为零。过程中只训练了低秩矩阵 A A A B B B
  • PISSA:同样基于低秩特性的假设,但PISSA不是去近似 ∆ W ∆W W,而是直接对 W W W进行操作。PiSSA使用奇异值分解(SVD)将 W W W分解为两个矩阵 A A A B B B的乘积加上一个残差矩阵 W r e s W^{res} Wres A A A B B B使用 W W W的主奇异值和奇异向量进行初始化,而 W r e s W^{res} Wres则使用剩余的奇异值和奇异向量初始化,并在微调过程中保持不变。也就能保证初始化时和基座模型一样。因此,和LoRA一样,PISSA的训练中也只训练了低秩矩阵 A A A B B B,而 W r e s W^{res} Wres保持冻结

初始化A和B矩阵:使用主要的奇异值和奇异向量初始化两个可训练的矩阵:


构建残差矩阵 W r e s W^{res} Wres:使用残差奇异值和奇异向量构建残差矩阵:

实验

PISSA微调

import torch
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTTrainer
from datasets import load_datasetmodel = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.pad_token_id = tokenizer.eos_token_id
lora_config = LoraConfig(# init_lora_weights="pissa", # Configure the initialization method to "pissa", which may take several minutes to execute SVD on the pre-trained model.init_lora_weights="pissa_niter_4", # Initialize the PiSSA with fast SVD, which completes in just a few seconds.
)
peft_model = get_peft_model(model, lora_config)peft_model.print_trainable_parameters()dataset = load_dataset("imdb", split="train[:1%]")trainer = SFTTrainer(model=peft_model,train_dataset=dataset,dataset_text_field="text",max_seq_length=128,tokenizer=tokenizer,
)
trainer.train()
peft_model.save_pretrained("pissa-llama-2-7b")

pissa加载

import torch
from peft import PeftModel
from transformers import AutoModelForCausalLMmodel = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto"
)
# Performs SVD again to initialize the residual model and loads the state_dict of the fine-tuned PiSSA modules.
peft_model = PeftModel.from_pretrained(model, "pissa-llama-2-7b")

将 PiSSA 转换为 LoRA

import torch
from peft import PeftModel
from transformers import AutoModelForCausalLMmodel = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto"
)
# No SVD is performed during this step, and the base model remains unaltered.
peft_model = PeftModel.from_pretrained(model, "pissa-llama-2-7b-lora")

总结

PiSSA是一种高效的微调方法,它通过奇异值分解提取大型语言模型中的关键参数,并仅对这些参数进行更新,以实现与全参数微调相似的性能,同时显著降低计算成本和参数数量。

参考文献

  • PISSA: PRINCIPAL SINGULAR VALUES AND SINGULAR VECTORS ADAPTATION OF LARGE LANGUAGE MODELS,https://arxiv.org/pdf/2404.02948
  • https://github.com/GraphPKU/PiSSA
  • LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS,https://arxiv.org/pdf/2106.09685
  • https://github.com/microsoft/LoRA

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/32144.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux_内核缓冲区

目录 1、用户缓冲区概念 2、用户缓冲区刷新策略 3、用户缓冲区的好处 4、内核缓冲区 5、验证内核缓冲区 6、用户缓冲区存放的位置 7、全缓冲 结语 前言: Linux下的内核缓冲区存在于系统中,该缓冲区和用户层面的缓冲区不过同一个概念&#x…

数据结构与算法引入(Python)

华子目录 引入第一次尝试第二次尝试 算法的概念算法的五大特性 算法效率衡量执行时间单靠时间值绝对可信吗? 时间复杂度与 "大O记法"如何理解 “大O记法” 最坏时间复杂度时间复杂度的几条基本计算规则 算法分析常见的时间复杂度常见时间复杂度之间的关系…

2024最新版DataGrip安装教程-全网最全教程!!!

1.DataGrip下载安装 1.打开DataGrip官网,选择自己需要的版本下载即可: 2.进行安装: 3.重启打开: 我这个是正版激活码激活的,需要教程可以关注留言

网络爬虫中selenium和requests这两个工具有什么区别呢?

在自动化和网络爬虫的开发过程中,Selenium和Requests是两个常用的工具。尽管它们都可以用于从互联网上获取数据,但它们在用途、功能和工作原理上存在显著的差异。以下将详细探讨Selenium和Requests之间的主要区别。 一、用途和定位 Selenium&#xff1…

互联网的盈利模式

1. 广告收入 展示广告:通过在网站或应用上展示横幅广告、视频广告等,按点击次数(CPC)或展示次数(CPM)收费。搜索广告:通过搜索引擎上的关键词竞价广告,按点击次数收费。社交媒体广告…

conda 常用指令大集合

增删改查 注:所有的环境都保存在路径 ${anaconda安装目录}/env/[环境名]下 查看现有环境 (星号✳️代表当前) conda info --env 创建 conda 环境 conda create -n [环境名] python[版本号] 删除 conda 环境 conda remove -n [环境名] --all 修改…

如何通过京东API优雅地获取商品评论数据

在当今电商领域,用户评论成为了影响消费者购买决策的关键因素之一。京东,作为中国领先的电商平台,深知商品评论的重要性,并通过其开放平台提供了商品评论数据接口,让开发者能够轻松获取商品的用户反馈信息。本文将指导…

[Redis]持久化机制

众所周知,Redis是内存数据库,也就是把数据存在内存上,读写速度很快,但是,内存的数据容易丢失,为了数据的持久性,还得把数据存储到硬盘上 也就是说,内存有一份数据,硬盘也…

如何在 Windows 上安装 Docker Desktop

如何在 Windows 上安装 Docker Desktop Docker 是一个开放平台,用于开发、部署和运行应用程序。Docker Desktop 是 Docker 在 Windows 和 macOS 上的官方客户端,它使得开发者能够轻松地在本地环境中构建、运行和共享容器化应用程序。本文将详细介绍如何…

RuoYi Swagger请求401

问题描述: 提示:这里简述项目相关背景: 使用ruoyi-vue分离版,访问swagger,发现接口都调用失败:401 解决方案: 最终解决问题如下步骤: 1、 调用swagger中的接口,报错&a…

如何在SpringSecurity中配置基于角色的访问控制?

在Spring Security中配置基于角色的访问控制是保护应用程序和资源不被未授权访问的基本策略之一。这里,我们将详细介绍如何在配置中和方法级别上实现基于角色的访问控制。 1. 配置基于角色的访问控制 在Spring Security的配置类中,你可以使用HttpSecur…

揭秘MMAdapt:如何利用AI跨领域战胜新兴健康谣言?

MMAdapt: A Knowledge-Guided Multi-Source Multi-Class Domain Adaptive Framework for Early Health Misinformation Detection 论文地址: MMAdapt: A Knowledge-guided Multi-source Multi-class Domain Adaptive Framework for Early Health Misinformation Detection …

【Mysql】DQL操作单表、创建数据库、排序、聚合函数、分组、limit关键字

DQL操作单表 1.1 创建数据库 •创建一个新的数据库 db2 CREATE DATABASE db2 CHARACTER SET utf8;•将db1数据库中的 emp表 复制到当前 db2数据库 ** 1.2 排序** 通过 ORDER BY 子句,可以将查询出的结果进行排序 (排序只是显示效果,不会影响真实数据) 语法结构:…

算法:渐进记号的含义及时间复杂度计算

渐进记号及时间复杂度计算 渐近符号渐近记号 Ω \Omega Ω渐进记号 Θ \Theta Θ渐进记号小 ο \omicron ο渐进记号小 ω \omega ω渐进记号大 O \Omicron O常见的时间复杂度关系 时间复杂度计算:递归方程代入法迭代法套用公式法 渐近符号 渐近记号 Ω \Omega Ω …

每天写java到期末考试--接口1--基础--6.22

规则: 练习: 抽象类的抽象方法 动物类Animal package 期末复习;public abstract class Animal {private String name;private int age;//1.空构造public Animal(){}public Animal(String name,int age){this.ageage;this.namename;}public String getNa…

【C++提高编程-11】----C++ STL常用集合算法

🎩 欢迎来到技术探索的奇幻世界👨‍💻 📜 个人主页:一伦明悦-CSDN博客 ✍🏻 作者简介: C软件开发、Python机器学习爱好者 🗣️ 互动与支持:💬评论 &…

Nginx 负载均衡实现上游服务健康检查

Nginx 负载均衡实现上游服务健康检查 Author:Arsen Date:2024/06/20 目录 Nginx 负载均衡实现上游服务健康检查 前言一、Nginx 部署并新增模块二、健康检查配置2.1 准备 nodeJS 应用程序2.2 Nginx 配置负载均衡健康检查 小结 前言 如果你使用云负载均衡…

深入理解适配器模式:Java实现与框架应用

适配器模式是一种结构型设计模式,它允许将一个类的接口转换成客户端希望的另一个接口。适配器模式使得原本由于接口不兼容而不能一起工作的类可以协同工作。在本篇博客中,我们将详细介绍适配器模式,并演示如何在Java中实现它。最后&#xff0…

python从入门到精通9:字符串简介

Python中的字符串是一种非常常见且重要的数据类型,用于存储一系列字符(如文本、数字、标点符号等)。Python的字符串处理功能强大且灵活,为开发者提供了丰富的操作方法和工具。下面我们将对Python字符串进行深入的解析。 1. 字符串…

对于大型 Clojure 项目,如何进行有效的代码组织和模块划分以提高可维护性?

在大型 Clojure 项目中,以下是一些有效的代码组织和模块划分的方法,可提高可维护性: 使用命名空间(namespace):将相关函数和数据结构组织到逻辑上相关的命名空间中,以便更好地理解和管理代码。按…