从零学习大模型(七)-----LoRA(中)

自注意力层中的 LoRA 应用

Transformer 的自注意力机制是模型理解输入序列之间复杂关系的核心部分。自注意力层通常包含多个线性变换,包括键(Key)查询(Query)值(Value) 三个权重矩阵的线性映射,这些权重矩阵需要被训练来适应不同的任务。

LoRA 在自注意力层中的步骤:

冻结原始权重:将自注意力层中的键、查询、值矩阵的权重参数冻结,不进行训练。这些原始权重是通过预训练获得的,用于保留模型在大规模数据上学习到的通用知识。

添加低秩适配矩阵:对于键、查询、值矩阵的权重分别添加低秩适配矩阵 A A A B B B。例如,如果键矩阵 W k ∈ R d k × d model W_k \in \mathbb{R}^{d_k \times d_{\text{model}}} WkRdk×dmodel,则引入两个适配矩阵:
A k ∈ R d k × r , B k ∈ R r × d model A_k \in \mathbb{R}^{d_k \times r}, \quad B_k \in \mathbb{R}^{r \times d_{\text{model}}} AkRdk×r,BkRr×dmodel
其中 r r r 是低秩值,通常比 d k d_k dk d model d_{\text{model}} dmodel 小很多。

权重变换:在执行前向传播时,键、查询、值矩阵的线性变换将使用 W k ′ = W k + A k B k W_k' = W_k + A_k B_k Wk=Wk+AkBk,作为新的权重矩阵,其中 W k W_k Wk 是冻结的原始权重, A k B k A_k B_k AkBk 是任务特定的微调部分。

训练适配矩阵:在特定任务的微调过程中,仅更新新增的适配矩阵 A k A_k Ak B k B_k Bk 的参数,而不更新原始的键矩阵权重。这使得微调过程变得高效和轻量化。

前馈网络中的 LoRA 应用

Transformer 的前馈网络(FFN) 通常由两个线性层和一个非线性激活函数(如 ReLU)组成。在前馈网络中,参数数量非常庞大,因为它在每个位置上独立地对输入进行两次线性变换。

LoRA 在前馈网络中的步骤:

冻结全连接层权重:前馈网络中的两个全连接层的权重也被冻结。这些层用于对每个输入位置进行独立的变换,其预训练参数保持不变,用于保留预训练期间获得的通用知识。

添加适配矩阵:对于每个全连接层,添加低秩适配矩阵。例如,假设前馈网络中的第一层权重矩阵为 W f f ∈ R d f f × d model W_{ff} \in \mathbb{R}^{d_{ff} \times d_{\text{model}}} WffRdff×dmodel,则 LoRA 在这个矩阵上引入两个适配矩阵: A f f ∈ R d f f × r , B f f ∈ R r × d model A_{ff} \in \mathbb{R}^{d_{ff} \times r}, \quad B_{ff} \in \mathbb{R}^{r \times d_{\text{model}}} AffRdff×r,BffRr×dmodel

权重变换与前向传播:在前向传播中,使用权重矩阵 W f f ′ = W f f + A f f B f f W_{ff}' = W_{ff} + A_{ff} B_{ff} Wff=Wff+AffBff来代替原始的全连接层权重矩阵。这样,前馈网络的线性变换就不仅包含原始的预训练权重,还包含适配矩阵的贡献,用于更好地适应特定任务。

训练适配矩阵:在训练过程中,仅更新适配矩阵 A f f A_{ff} Aff B f f B_{ff} Bff,而冻结的原始权重 W f f W_{ff} Wff 不变,这大大减少了训练中需要更新的参数数量。

LoRA 在 Transformer 中的整体优势

参数效率:在自注意力层和前馈网络中引入低秩适配矩阵,可以大大减少需要训练的参数数量,同时保留模型在大规模数据上学到的知识。

任务适应性:LoRA 的方法允许对每个特定任务引入单独的低秩适配矩阵,而保持预训练模型的核心不变,这使得同一个预训练模型可以高效地适应多种任务。

降低计算与存储开销:由于适配矩阵的秩 rrr 通常远小于模型的维度,LoRA 显著降低了训练和存储的开销,使得大模型在计算资源受限的情况下能够得到有效应用。

代码实现LoRA微调Bert的过程

import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertModel# 假设我们有一个预训练好的 Transformer 模型,例如 BERT
pretrained_model = BertModel.from_pretrained('bert-base-uncased')# 定义 LoRA 适配层
class LoRAAdapter(nn.Module):def __init__(self, input_dim, low_rank_dim):super(LoRAAdapter, self).__init__()self.A = nn.Linear(input_dim, low_rank_dim, bias=False)  # 低秩矩阵 Aself.B = nn.Linear(low_rank_dim, input_dim, bias=False)  # 低秩矩阵 Bdef forward(self, x):return self.B(self.A(x))  # 输出 B(A(x))# 定义一个新的 BERT 模型,包含 LoRA 适配器
class LoRABertModel(nn.Module):def __init__(self, pretrained_model, low_rank_dim):super(LoRABertModel, self).__init__()self.bert = pretrained_modelself.low_rank_dim = low_rank_dim# 为 BERT 的每一层添加 LoRA 适配器for layer in self.bert.encoder.layer:layer.attention.self.query_adapter = LoRAAdapter(layer.attention.self.query.in_features, low_rank_dim)layer.attention.self.key_adapter = LoRAAdapter(layer.attention.self.key.in_features, low_rank_dim)layer.attention.self.value_adapter = LoRAAdapter(layer.attention.self.value.in_features, low_rank_dim)# 冻结原始 BERT 模型的所有参数for param in self.bert.parameters():param.requires_grad = False# 只训练 LoRA 适配层for layer in self.bert.encoder.layer:for param in layer.attention.self.query_adapter.parameters():param.requires_grad = Truefor param in layer.attention.self.key_adapter.parameters():param.requires_grad = Truefor param in layer.attention.self.value_adapter.parameters():param.requires_grad = Truedef forward(self, input_ids, attention_mask=None, token_type_ids=None):# 对每一层应用 LoRA 适配器for layer in self.bert.encoder.layer:# 使用适配后的 Query、Key、Value 进行注意力计算query = layer.attention.self.query(input_ids) + layer.attention.self.query_adapter(input_ids)key = layer.attention.self.key(input_ids) + layer.attention.self.key_adapter(input_ids)value = layer.attention.self.value(input_ids) + layer.attention.self.value_adapter(input_ids)# 注意力计算attn_output, _ = layer.attention.self.attn(query, key, value)input_ids = attn_outputreturn self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)# 创建包含 LoRA 适配器的 BERT 模型
low_rank_dim = 16
lora_bert = LoRABertModel(pretrained_model, low_rank_dim)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, lora_bert.parameters()), lr=1e-4)# 生成一些随机数据用于训练
data = torch.randint(0, 30522, (8, 64))  # 随机输入数据,批量大小为 8,序列长度为 64
labels = torch.randint(0, 2, (8, 64))  # 随机标签# 训练循环
num_epochs = 5
for epoch in range(num_epochs):# 前向传播outputs = lora_bert(data).last_hidden_statelogits = outputs.view(-1, outputs.size(-1))labels = labels.view(-1)# 计算损失loss = criterion(logits, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()# 打印损失值print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

利用LoRA对LLAMA2进行微调的代码

import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from datasets import load_dataset# 1. 加载预训练的 LLaMA 2 7B 模型
model_name = "meta-llama/Llama-2-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)# 2. 手动添加 LoRA 适配器
class LoRAAdapter(nn.Module):def __init__(self, input_dim, low_rank_dim):super(LoRAAdapter, self).__init__()self.A = nn.Linear(input_dim, low_rank_dim, bias=False)  # 低秩矩阵 Aself.B = nn.Linear(low_rank_dim, input_dim, bias=False)  # 低秩矩阵 Bdef forward(self, x):return self.B(self.A(x))  # 输出 B(A(x))# 添加 LoRA 适配器到模型的注意力层
low_rank_dim = 16
for name, module in model.named_modules():if "attn" in name.lower() and hasattr(module, 'q_proj') and hasattr(module, 'k_proj') and hasattr(module, 'v_proj'):# 为 Query, Key, Value 添加 LoRA 适配器input_dim = module.q_proj.in_featuresmodule.q_proj_lora = LoRAAdapter(input_dim, low_rank_dim)module.k_proj_lora = LoRAAdapter(input_dim, low_rank_dim)module.v_proj_lora = LoRAAdapter(input_dim, low_rank_dim)# 使用 hooks 修改前向传播逻辑def hook_fn_forward_q(module, input, output):return output + module.q_proj_lora(input[0])def hook_fn_forward_k(module, input, output):return output + module.k_proj_lora(input[0])def hook_fn_forward_v(module, input, output):return output + module.v_proj_lora(input[0])module.q_proj.register_forward_hook(hook_fn_forward_q)module.k_proj.register_forward_hook(hook_fn_forward_k)module.v_proj.register_forward_hook(hook_fn_forward_v)# 3. 定义训练参数和优化器
training_args = TrainingArguments(output_dir="./llama2_7b_lora_finetuned",num_train_epochs=3,per_device_train_batch_size=4,gradient_accumulation_steps=8,learning_rate=1e-4,logging_dir="./logs",logging_steps=10,save_total_limit=2,fp16=True  # 采用混合精度加速训练
)optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)# 4. 加载训练数据集
train_dataset = load_dataset("daily_dialog", split="train")# 5. 定义 Trainer 对象
trainer = Trainer(model=model,args=training_args,train_dataset=train_dataset,tokenizer=tokenizer,optimizers=(optimizer, None)
)# 6. 开始训练
trainer.train()# 7. 验证和保存模型
trainer.evaluate()
model.save_pretrained("./llama2_7b_lora_finetuned")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/56595.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue开发

新建 Vue 项目 vue create project_name按照自己的需求模块进行安装 选择安装 Router、Vuex 插件 选择 Vue3 版本 是否使用 history 模式的路由,按需选 Y 或者 n 后面的选项都默认即可 是否记住上面的选择项:否 总体配置 启动项目 cd demo…

solr安装ik分词器

环境 系统 windows docker v4.34.3 solr:8.11.2 ik:ik-analyzer-solr7-7.x 1.安装步骤 1.1启动solr 1.1启动(为了方便编辑配置文件,挂载了文件目录) docker run -d -p 8983:8983 -v C:\docker\solr\classes:/opt/solr/server/solr-webap…

Windows解决localhost拒绝了连接请求

最近,在开发前端Vue项目时,Vue项目启动成功,没有任何报错,服务控制台已出现APP访问地址,如下图所示。 览器打开后页面先是空白,然后过了一会儿显示无法访问此网站,localhost拒绝了我们的连接请…

【前端】Next.js的安装及配置

Next.js介绍 Next.js 是一个流行的 React 框架,它具有以下优点: 服务器端渲染(SSR):Next.js 支持服务器端渲染,这意味着页面可以在服务器上预渲染,然后发送给用户,这可以加快首屏加…

关于写更新接口的一些理解

“更新”接口的思路 在上篇文章中,我们讲了如何编写删除接口。这篇文章将讲解如何编写更新接口。 其实,更新接口和新增接口非常相似。整体思路都是传入form参数,然后在service层将form转换成entity,最后调用updateById方法&…

idea删除git历史提交记录

前言:此文章是我在实际工作中有效解决问题的方法,做记录的同时也供大家参考! 一、 首先,通过idea的终端或系统的cmd控制台,进入到你的项目文件根目录,idea终端默认就是项目根目录。 二、确保你当前处于要删…

ECMAScript与JavaScript的区别:深入解析与代码示例

目录 引言 ECMAScript与JavaScript的定义 ECMAScript JavaScript ECMAScript与JavaScript的关系 区别详解 定义上的区别 功能上的区别 实现上的区别 代码示例 ECMAScript 6 (ES6) 特性示例 箭头函数 模板字面量 JavaScript 特有的扩展 在Web开发中的应用 ECMAS…

【数据结构与算法】之栈 vs 队列

栈和队列是计算机科学中最基础也是最常用的两种线性数据结构,它们提供了一种组织和管理数据的方式。它们的主要区别在于元素的添加和删除顺序。理解它们的特点、差异以及底层实现对于选择合适的结构解决特定问题至关重要。本文将更详细地比较栈和队列,并…

Java多线程详解①(全程干货!!!)

这里是Themberfue 今天,我们将正式进入多线程章节的讲解,希望我的讲解能够让你理解😎 进程 在进入多线程的讲解中,我们先引入进程的概念及其解释 操作系统都是大家耳熟能详的名词,常见的操作系统主要有:Li…

opencv - py_ml - py_knn k-最近邻算法

文章目录 1.理解 k-最近邻算法目标理论OpenCV 中的 kNN其他资源 2.使用 kNN 对手写数据进行 OCR目标手写数字的 OCR英文字母的 OCR其他资源 1.理解 k-最近邻算法 目标 在本章中,我们将理解 k-最近邻算法 (kNN) 的概念。 理论 kNN 是监督学习中最简单的分类算法之…

从0到1学习node.js(path模块以及HTTP协议)

文章目录 一、path模块二、HTTP协议1、常见状态码分类2、IP地址3、端口 一、path模块 // 引入path模块 const path require(path)// 拼接地址 const resolveData path.resolve(__dirname, ./index) console.log(__dirname, __dirname) console.log(resolveData, resolveData…

【js逆向专题】12.RPC技术

目录 一. websocket1. 什么是websocket2. websocket的原理3. websocket实现方式1. 客户端2.服务端3. 实际案例1. 案例目标2. 解析思路 二. RPC1. RPC 简介2.Sekiro-RPC1. 使用方法1. 执行方式2.客户端环境3.使用参数说明 2. 测试使用1. 前端代码2. SK API3.python调用代码 三.项…

C++,STL 042(24.10.21)

内容 一道练习题。 (涉及list,sort) 题目(大致) 将Person自定义类型进行排序(Person中属性有姓名、年龄、身高),按照年龄进行升序,如果年龄相同则按照身高进行降序。 …

openpnp - 解决“底部相机高级校正成功后, 开机归零时,吸嘴自动校验失败的问题“

文章目录 openpnp - 解决"底部相机高级校正成功后, 开机归零时,吸嘴自动校验失败的问题"概述笔记问题现象1问题现象2原因分析现在底部相机和吸嘴的位置偏差记录修正底部相机位置现在再看看NT1在底部相机中的位置开机归零,看看是否能通过所有校…

【分布式微服务云原生】《Redis 分布式锁的挑战与解决方案及 RedLock 的强大魅力》

《Redis 分布式锁的挑战与解决方案及 RedLock 的强大魅力》 摘要: 本文深入探讨了使用 Redis 做分布式锁时可能遇到的各种问题,并详细阐述了相应的解决方案。同时,深入剖析了 RedLock 作为分布式锁的原因及原理,包括其多节点部署…

HarmonyOS鸿蒙- 一行代码自动换行技巧

DevEco Studio 编辑器设置 一行代码自动换行显示。 一、代码自动换行设置方式路径:File > Editor > General 如图: 二、找到标题:Soft Wraps 勾选《Soft-wrap these files:》,然后在后面添加*.ets 然后保存即可。添加后&#xff0c…

【TIMM库】是一个专门为PyTorch用户设计的图像模型库 python库

TIMM库 1、引言:遇见TIMM2、初识TIMM:安装与基本结构3、实战案例一:图像分类4、实战案例二:迁移学习5、实战案例三:模型可视化6、结语:TIMM的无限可能 1、引言:遇见TIMM 大家好,我是…

LangSplat和3D language fields简略介绍

LangSplat: 3D Language Gaussian Splatting 相关技术拆分解释: 3dgs:伟大无需多言SAM:The Segment Anything Model,是图像分割领域的foundational model,已经用在很多视觉任务上(如图像修复、物体追踪、图…

支持国密算法的数字证书-国密SSL证书详解

在互联网中,数字证书作为标志通讯各方身份信息的数字认证而存在,常见的数字证书大都采用国际算法,比如RSA算法、ECC算法、SHA2算法等。随着我国加强网络安全技术自主可控的大趋势,也出现了支持国密算法的数字证书-国密SSL证书。那…

OpenCV高级图形用户界面(21)暂停程序执行并等待用户按键输入函数waitKey()的使用

操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 等待按键 该函数 waitKey 在 delay≤0 时无限等待按键事件,或者在 delay 为正数时等待 delay 毫秒。由于操作系统在切换线程时有最小…