AIGC笔记--基于PEFT库使用LoRA

1--相关讲解

LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS

LoRA 在 Stable Diffusion 中的三种应用:原理讲解与代码示例

PEFT-LoRA

2--基本原理

        固定原始层,通过添加和训练两个低秩矩阵,达到微调模型的效果;

3--简单代码

import torch
import torch.nn as nn
from peft import LoraConfig, get_peft_model, LoraModel
from peft.utils import get_peft_model_state_dict# 创建模型
class Simple_Model(nn.Module):def __init__(self):super().__init__()self.linear1 = nn.Linear(64, 128)self.linear2 = nn.Linear(128, 256)def forward(self, x: torch.Tensor):x = self.linear1(x)x = self.linear2(x)return xif __name__ == "__main__":# 初始化原始模型origin_model = Simple_Model()# 配置lora configmodel_lora_config = LoraConfig(r = 32, lora_alpha = 32, # scaling = lora_alpha / r 一般来说,lora_alpha的参数初始化为与r相同,即scale=1init_lora_weights = "gaussian", # 参数初始化方式target_modules = ["linear1", "linear2"], # 对应层添加lora层lora_dropout = 0.1)# Test datainput_data = torch.rand(2, 64)origin_output = origin_model(input_data)# 原始模型的权重参数origin_state_dict = origin_model.state_dict() # 两种方式生成对应的lora模型,调用后会更改原始的模型new_model1 = get_peft_model(origin_model, model_lora_config)new_model2 = LoraModel(origin_model, model_lora_config, "default")output1 = new_model1(input_data)output2 = new_model2(input_data)# 初始化时,lora_B矩阵会初始化为全0,因此最初 y = WX + (alpha/r) * BA * X == WX# origin_output == output1 == output2# 获取lora权重参数,两者在key_name上会有区别new_model1_lora_state_dict = get_peft_model_state_dict(new_model1)new_model2_lora_state_dict = get_peft_model_state_dict(new_model2)# origin_state_dict['linear1.weight'].shape -> [output_dim, input_dim]# new_model1_lora_state_dict['base_model.model.linear1.lora_A.weight'].shape -> [r, input_dim]# new_model1_lora_state_dict['base_model.model.linear1.lora_B.weight'].shape -> [output_dim, r]print("All Done!")

4--权重保存和合并

核心公式是:new_weights = origin_weights + alpha* (BA)

    # 借助diffuser的save_lora_weights保存模型权重from diffusers import StableDiffusionPipelinesave_path = "./"global_step = 0StableDiffusionPipeline.save_lora_weights(save_directory = save_path,unet_lora_layers = new_model1_lora_state_dict,safe_serialization = True,weight_name = f"checkpoint-{global_step}.safetensors",)# 加载lora模型权重(参考Stable Diffusion),其实可以重写一个简单的版本from safetensors import safe_openalpha = 1. # 参数融合因子lora_path = "./" + f"checkpoint-{global_step}.safetensors"state_dict = {}with safe_open(lora_path, framework="pt", device="cpu") as f:for key in f.keys():state_dict[key] = f.get_tensor(key)all_lora_weights = []for idx,key in enumerate(state_dict):# only process lora down keyif "lora_B." in key: continueup_key    = key.replace(".lora_A.", ".lora_B.") # 通过lora_A直接获取lora_B的键名model_key = key.replace("unet.", "").replace("lora_A.", "").replace("lora_B.", "")layer_infos = model_key.split(".")[:-1]curr_layer = new_model1while len(layer_infos) > 0:temp_name = layer_infos.pop(0)curr_layer = curr_layer.__getattr__(temp_name)weight_down = state_dict[key].to(curr_layer.weight.data.device)weight_up   = state_dict[up_key].to(curr_layer.weight.data.device)# 将lora参数合并到原模型参数中 -> new_W = origin_W + alpha*(BA)curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)all_lora_weights.append([model_key, torch.mm(weight_up, weight_down).t()])print('Load Lora Done')

5--完整代码

PEFT_LoRA

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/843371.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AIGC 009-DaLLE2遇见达利!文生图过程中另外一种思路。

AIGC 009-DaLLE2遇见达利!文生图过程中另外一种思路。 0 论文工作 首先,遇见达利是我很喜欢的名字,达利是跟毕加索同等优秀的画家。这个名字就很有意思。 这篇论文提出了一种新颖的分层文本条件图像生成方法,该方法利用 CLIP&…

DolphinScheduler 3.3.0版本更新一览

Apache DolphinScheduler即将迎来3.3.0版本的发布,届时将有一系列重要的更新和改进。在近期的社区5月份用户线上分享会上,项目PMC 阮文俊为大家介绍了3.3.0版本将带来的主要更新和改进,并为大家指出了如何参与社区的方式。 什么是DolphinSch…

四川古力未来科技抖音小店安全靠谱,购物新体验

在数字化浪潮席卷而来的今天,电商行业蓬勃发展,各种线上购物平台如雨后春笋般涌现。其中,抖音小店凭借其独特的短视频直播购物模式,迅速赢得了广大消费者的青睐。而四川古力未来科技抖音小店,更是以其安全靠谱、品质保…

ARM鲲鹏920-oe2309-caffe

参考链接:Caffe | Installation 安装依赖包 dnf install dnf update dnf install leveldb-devel snappy-devel opencv.aarch64 boost-devel hdf5-devel gflags-devel glog-devel lmdb-devel openblas.aarch64 dnf install git wget tar gcc-g unzip automake libtool autoco…

网工内推 | 高校、外企网工,IE认证优先,年薪最高18w

01 上海外国语大学贤达经济人文学院 🔷招聘岗位:高校网络主管 🔷职责描述: 1、负责总机房、网络规划及管理,包括容量规划、成本评估、建设管理等; 2、负责设计、实施及维护全网络架构及规划网络变更计划 3、负责网络功…

VMware ESXi 兼容性查询

官网兼容性查询地址:https://www.vmware.com/resources/compatibility/search.php

优选免单:重塑电商销售模式的新策略

随着电商行业的不断发展,一种名为“优选免单”的新兴销售模式正逐渐崭露头角。该模式以独特的价格策略、创新的奖励机制和巧妙的社交网络应用为核心,成功激发了消费者的购买热情,并实现了销售的高速增长。 一、规范运营,避免潜在风…

【网络协议】应用层协议HTTPS

文章目录 为什么引入HTTPS?基本概念加密的基本过程对称加密非对称加密中间人攻击证书 为什么引入HTTPS? 由于HTTP协议在网络传输中是明文传输的,那么当传输一些机密的文件或着对钱的操作时,就会有泄密的风险,从而引入…

【进程空间】通过页表寻址的过程

文章目录 前言介绍页表、页框、页目录的概念页框页表页目录页表和页目录的分配 一级页表和二级页表一级页表寻址过程 二级页表寻址过程 一级页表和二级页表的对比 前言 我们知道每个进程都有属于自己的虚拟地址空间,且每个进程的虚拟地址都是统一的。要想通过虚拟地…

数据结构(七)递归、快速排序

文章目录 一、递归(一)使用递归实现1~n求和1. 代码实现:2. 调用过程:3. 输出结果: (二)青蛙跳台阶问题1. 问题分析2. 代码实现3. 输出结果4. 代码效率优化5. 优化后的输出结果 二、快速排序&…

Euler 欧拉系统介绍

Euler 欧拉系统介绍 1 简介重要节点与版本EulerOS 特色EulerOS 与 openEuler 区别联系Euler 与 HarmonyOS 区别联系 2 openEuler特色支持 ARM,x86,RISC-V 等全部主流通用计算架构融入 AI 生态嵌入式实时能力提升引入 OpenHarmony 一些突出功能 参考 1 简…

将 KNX 接入 Home Assistant 之二 准备软件

写在前面: 在KNX官网也有关于 Home Assistant 的教程,地址是 Get started with Home Assistant x KNX 需要的东西是 a KNX IP Interface or Routera Raspberry Pian SD Card at least 32 GB 安装 Home Assistant 系统 下载镜像: 地址&…

idea中git检出失败

之前clone好好的,今天突然就拉取不下来了。很多时候是用户凭证的信息没更新的问题。由于window对同一个地址都存储了会话。如果是新的会话,必须要更新window下的凭证。 然后根据你的仓库找到你对应的账户,更新信息即可。

反射器和联邦实验

拓扑 要求 IP配置 [R1-GigabitEthernet0/0/0]ip add 12.0.0.1 24 [R1-LoopBack0]ip add 192.168.1.1 24 [R1-LoopBack1]ip add 10.0.0.1 24[R2-GigabitEthernet0/0/0]ip add 12.0.0.2 24 [R2-GigabitEthernet0/0/1]ip add 172.16.0.1 21 [R2-GigabitEthernet0/0/2]ip add 17…

(IDEA修改Java版本)java: 警告: 源发行版 X 需要目标发行版 X

搜索关键词:一致、发行 错误信息 其他错误: java: 错误: 不支持发行版本 6 java: -source 1.5 中不支持 lambda 表达式 (请使用 -source 8 或更高版本以启用 lambda 表达式) 思路 有两个地方要检查,JDK版本保持一致即可。 比如统一用JDK8或…

FM1800隧道广播插播控制器

隧道广播插播控制器是一款群载波&应急广播插播控制器采用SDR软件无线电技术,产生独立的插播信号与“群载波”信号,本设备可通过软件无线电技术将音频信号调制成调频载波或“群载波”信号,分别送入插播主机,实现隧道广播远端机…

20240528解决飞凌的OK3588-C的核心板可以刷机不能连接ADB的问题

20240528解决飞凌的OK3588-C的核心板可以刷机不能连接ADB的问题 2024/5/28 16:34 OS:Linux R4/Buildroot 硬件接了3条线出来,一直可以刷机,但是链接ADB异常。 【总是链接不上】 Z:\OK3588_Linux_fs\kernel\arch\arm64\boot\dts\rockchip\OK3…

Android11 事件分发流程

在Android 11 输入系统之InputDispatcher和应用窗口建立联系一文中介绍到,当InputDispatcher写入数据后,客户端这边就会调用handleEvent方法接收数据 //frameworks\base\core\jni\android_view_InputEventReceiver.cpp int NativeInputEventReceiver::h…

炒黄金怎么追单?-融知财经网

在黄金投资领域,当市场行情呈现出有利的走势时,许多交易者会选择追加下单以扩大盈利。追单作为一种投资策略,旨在利用市场波动获取额外收益。然而,要想在追单中取得成功,需要掌握一定的技巧和策略。融知财经网给介绍黄金交易中追单的一些关键技巧,帮助投资者理智追单,稳健获利。…

线性插值的频域特性

1、抽取和插值的简单说明 抽取和插值是变采样过程中常用的两种手段,其中抽取的目的是降低数据的采样率,以降低对系统存储深度或计算量的要求。插值的目的是提高数据的采样率,以提高系统的计算精度。 M M M倍抽取通常是通过每隔 M M M…