【原理+使用】DeepCache: Accelerating Diffusion Models for Free

论文:arxiv.org/pdf/2312.00858

代码:horseee/DeepCache: [CVPR 2024] DeepCache: Accelerating Diffusion Models for Free (github.com)

介绍


DeepCache是一种新颖的无训练且几乎无损的范式,从模型架构的角度加速了扩散模型。DeepCache利用 扩散模型顺序去噪步骤中观察到的固有时间冗余,缓存和检索相邻去噪阶段的特征,从而减少冗余计算。利用U-Net的特性,重用高级特征,同时以低成本的方式更新低级特征。将 Stable Diffusion v1.5 加速了 2.3 倍,CLIP 分数仅下降了 0.05 倍,LDM-4-G(ImageNet) 加速了 4.1 倍,FID 降低了 0.22。

动机:

由于顺序去噪过程和繁琐的模型尺寸,训练扩散模型会产生大量的计算成本。本文希望在没有额外训练的情况下,减少每个去噪步骤的计算开销,从而实现对扩散模型的无成本压缩。

背景:

反向扩散过程的加速。反向扩散过程的固有性质减慢了推理速度。目前的研究主要集中在两种加速扩散模型推理的方法上:

  1. 优化采样效率。侧重于减少采样步骤的数量。DDIM、一致性模型将随机噪声转换为初始图像,只需要进行一次模型评估。
  2. 优化结构效率。减少每个采样步骤的推理时间。

U-Net的高级和低级特征。由于跳跃式连接,UNet具有很强的合并低级和高级特征的能力。U-Net构建在堆叠的下采样和上采样块上,将输入图像编码为高级表示,然后对其进行解码,用于下游任务。表示为的块对,通过额外的跳过路径连接,直接将低级的信息从Di转发到Ui。在U-Net体系结构的前向传播过程中,数据通过两条路径并发地遍历:主分支和跳过分支。这些分支汇聚在一个连接模块,主分支提供处理过的高级特征,这些特征来自前面的上采样块Ui+1,而跳过分支提供来自对称块Di的相应特征。因此,U-Net模型的核心是来自跳过分支的低级特征和来自主分支的高级特征的连接:

原理

序列去噪中的特征冗余

去噪过程中的相邻步骤在高级特征上表现出显著的时间相似性。

图2实验揭示了两个主要观点:

  1. 在去噪过程中,相邻步骤之间,存在明显的时间特征相似性,表明连续步骤之间的变化通常较小。
  2. 无论使用哪种扩散模型,如稳定扩散、LDM和DDPM,对于每个时间步长,至少有10%的相邻时间步长与当前步长表现出高度相似(>0.95),这表明某些高级特征以渐进的速度变化。

每次计算,得到的特征都与前一步相似,存在大量冗余计算,产生边际效益。本文目标是利用这一特性来加速去噪过程。

扩散模型的深度缓存

DeepCache利用反向扩散过程中步骤之间的时间冗余来加速推理。从计算机系统中的缓存机制中获得灵感,结合了为随时间变化最小的元素设计的存储组件。应用于扩散模型,通过缓存那些变化缓慢的特征,来消除冗余计算,从而无需在后续步骤中重复计算

实现重点为U-Net中的跳过连接,它本质上提供了双路径优势:主分支需要大量的计算来遍行整个网络,而跳过分支只需要通过一些浅层,从而产生非常小的计算负载。主要分支中突出的特征相似性允许重用已经计算的结果,而不是为所有时间步重复计算。

去噪中的可缓存特性。

在两个连续时间步长 𝑡 和 𝑡−1 之间,根据反向过程,𝑥𝑡−1 将基于先前的结果 𝑥𝑡​ 进行条件生成。实验:首先生成 𝑥𝑡​,计算跨整个U-Net进行。为了获得下一个输出 𝑥𝑡−1,我们检索在先前时间步长 𝑡 中生成的高层次特征。即,考虑U-Net中的一个跳跃分支 𝑚,它连接 𝐷𝑚​ 和 𝑈𝑚​,在时间 𝑡 从先前的上采样块缓存特征图:

这是时间步长 𝑡 的主分支中的特征。这些缓存的特征将在后续推理中使用。

在下一个时间步长 𝑡−1 中,推理并不在整个网络上进行,只计算 m-th 跳跃分支中所需的部分,并用缓存中的特征替代主分支的计算。因此,时间步长 𝑡−1 中 𝑈𝑡−1𝑚​ 的输入可表示为:

𝐷𝑡−1𝑚​ 代表 m-th 下采样块的输出,如果选择一个较小的 𝑚,则只包含几层。例如,如果我们在第一层执行 DeepCache 并选择 𝑚=1,则只需要执行一个下采样块以获得。至于第二个特征 ​,由于可以简单地从缓存中检索,因此不需要额外的计算成本。过程如图3.

在第t - 1步,通过重用第t步缓存的特征,来生成xt - 1,并且为了更有效的推理,不执行D2, D3, U2, U3块。

扩展到1:N推理。缓存的特征计算一次,可以在后续的N−1步中重用,以取代原始的。对于所有去噪的T步,执行完全推理的时间步长序列为:

非均匀1:N推理。

基于1:N策略,在假定高级特征在连续N步中不变的前提下,成功地加速了扩散推理。然而,并非总是如此,特别是对于N,如图2(c)所示,特征的相似性并不是在所有步骤中都保持不变。对于像LDM这样的模型,特征的时间相似性会在去噪过程中显著降低40%左右。

因此,对于非均匀的1:N推理,我们倾向于对那些与相邻步骤相似度相对较小的步骤进行更多采样。在这里,执行完整推理的时间步长序列变为:

使用

import torch
from diffusers import StableDiffusionPipeline
from DeepCache import DeepCacheSDHelper# 加载 Stable Diffusion 模型
pipe = StableDiffusionPipeline.from_pretrained('runwayml/stable-diffusion-v1-5', torch_dtype=torch.float16).to("cuda:0")# 创建 DeepCacheSDHelper 对象
helper = DeepCacheSDHelper(pipe=pipe)# 设置缓存参数
helper.set_params(cache_interval=3,cache_branch_id=0,
)# 启用缓存机制
helper.enable()# 定义输入提示词
prompt = "a beautiful landscape with mountains and rivers"# 生成图像
deepcache_image = pipe(prompt,output_type='pt'
).images[0]# 禁用缓存机制
helper.disable()

 库:

diffusers==0.24.0
transformer

仅需要用DeepCache提供的Pipeline替换Diffusers库的Pipeline,即可实现扩散模型加速。目前支持 StableDiffusionPipeline 可以加载的模型。可以通过参数指定模型名称。

尝试1:将 DeepCacheSDHelper 应用于整个 pipeline,并确保缓存机制只启用一次

 pipe = Pose2VideoPipeline(vae=vae,image_encoder=image_enc,reference_unet=reference_unet,denoising_unet=denoising_unet,pose_guider=pose_guider,scheduler=scheduler,)pipe = pipe.to("cuda", dtype=weight_dtype)# 初始化 DeepCacheSDHelperhelper = DeepCacheSDHelper(pipe=pipe)# 设置缓存参数helper.set_params(cache_interval=3,cache_branch_id=0,)# 启用缓存机制helper.enable()

报错:

AttributeError: 'Pose2VideoPipeline' object has no attribute 'unet'

报错信息显示 Pose2VideoPipeline 对象没有 unet 属性,这说明 DeepCacheSDHelper 无法找到所需的 UNet 模型。要解决这个问题,必须确保传递给 DeepCacheSDHelper 的 pipeline 具有 unet 属性,并且该属性指向实际的 UNet 模型。

而 Pose2VideoPipeline 包含多个 UNet 模型( reference_unetdenoising_unet),需要对 DeepCacheSDHelper 进行修改,使其能够处理这种情况。一种解决方法是扩展 DeepCacheSDHelper 以接受多个 UNet 模型。解决方案:修改DeepCacheSDHelper类,pipe 和包含所有 UNet 模型的列表传递给 DeepCacheSDHelper:

    pipe = Pose2VideoPipeline(vae=vae,image_encoder=image_enc,reference_unet=reference_unet,denoising_unet=denoising_unet,pose_guider=pose_guider,scheduler=scheduler,)pipe = pipe.to("cuda", dtype=weight_dtype)# Initialize DeepCacheSDHelper with both UNet modelshelper = DeepCacheSDHelper(pipe=pipe, unets=[reference_unet, denoising_unet])helper.set_params(cache_interval=3,cache_branch_id=0,)helper.enable()

尝试2:分别对 reference_unetdenoising_unet 初始化并启用 DeepCacheSDHelper

 reference_unet = UNet2DConditionModel.from_pretrained(config.pretrained_base_model_path,subfolder="unet",).to(dtype=weight_dtype, device="cuda")# Import the DeepCacheSDHelperhelper = DeepCacheSDHelper(reference_unet=reference_unet)helper.set_params(cache_interval=3,cache_branch_id=0,)helper.enable()inference_config_path = config.inference_configinfer_config = OmegaConf.load(inference_config_path)denoising_unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_base_model_path,config.motion_module_path,subfolder="unet",unet_additional_kwargs=infer_config.unet_additional_kwargs,).to(dtype=weight_dtype, device="cuda")helper = DeepCacheSDHelper(denoising_unet=denoising_unet)helper.set_params(cache_interval=3,cache_branch_id=0,   # 指定缓存的分支 ID,上下两个unet是否需要不同分支?)helper.enable()

TypeError: DeepCacheSDHelper.__init__() got an unexpected keyword argument 'reference_unet', DeepCacheSDHelper 需要对 pipeline 中所有相关的 UNet 模型进行统一处理,而不是分别处理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/41528.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【因果推断】优惠券政策对不同店铺的影响

这次依然是用之前rossmann店铺竞赛的数据集。 之前的数据集探索处理在这里已经做过了,此处就不再赘述了CSDN链接 数据集地址:竞赛链接 这里探讨数据集中Promo2对于每家店铺销售额的影响。其中,Promo2是一个基于优惠券的邮寄活动,发…

SQL Server 2022 中的 Tempdb 性能改进非常显著

无论是在我的会话中还是在我写的博客中,Tempdb 始终是我的话题。然而,当谈到 SQL Server 2022 中引入的重大性能变化时,我从未如此兴奋过。他们解决了我们最大的性能瓶颈之一,即系统页面闩锁并发。 在 SQL Server 2019 中&#x…

Go语言如何入门,有哪些书推荐?

Go 语言之所以如此受欢迎,其编译器功不可没。Go 语言的发展也得益于其编译速度够快。 对开发者来说,更快的编译速度意味着更短的反馈周期。大型的 Go 应用程序总是能在几秒钟之 内完成编译。而当使用 go run编译和执行小型的 Go 应用程序时,其…

如何利用Github Action实现自动Merge PR

我是蚂蚁背大象(Apache EventMesh PMC&Committer),文章对你有帮助给项目rocketmq-rust star,关注我GitHub:mxsm,文章有不正确的地方请您斧正,创建ISSUE提交PR~谢谢! Emal:mxsmapache.com 1. 引言 GitHub Actions 是 GitHub 提供的一种强大而灵活的自…

SSM中小学生信息管理系统 -计算机毕业设计源码02677

摘要 随着社会的发展和教育的进步,中小学生信息管理系统成为学校管理的重要工具。本论文旨在基于SSM框架,采用Java编程语言和MySQL数据库,设计和开发一套高效、可靠的中小学生信息管理系统。中小学生信息管理系统以学生为中心,通过…

赤壁之战的烽火台 - 观察者模式

“当烽火连三月,家书抵万金;设计模式得其法,千军如一心。” 在波澜壮阔的三国历史长河中,赤壁之战无疑是一场改变乾坤的重要战役。而在这场战役中,一个看似简单却至关重要的系统发挥了巨大作用——烽火台。这个古老的…

OpenAI的崛起:从梦想到现实

OpenAI的崛起不仅是人工智能领域的重大事件,也是科技史上一个引人注目的篇章。本文将深入探讨OpenAI从创立到如今的演变过程,分析其成功的关键因素,以及未来的发展方向。 一、OpenAI的初创期:理想主义与混乱并存 OpenAI成立于20…

插入排序——C语言

假设我们现在有一个数组,对它进行排序,插入排序的算法如同它的名字一样,就是将元素一个一个插入到合适的位置,那么,该如何做呢? 如果我们要从小到大进行排序的话,步骤如下: 1.对于…

区间最值问题-RQM(ST表,线段树)

1.ST表求解 ST表的实质其实是动态规划&#xff0c;下面是区间最小的递归公式&#xff0c;最大只需将min改成max即可 f[i][j] min(f[i][j - 1], f[i (1 << j - 1)][j - 1]); 二维数组的f[i][j]表示从i开始连续2*j个数的最小/大值。 例如&#xff1a;我们给出一个数组…

uniapp启动安卓模拟器mumu

mumu模拟器下载 ADB&#xff1a; android debug bridge &#xff0c; 安卓调试桥&#xff0c;是一个多功能的命令行工具&#xff0c;他使你能够与连接的安卓设备进行交互 # adb连接安卓模拟器 adb connect 127.0.0.1:port # 查看adb设备 adb deviceshubuilderx 有内置的adb&a…

MSPM0G3507——滴答定时器和普通定时

滴答定时器定时&#xff1a;&#xff08;放在主函数即可&#xff09; volatile unsigned int delay_times 0;//搭配滴答定时器实现的精确ms延时 void delay_ms(unsigned int ms) {delay_times ms;while( delay_times ! 0 ); } //滴答定时器中断 void SysTick_Handler(…

Python28-7.4 独立成分分析ICA分离混合音频

独立成分分析&#xff08;Independent Component Analysis&#xff0c;ICA&#xff09;是一种统计与计算技术&#xff0c;主要用于信号分离&#xff0c;即从多种混合信号中提取出独立的信号源。ICA在处理盲源分离&#xff08;Blind Source Separation&#xff0c;BSS&#xff0…

【机器学习】(基础篇一) —— 什么是机器学习

什么是机器学习 本系列博客为你从机器学习的介绍开始&#xff0c;使用大量的代码实战和验证&#xff0c;最终帮助你完全掌握什么是机器学习 人工智能、机器学习和深度学习的关系 人工智能&#xff08;Artificial Intelligence&#xff0c;AI&#xff09;&#xff1a;是一门研…

Java多线程不会?一文解决——

方法一 新建类如MyThread继承Thread类重写run()方法再通过new MyThread类来新建线程通过start方法启动新线程 案例&#xff1a; class MyThread extends Thread {public MyThread(String name) {super(name);}Overridepublic void run() {for(int i0;i<10;i){System.out.…

react dangerouslySetInnerHTML将html字符串以变量方式插入页面,点击后出现编辑状态

1.插入变量 出现以下编辑状态 2.解决 给展示富文本的标签添加css样式 pointerEvents: none

那些年背过的面试题——MySQL篇

本文是技术人面试系列 MySQL 篇&#xff0c;面试中关于 MySQL 都需要了解哪些基础&#xff1f;一文带你详细了解&#xff0c;欢迎收藏&#xff01; WhyMysql&#xff1f; NoSQL 数据库四大家族 列存储 Hbase K-V 存储 Redis 图像存储 Neo4j 文档存储 MongoDB 云存储 OSS …

AI大模型的智能心脏:向量数据库的崛起

在人工智能的飞速发展中,一个关键技术正悄然成为AI大模型的智能心脏——向量数据库。它不仅是数据存储和管理的革命性工具,更是AI技术突破的核心。随着AI大模型在各个领域的广泛应用,向量数据库的重要性日益凸显。 01 技术突破:向量数据库的内在力量 向量数据库以其快速检索…

RNN、LSTM与GRU循环神经网络的深度探索与实战

循环神经网络RNN、LSTM、GRU 一、引言1.1 序列数据的迷宫探索者&#xff1a;循环神经网络&#xff08;RNN&#xff09;概览1.2 深度探索的阶梯&#xff1a;LSTM与GRU的崛起1.3 撰写本博客的目的与意义 二、循环神经网络&#xff08;RNN&#xff09;基础2.1 定义与原理2.1.1 RNN…

【Python】组合数据类型:序列,列表,元组,字典,集合

个人主页&#xff1a;【&#x1f60a;个人主页】 系列专栏&#xff1a;【❤️Python】 文章目录 前言组合数据类型序列类型序列常见的操作符列表列表操作len()append()insert()remove()index()sort()reverse()count() 元组三种序列类型的区别 集合类型四种操作符集合setfrozens…