AIGC笔记--基于PEFT库使用LoRA

1--相关讲解

LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS

LoRA 在 Stable Diffusion 中的三种应用:原理讲解与代码示例

PEFT-LoRA

2--基本原理

        固定原始层,通过添加和训练两个低秩矩阵,达到微调模型的效果;

3--简单代码

import torch
import torch.nn as nn
from peft import LoraConfig, get_peft_model, LoraModel
from peft.utils import get_peft_model_state_dict# 创建模型
class Simple_Model(nn.Module):def __init__(self):super().__init__()self.linear1 = nn.Linear(64, 128)self.linear2 = nn.Linear(128, 256)def forward(self, x: torch.Tensor):x = self.linear1(x)x = self.linear2(x)return xif __name__ == "__main__":# 初始化原始模型origin_model = Simple_Model()# 配置lora configmodel_lora_config = LoraConfig(r = 32, lora_alpha = 32, # scaling = lora_alpha / r 一般来说,lora_alpha的参数初始化为与r相同,即scale=1init_lora_weights = "gaussian", # 参数初始化方式target_modules = ["linear1", "linear2"], # 对应层添加lora层lora_dropout = 0.1)# Test datainput_data = torch.rand(2, 64)origin_output = origin_model(input_data)# 原始模型的权重参数origin_state_dict = origin_model.state_dict() # 两种方式生成对应的lora模型,调用后会更改原始的模型new_model1 = get_peft_model(origin_model, model_lora_config)new_model2 = LoraModel(origin_model, model_lora_config, "default")output1 = new_model1(input_data)output2 = new_model2(input_data)# 初始化时,lora_B矩阵会初始化为全0,因此最初 y = WX + (alpha/r) * BA * X == WX# origin_output == output1 == output2# 获取lora权重参数,两者在key_name上会有区别new_model1_lora_state_dict = get_peft_model_state_dict(new_model1)new_model2_lora_state_dict = get_peft_model_state_dict(new_model2)# origin_state_dict['linear1.weight'].shape -> [output_dim, input_dim]# new_model1_lora_state_dict['base_model.model.linear1.lora_A.weight'].shape -> [r, input_dim]# new_model1_lora_state_dict['base_model.model.linear1.lora_B.weight'].shape -> [output_dim, r]print("All Done!")

4--权重保存和合并

核心公式是:new_weights = origin_weights + alpha* (BA)

    # 借助diffuser的save_lora_weights保存模型权重from diffusers import StableDiffusionPipelinesave_path = "./"global_step = 0StableDiffusionPipeline.save_lora_weights(save_directory = save_path,unet_lora_layers = new_model1_lora_state_dict,safe_serialization = True,weight_name = f"checkpoint-{global_step}.safetensors",)# 加载lora模型权重(参考Stable Diffusion),其实可以重写一个简单的版本from safetensors import safe_openalpha = 1. # 参数融合因子lora_path = "./" + f"checkpoint-{global_step}.safetensors"state_dict = {}with safe_open(lora_path, framework="pt", device="cpu") as f:for key in f.keys():state_dict[key] = f.get_tensor(key)all_lora_weights = []for idx,key in enumerate(state_dict):# only process lora down keyif "lora_B." in key: continueup_key    = key.replace(".lora_A.", ".lora_B.") # 通过lora_A直接获取lora_B的键名model_key = key.replace("unet.", "").replace("lora_A.", "").replace("lora_B.", "")layer_infos = model_key.split(".")[:-1]curr_layer = new_model1while len(layer_infos) > 0:temp_name = layer_infos.pop(0)curr_layer = curr_layer.__getattr__(temp_name)weight_down = state_dict[key].to(curr_layer.weight.data.device)weight_up   = state_dict[up_key].to(curr_layer.weight.data.device)# 将lora参数合并到原模型参数中 -> new_W = origin_W + alpha*(BA)curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)all_lora_weights.append([model_key, torch.mm(weight_up, weight_down).t()])print('Load Lora Done')

5--完整代码

PEFT_LoRA

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/843371.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AIGC 009-DaLLE2遇见达利!文生图过程中另外一种思路。

AIGC 009-DaLLE2遇见达利!文生图过程中另外一种思路。 0 论文工作 首先,遇见达利是我很喜欢的名字,达利是跟毕加索同等优秀的画家。这个名字就很有意思。 这篇论文提出了一种新颖的分层文本条件图像生成方法,该方法利用 CLIP&…

代码随想录算法训练营Day45 | 70. 爬楼梯 (进阶) 322. 零钱兑换 279.完全平方数

代码随想录算法训练营Day45 | 70. 爬楼梯 &#xff08;进阶&#xff09; 322. 零钱兑换 279.完全平方数 LeetCode 70. 爬楼梯&#xff08;进阶&#xff09; 题目链接&#xff1a;LeetCode 70. 爬楼梯&#xff08;进阶&#xff09; 思路&#xff1a; #include <iostream&…

RustGUI学习(iced/iced_aw)之扩展小部件(二十六):如何是drop_down部件来构建下拉菜单?

前言 本专栏是学习Rust的GUI库iced的合集,将介绍iced涉及的各个小部件分别介绍,最后会汇总为一个总的程序。 iced是RustGUI中比较强大的一个,目前处于发展中(即版本可能会改变),本专栏基于版本0.12.1. 概述 这是本专栏的第二十六篇,主要讲述drop_down部件的使用,会结合…

DolphinScheduler 3.3.0版本更新一览

Apache DolphinScheduler即将迎来3.3.0版本的发布&#xff0c;届时将有一系列重要的更新和改进。在近期的社区5月份用户线上分享会上&#xff0c;项目PMC 阮文俊为大家介绍了3.3.0版本将带来的主要更新和改进&#xff0c;并为大家指出了如何参与社区的方式。 什么是DolphinSch…

四川古力未来科技抖音小店安全靠谱,购物新体验

在数字化浪潮席卷而来的今天&#xff0c;电商行业蓬勃发展&#xff0c;各种线上购物平台如雨后春笋般涌现。其中&#xff0c;抖音小店凭借其独特的短视频直播购物模式&#xff0c;迅速赢得了广大消费者的青睐。而四川古力未来科技抖音小店&#xff0c;更是以其安全靠谱、品质保…

LeetCode hot100-48-G

437. 路径总和 III给定一个二叉树的根节点 root &#xff0c;和一个整数 targetSum &#xff0c;求该二叉树里节点值之和等于 targetSum 的 路径 的数目。路径 不需要从根节点开始&#xff0c;也不需要在叶子节点结束&#xff0c;但是路径方向必须是向下的&#xff08;只能从父…

ARM鲲鹏920-oe2309-caffe

参考链接:Caffe | Installation 安装依赖包 dnf install dnf update dnf install leveldb-devel snappy-devel opencv.aarch64 boost-devel hdf5-devel gflags-devel glog-devel lmdb-devel openblas.aarch64 dnf install git wget tar gcc-g unzip automake libtool autoco…

阻止el-popover的冒泡事件

在 Vue.js 中使用 Element UI 或 Element Plus 组件库时&#xff0c;如果你想要阻止 el-popover 的冒泡事件&#xff0c;你可以在触发该事件的处理函数中调用 event.stopPropagation() 方法。这个方法会阻止事件进一步向上冒泡到 DOM 树中的父元素。 以下是一个如何在 el-popo…

网工内推 | 高校、外企网工,IE认证优先,年薪最高18w

01 上海外国语大学贤达经济人文学院 &#x1f537;招聘岗位&#xff1a;高校网络主管 &#x1f537;职责描述&#xff1a; 1、负责总机房、网络规划及管理&#xff0c;包括容量规划、成本评估、建设管理等; 2、负责设计、实施及维护全网络架构及规划网络变更计划 3、负责网络功…

VMware ESXi 兼容性查询

官网兼容性查询地址&#xff1a;https://www.vmware.com/resources/compatibility/search.php

优选免单:重塑电商销售模式的新策略

随着电商行业的不断发展&#xff0c;一种名为“优选免单”的新兴销售模式正逐渐崭露头角。该模式以独特的价格策略、创新的奖励机制和巧妙的社交网络应用为核心&#xff0c;成功激发了消费者的购买热情&#xff0c;并实现了销售的高速增长。 一、规范运营&#xff0c;避免潜在风…

OpenHarmony鸿蒙软总线使用mbedtls数据加密详解

OpenHarmony鸿蒙软总线子系统中使用了多种的加密技术,本篇介绍调用mbedtls的数据加密。 调用mbedtls加密的源码位于: foundation/communication/dsoftbus/adapter/common/mbedtls/softbus_adapter_crypto.c 这个源码单元,调用mbedTLS库实现了各种加密功能,包括AES-GCM加密…

【MongoDB】配置Secondary(从节点) 的 Sync Target(复制源)

一 概述 从节点 从 主节点捕获数据以保持副本集数据的最新副本。然而&#xff0c;默认情况下&#xff0c;从节点可能会根据成员之间的ping时间变化和其他成员的复制状态自动更改其同步目标。请参阅“副本集数据同步”和“管理链式复制”以获取更多信息。 对于某些部署&#x…

【网络协议】应用层协议HTTPS

文章目录 为什么引入HTTPS&#xff1f;基本概念加密的基本过程对称加密非对称加密中间人攻击证书 为什么引入HTTPS&#xff1f; 由于HTTP协议在网络传输中是明文传输的&#xff0c;那么当传输一些机密的文件或着对钱的操作时&#xff0c;就会有泄密的风险&#xff0c;从而引入…

【进程空间】通过页表寻址的过程

文章目录 前言介绍页表、页框、页目录的概念页框页表页目录页表和页目录的分配 一级页表和二级页表一级页表寻址过程 二级页表寻址过程 一级页表和二级页表的对比 前言 我们知道每个进程都有属于自己的虚拟地址空间&#xff0c;且每个进程的虚拟地址都是统一的。要想通过虚拟地…

数据结构(七)递归、快速排序

文章目录 一、递归&#xff08;一&#xff09;使用递归实现1~n求和1. 代码实现&#xff1a;2. 调用过程&#xff1a;3. 输出结果&#xff1a; &#xff08;二&#xff09;青蛙跳台阶问题1. 问题分析2. 代码实现3. 输出结果4. 代码效率优化5. 优化后的输出结果 二、快速排序&…

Euler 欧拉系统介绍

Euler 欧拉系统介绍 1 简介重要节点与版本EulerOS 特色EulerOS 与 openEuler 区别联系Euler 与 HarmonyOS 区别联系 2 openEuler特色支持 ARM&#xff0c;x86&#xff0c;RISC-V 等全部主流通用计算架构融入 AI 生态嵌入式实时能力提升引入 OpenHarmony 一些突出功能 参考 1 简…

将 KNX 接入 Home Assistant 之二 准备软件

写在前面&#xff1a; 在KNX官网也有关于 Home Assistant 的教程&#xff0c;地址是 Get started with Home Assistant x KNX 需要的东西是 a KNX IP Interface or Routera Raspberry Pian SD Card at least 32 GB 安装 Home Assistant 系统 下载镜像&#xff1a; 地址&…

EfficientNet结构的特点

EfficientNet是一种高效的卷积神经网络架构&#xff0c;它通过系统化的方法来提升模型的性能和效率。由Google AI提出&#xff0c;EfficientNet的设计理念是通过网络的复合缩放&#xff08;compound scaling&#xff09;来均衡地扩展网络的深度&#xff08;depth&#xff09;、…

idea中git检出失败

之前clone好好的&#xff0c;今天突然就拉取不下来了。很多时候是用户凭证的信息没更新的问题。由于window对同一个地址都存储了会话。如果是新的会话&#xff0c;必须要更新window下的凭证。 然后根据你的仓库找到你对应的账户&#xff0c;更新信息即可。