图像生成:SD lora加载代码详解与实现

文章目录

  • 前言
  • 一、SD模型介绍
  • 二、模型加载
    • 1. 模型架构加载
    • 2. safetensors权重加载
    • 3. lora权重加载
  • 三、Name匹配
  • 四、权重融合
    • 1、构建net类
    • 2、匹配lora weight和model weight
    • 3、基于lora权重创建lora模块
    • 4、权重融合
  • 五、整体pipeline
  • 总结


前言

SD中lora的加载相信都不陌生,但是大家大多数都是利用SD webUI加载lora,本文主要梳理一下SD webUI中lora加载的代码逻辑。关于lora的原理,可以参考我之前的博客——图像生成:SD LoRA模型详解


一、SD模型介绍

SD model结构一般分为几个部分,如下:
在这里插入图片描述

SD webui使用pytorch lightning搭建,了解pl的同学可能知道,模型的相关配置一般都写在yaml文件中,因此其实可以根据yaml文件来判断模型的基本结构,类似如下:

      unet_config:target: ldm.modules.diffusionmodules.openaimodel.UNetModelparams:image_size: 32 # unusedin_channels: 4out_channels: 4model_channels: 320attention_resolutions: [ 4, 2, 1 ]num_res_blocks: 2channel_mult: [ 1, 2, 4, 4 ]num_heads: 8use_spatial_transformer: Truetransformer_depth: 1context_dim: 768use_checkpoint: Truelegacy: Falsefirst_stage_config:target: ldm.models.autoencoder.AutoencoderKLparams:embed_dim: 4monitor: val/rec_lossddconfig:double_z: truez_channels: 4resolution: 256in_channels: 3out_ch: 3ch: 128ch_mult:- 1- 2- 4- 4num_res_blocks: 2attn_resolutions: []dropout: 0.0lossconfig:target: torch.nn.Identitycond_stage_config:target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

二、模型加载

因为整个模型分为VAE,CLIP,Unet等多个部分,lora模型一般是对这几个部分进行权重修改,有的可能只修改了Unet,有的可能多个部分都修改了,这个是由当时lora的训练师冻结的模块决定的。本文接下来主要以Unet部分的加载和lora修改为例来介绍,其它模块类似。

1. 模型架构加载

首先根据yaml文件加载模型架构(这里我只保留了Unet的配置文件,然后进行加载)

unet_config_path = '/data/wangyx/工程处理/sd/unet.yaml'
diff_model_config = OmegaConf.load(unet_config_path)  
unet_config = diff_model_config.model.unet_config
diffusion_model = instantiate_from_config(unet_config)  

2. safetensors权重加载

这里以比较火的chilloutmix作为example 。( ⁼̴̀ .̫ ⁼̴ )✧

#可视化权重结构
import torch
from safetensors.torch import load_file, save_file
from safetensors import safe_open
model_path = '/stable-diffusion-webui/stable-diffusion-webui/models/Stable-diffusion/chilloutmix_NiPrunedFp32Fix.safetensors'
tensors = {}
with safe_open(model_path, framework="pt", device='cpu') as f:for k in f.keys():tensors[k] = f.get_tensor(k)

获取权重后即可把对应权重文件加载到模型结构中。

3. lora权重加载

lora的权重加载直接用safetensor加载即可

lora_path = '**/sdxl_lcm_lora.safetensors'
pl_sd = safetensors.torch.load_file(lora_path) 

三、Name匹配

打印出lora权重名字和模型每一层的名字后其实可以发现,其实都是不对应的,因此需要手动将他们匹配起来,在SD webUI中的Lora中部分使用下面这样一个函数完成权重匹配

def convert_diffusers_name_to_compvis(key, is_sd2):def match(match_list, regex_text):regex = re_compiled.get(regex_text)if regex is None:regex = re.compile(regex_text)re_compiled[regex_text] = regexr = re.match(regex, key)if not r:return Falsematch_list.clear()match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])return Truem = []if match(m, r"lora_unet_conv_in(.*)"):return f'diffusion_model_input_blocks_0_0{m[0]}'if match(m, r"lora_unet_conv_out(.*)"):return f'diffusion_model_out_2{m[0]}'if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"):return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}"if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):if is_sd2:if 'mlp_fc1' in m[1]:return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"elif 'mlp_fc2' in m[1]:return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"else:return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"):if 'mlp_fc1' in m[1]:return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"elif 'mlp_fc2' in m[1]:return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"else:return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"return keykey1 = 'lora_unet_down_blocks_0_downsamplers_0_conv'  #.alpha
new_key1 = convert_diffusers_name_to_compvis(key1, is_sd2=True)
print(new_key1)

通过这样一个函数即可把lora权重的名字替换成和模型名字一样的。


另外同时对SD原模型的层名字进行修改,并存在一个字典中:

def assign_network_names_to_compvis_modules(sd_model):network_layer_mapping = {}for name, module in sd_model.named_modules():network_name = name.replace(".", "_")network_layer_mapping[network_name] = modulemodule.network_layer_name = network_namesd_model.network_layer_mapping = network_layer_mapping

这样子lora和模型的名字就完成对应了。

四、权重融合

完成权重匹配后,就可以进行权重融合了,这里我将SD wenUI中的代码摘了一部分出来进行实现,从而更好理解原理。

1、构建net类

这里我调用了webUI中的net类,用于后续赋予相关属性

network_on_disk = NetworkOnDisk('name', '.pth')
net = Network('name', network_on_disk)
#这部分是建立一个空壳,用于后续操作

2、匹配lora weight和model weight

#创建nametuple保存权重
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
matched_networks = {}for key_network, weight in pl_sd.items():  #循环lora的每一项key_network_without_network_parts, network_part = key_network.split(".", 1)#如果是SDXL,那么isSD2需要选择为Truefkey = convert_diffusers_name_to_compvis(key_network_without_network_parts, True) key = fkey[16:]'''正常模型架构是model.diffusion_model, model.first_stage_model但是现在我们只加载了unet部分所以先去掉前半部分'''sd_module = diffusion_model.network_layer_mapping.get(key, None)#获取修改名字后对应的moduleif key not in matched_networks and sd_module is not None:  matched_networks[key] = NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module)if sd_module is not None:matched_networks[key].w[network_part] = weight

通过上述代码就可以将匹配的Unet层权重和lora权重放在一个NetworkWeights组了,用于后续融合。代码中w存的是lora的权重,sd module存的是对应unet中的结构和权重。

3、基于lora权重创建lora模块

上面构建了matched_networks字典,每个key下对应了一组匹配好的lora权重和模型模块,接下来就是基于lora权重创建lora模块了,我们第一步创建了一个net空壳类,在这里赋予net.modules属性,并将创建好的lora模块赋予该属性。

for key, weights in matched_networks.items():print(key)print(weights)net_module = Nonefor nettype in module_types:net_module = nettype.create_module(net, weights)if net_module is not None:breaknet.modules[key] = net_module

这里的create_module来自于源码中的Lora/network_lora.py, 该函数主要是创建结构,并赋予权重,最后构建一个完整的lora模块。

4、权重融合

接下来就可以完成权重融合了,这里我以某一层为例

network_layer_name = 'output_blocks_11_1_transformer_blocks_0_attn1_to_k'
module = net.modules.get(network_layer_name, None)
print('get模块', module)
fb = matched_networks[network_layer_name].sd_module  
fb_weight = fb.weight
with torch.no_grad():updown, ex_bias = module.calc_updown(fb_weight) if len(fb_weight) == 4 and fb_weight.shape[1] == 9:# inpainting model. zero pad updown to make channel[1]  4 to 9updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))fb_weight += updown

updown计算好的lora权重,fb_weight是原模块的权重,相加即可完成融合。
若要实现所有的权重融合,循环matched_networks中每一个key,然后执行上述操作,最后进行权重替换即可。

五、整体pipeline

最后来整体梳理一下:
1)加载sd模型结构,加载权重
2)加载lora模型
3)进行name匹配,找到互相对应的层
4)将互相对应的sd_module和lora权重放在一组
5)基于lora权重创建lora模块
6)完成lora的计算并融合

总结

整体lora融合的流程其实并不复杂,主要在于匹配和计算融合,SD webUI可能由于集成度较高所以看起来代码稍显复杂,感兴趣其实可以照着这个流程进一步简化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/787624.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

VSCode美化

今天有空收拾了一下VSCode,页面如下,个人觉得还是挺好看的~~ 1. 主题 Noctis 色彩较多,有种繁杂美。 我使用的是浅色主题的一款Noctis Hibernus 2. 字体 Maple Mono 官网:Maple-Font 我只安装了下图两个字体,使…

堆积排序python实现

堆积排序(Heap Sort)是一种基于二叉堆的排序算法。Python语言实现堆积排序算法可以如下所示: def heapify(arr, n, i):largest = i # Initialize largest as rootl = 2 * i + 1 # left = 2*i + 1r = 2 * i + 2

Java常用API之Encoders类解读

写在开头:本文用于作者学习Java常用API 我将官方文档中Collections类中所有API全测了一遍并打印了结果,日拱一卒,常看常新 在Spark中,Encoders类提供了一些静态方法用于创建不同数据类型的编码器。 首先,我遇到这样…

区块链面试题总结

1、什么是区块链? 区块链是不间断的经济交易数字分类账,可以进行编程,以记录不止金融交易,还可以记录其他有价值的东西,简单来说,区块链是一个不可变记录的分布式数据库,该数据库由计算机集群来…

今后的推进计划方针

今后的推进计划方针 信息数学物理 信息 线段树,其它的随缘。 数学 三角函数(必修3)-> 对数函数和指数函数 物理 随缘

Docker部署Nexus Maven私服并且实现远程访问Nexus界面

目录 ⛳️推荐 1. Docker安装Nexus 2. 本地访问Nexus 3. Linux安装Cpolar 4. 配置Nexus界面公网地址 5. 远程访问 Nexus界面 6. 固定Nexus公网地址 7. 固定地址访问Nexus ⛳️推荐 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默&am…

使用混沌加密图像(MATLAB)

logistic图像加密是一种基于混沌理论的加密算法,它利用混沌系统的特性,如敏感依赖于初始条件和参数的非线性动态行为,来生成密钥和进行加密解密操作。以下是逻辑加密的基本原理和过程: 基本原理: 混沌系统特性利用:逻辑加密基于混沌系统的特性,利用混沌系统的非线性、不…

Vue3:使用Pinia的$subscribe+localStorage实现数据存储

一、情景说明 我们学习Vue的时候,知道可以用watch来监听数据的变化 那么,Pinia的store中的数据发生变化,如何监听了? 这里就用到$subscribe来实现效果 这一篇,$subscribelocalStorage实现数据存储的案例&#xff0c…

Nginx开发实战三:替换请求资源中的固定数据

文章目录 1.效果预览2.下载Nginx解压并初始化3.字符串替换模块安装4.修改nginx配置文件并重启 1.效果预览 页面初始效果 页面替换后效果 说明:页面是内网的一个地址,我们通过nginx可以很便捷的将其改为外网访问,但是在外网访问这个地址后&#xff0c…

windows 使用 wsl 安装 linux 子系统

windows 使用 wsl 安装 linux 子系统 介绍使用如何启动设置基本配置安装和配置 SSH 服务 介绍 WSL(Windows Subsystem for Linux)是微软为Windows 10和Windows 11操作系统提供的一个功能,它允许用户在Windows上直接运行GNU/Linux环境。WSL提…

剑指offer--数组中重复的数字

一.题目描述 在一个长度为 n 的数组 nums 里的所有数字都在 0~n-1 的范围内。数组中某些数字是重复的,但不知道有几个数字重复了,也不知道每个数字重复了几次。请找出数组中任意一个重复的数字。 算法1.排序,然后遍历,时间复杂度O(nlogn),空…

私域流量黄金时代:SCRM赋能企业深度运营

随着互联网的不断发展,私域流量已成为企业运营中不可或缺的一部分。私域流量不仅代表着企业的忠实用户群体,更是企业进行精细化运营、提升品牌价值的重要阵地。在这样的背景下,社会化客户关系管理(SCRM)应运而生&#…

SpringBoot + Redis 实现接口限流,一个注解的事

Redis 除了做缓存,还能干很多很多事情:分布式锁、限流、处理请求接口幂等性。。。太多太多了~今天想和小伙伴们聊聊用 Redis 处理接口限流。 1. 准备工作 首先我们创建一个 Spring Boot 工程,引入 Web 和 Redis 依赖,同时考虑到接口限流一般是通过注解来标记,而注解是通…

【力扣一刷】代码随想录day28(93.复原IP地址、78.子集、90.子集II )

目录 【93.复原IP地址】中等题(偏难,坑很多) 【78.子集】中等题(偏简单) 【90.子集II】中等题 【93.复原IP地址】中等题(偏难,坑很多) 思路:以101023为例子 1、将题目…

输出100~200之间的素数(C语言)

一、运行结果&#xff1b; 二、源代码&#xff1b; # define _CRT_SECURE_NO_WARNINGS # include <stdio.h>//实现素数判断函数&#xff1b; int Prime(int number) {//初始化变量值&#xff1b;int divided 2;int JudgementCondition 0;//循环判断素数&#xff1b;wh…

(C)1007 素数对猜想

1007 素数对猜想 问题描述 输入样例&#xff1a; 20 输出样例&#xff1a; 4 解决方案&#xff1a; #include<stdio.h> #include<string.h> #include<math.h> int main(){int n,d;int a[100000];int flag,jishu0;scanf("%d",&n);memset(a,-1,…

基于51单片机甲醛浓度检测设计

基于51单片机甲醛浓度检测设计 &#xff08;仿真&#xff0b;程序&#xff0b;原理图&#xff0b;PCB&#xff0b;设计报告&#xff09; 功能介绍 具体功能&#xff1a; 1.甲醛浓度数据经过单片机处理&#xff0c;由LCD1602实时显示。 2.可通过按键设置甲醛报警阈值&#xff…

RK3568驱动指南|第十四篇 单总线-第158章DS18B20编写字符设备驱动框架

瑞芯微RK3568芯片是一款定位中高端的通用型SOC&#xff0c;采用22nm制程工艺&#xff0c;搭载一颗四核Cortex-A55处理器和Mali G52 2EE 图形处理器。RK3568 支持4K 解码和 1080P 编码&#xff0c;支持SATA/PCIE/USB3.0 外围接口。RK3568内置独立NPU&#xff0c;可用于轻量级人工…

鸿蒙原生应用开发-网络管理HTTP数据请求

一、场景介绍 应用通过HTTP发起一个数据请求&#xff0c;支持常见的GET、POST、OPTIONS、HEAD、PUT、DELETE、TRACE、CONNECT方法。 二、接口说明 HTTP数据请求功能主要由http模块提供。 使用该功能需要申请ohos.permission.INTERNET权限。 涉及的接口如下表&#xff0c;具体的…

【pytest】fixture机制

目录 概念fixture 的主要特点测试场景1. 准备和清理测试数据2. 模拟外部依赖3. 共享资源&#xff08;如数据库连接&#xff09;4. 使用内置 fixture5. 自动使用 fixture 用途 概念 fixture机制是pytest测试框架中的一个核心概念&#xff0c;它提供了一种用于处理测试所需资源的…