stable diffusion代码学习笔记

前言:本文没有太多公式推理,只有一些简单的公式,以及公式和代码的对应关系。本文仅做个人学习笔记,如有理解错误的地方,请指出。

本文包含stable diffusion入门文献和不同版本的代码。

文献资源

  1. 本文学习的代码;
  2. 相关文献:
  • Denoising Diffusion Probabilistic Models : DDPM,这个是必看的,推推公式
  • Denoising Diffusion Implicit Models :DDIM,对 DDPM 的改进
  • Pseudo Numerical Methods for Diffusion Models on Manifolds :PNMD/PLMS,对 DDPM 的改进
  • High-Resolution Image Synthesis with Latent Diffusion Models :Latent-Diffusion,必看
  • Neural Discrete Representation Learning : VQVAE,简单翻了翻,示意图非常形象,很容易了解其做法

代码资源

  1. stable diffusion v1.1-v1.4, https://github.com/CompVis/stable-diffusion
  2. stable diffusion v1.5,https://github.com/runwayml/stable-diffusion
  3. stable diffusion v2,https://github.com/Stability-AI/stablediffusion
  4. stable diffusion XL,https://github.com/Stability-AI/generative-models

前向过程(训练)

  • 输入一张图片+随机噪声,训练unet,网络预测图片加上的噪声

反向过程(推理)

  • 给个随机噪声,不断迭代去噪,输出一张图片

总体流程

  • 输入的prompt经过clip encoder编码成(3+3,77,768)特征,正负prompt各3个,默认negative prompt为空‘’,解码时正的和负的latent图片用公式计算一下才是最终结果;time step通过linear层得到(3+3,1280)特征;把prompt和time ebedding和随机生成的图片放入unet,得到的就是我们要的图片。

采样流程 text2img

  • 该函数在PLMSSampler中,输入x(噪声,(3,4,64,64))-----c(输入的prompt,(3,77,768)----t (输入的time step,第几次去噪(3,)。把这三个东西输入unet,得到预测的噪声e_t。
 def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):b, *_, device = *x.shape, x.devicedef get_model_output(x, t):if unconditional_conditioning is None or unconditional_guidance_scale == 1.:e_t = self.model.apply_model(x, t, c)else:x_in = torch.cat([x] * 2)t_in = torch.cat([t] * 2)c_in = torch.cat([unconditional_conditioning, c]) # 积极消极的prompt,解码时按照公式减去消极prompt的图像e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)if score_corrector is not None:assert self.model.parameterization == "eps"e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)return e_talphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphasalphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prevsqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphassigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmasdef get_x_prev_and_pred_x0(e_t, index):# select parameters corresponding to the currently considered timestepa_t = torch.full((b, 1, 1, 1), alphas[index], device=device)a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)# current prediction for x_0pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()if quantize_denoised:pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)# direction pointing to x_tdir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_tnoise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperatureif noise_dropout > 0.:noise = torch.nn.functional.dropout(noise, p=noise_dropout)x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noisereturn x_prev, pred_x0e_t = get_model_output(x, t) # 模型预测的噪声if len(old_eps) == 0:# Pseudo Improved Euler (2nd order)x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) # 输入噪声减去预测噪声得到新的噪声,当前预测的latent图片e_t_next = get_model_output(x_prev, t_next)e_t_prime = (e_t + e_t_next) / 2 # 两次噪声的均值?elif len(old_eps) == 1:# 2nd order Pseudo Linear Multistep (Adams-Bashforth)e_t_prime = (3 * e_t - old_eps[-1]) / 2elif len(old_eps) == 2:# 3nd order Pseudo Linear Multistep (Adams-Bashforth)e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12elif len(old_eps) >= 3:# 4nd order Pseudo Linear Multistep (Adams-Bashforth)e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)return x_prev, pred_x0, e_t
  • 接下来看公式:
    在这里插入图片描述
  • 网络得到e_t后,进入到get_x_prev_and_pred_x0函数,可以看到pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()就是上述公式,也就是说网络的预测结果通过公式计算,我们可以得到预测的pred_x0原始图片和前一刻的噪声图像x_prev
        def get_x_prev_and_pred_x0(e_t, index):# select parameters corresponding to the currently considered timestepa_t = torch.full((b, 1, 1, 1), alphas[index], device=device)a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)# current prediction for x_0pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()if quantize_denoised:pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)# direction pointing to x_tdir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_tnoise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperatureif noise_dropout > 0.:noise = torch.nn.functional.dropout(noise, p=noise_dropout)x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noisereturn x_prev, pred_x0
  • 前一刻的噪声图像的推理公式如图:
    在这里插入图片描述
    在这里插入图片描述

  • 得到了上一刻的噪声图片x_prev后(也就是函数返回的img),继续迭代,最终生成需要的图片。
    在这里插入图片描述

额外说明

这部分代码应该就是PLMS加速采样用的,论文中有公式推理
在这里插入图片描述
另外,还有一些参数是训练时候保存的,betas逐渐增大,用来控制噪声的强度。变量名解析 log_one_minus_alphas_cumprod其实就是log(1-alpha(右下角t)(头上直线)),没有带prev的都是当前时刻t,带prev的是前一时刻t-1。

在这里插入图片描述

参考文献:

https://blog.csdn.net/Eric_1993/article/details/129600524?spm=1001.2014.3001.5502
https://zhuanlan.zhihu.com/p/630354327

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/617133.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android基于Matrix绘制PaintDrawable设置BitmapShader,以手指触点为中心显示原图像圆图,Kotlin(2)

Android基于Matrix绘制PaintDrawable设置BitmapShader,以手指触点为中心显示原图像圆图,Kotlin(2) 在 https://zhangphil.blog.csdn.net/article/details/135374279 基础上,增加一个功能,当手指在上面的图片…

【DevOps-08-3】Jenkins容器内部使用Docker

一、简要描述 构建镜像和发布镜像到harbor都需要使用到docker命令。而在Jenkins容器内部安装Docker官方推荐直接采用宿主机带的Docker即可。 设置Jenkins容器使用宿主机Docker。 二、配置和操作步骤 1、修改宿主机docker.sock权限 # 修改docker.sock 用户和用户组都为root $ …

并发,并行,线程与UI操作

并行和并发是计算机领域中两个相关但不同的概念。 并行(Parallel)指的是同时执行多个任务或操作,它依赖于具有多个处理单元的系统。在并行计算中,任务被分成多个子任务,并且这些子任务可以同时在不同的处理单元上执行…

工智能基础知识总结--聚类算法

什么是聚类算法 聚类是一种机器学习技术,它涉及到数据点的分组。给定一组数据点,我们可以使用聚类算法将每个数据点划分为一个特定的组。理论上,同一组中的数据点应该具有相似的属性和/或特征,而不同组中的数据点应该具有高度不同的属性和/或特征。聚类是一种无监督学习的方…

DEJA_VU3D - Cesium功能集 之 112-获取圆节点(1)

前言 编写这个专栏主要目的是对工作之中基于Cesium实现过的功能进行整合,有自己琢磨实现的,也有参考其他大神后整理实现的,初步算了算现在有差不多实现小140个左右的功能,后续也会不断的追加,所以暂时打算一周2-3更的样子来更新本专栏(每篇博文都会奉上完整demo的源代码…

2024年甘肃省职业院校技能大赛信息安全管理与评估 样题二 模块二

竞赛需要完成三个阶段的任务,分别完成三个模块,总分共计 1000分。三个模块内容和分值分别是: 1.第一阶段:模块一 网络平台搭建与设备安全防护(180 分钟,300 分)。 2.第二阶段:模块二…

探索“城堡世界”APP:你的城堡,你的故事

在快节奏的现代生活中,我们常常渴望有一个属于自己的世界,可以随心所欲地创造和讲述故事。今天,我们要为大家介绍的是一款名为“城堡世界”的APP,它将带给你实现这个梦想的机会。 “城堡世界”是一款独特的APP,它允许用…

vue-virtual-scroll-list(可单选、多选、搜索查询、创建条目)

element-ui-解决下拉框数据量过多问题(vue-virtual-scroll-list)_element-ui下拉框数据太多如何优化-CSDN博客 的升级版 参考链接:封装el-select,实现虚拟滚动,可单选、多选、搜索查询、创建条目-CSDN博客 1.封装组件 select.v…

redis获取过期时间

03,redisTemplate_redistemplate 获取剩余时间-CSDN博客 11.返回当前key所对应的剩余过期时间 redisTemplate.getExpire(key);1 12.返回剩余过期时间并且指定时间单位 redisTemplate.getExpire(key, unit);

深入理解区间合并:让数字之间的故事更加有序

嗨,亲爱的读者朋友们!在今天的这篇博客中,我们将深入探讨一个在编程和算法中常见但又很有趣的话题——区间合并。这个话题可能让一些初学者感到头疼,但我会尽力通过生动的例子和简单的解释来让你对它有一个清晰的认识。 引子&…

HTTP 常见协议:选择正确的协议,提升用户体验(下)

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云…

Vulnhub-HACKSUDO: PROXIMACENTAURI渗透

文章目录 一、前言1、靶机ip配置2、渗透目标3、渗透概括 开始实战一、信息获取二、端口敲门三、web密码爆破四、getShell五、获取新用户六、提权 一、前言 由于在做靶机的时候,涉及到的渗透思路是非常的广泛,所以在写文章的时候都是挑重点来写&#xff0…

也谈人工智能——AI科普入门

文章目录 1. 科普入门人工智能的定义人工智能的类型 - 弱 AI 与强 AI人工智能、深度学习与机器学习人工智能的应用和使用场景语音识别计算机视觉客户服务建议引擎数据分析网络安全 行业应用人工智能发展史![img](https://img-blog.csdnimg.cn/img_convert/66aeaaeac6870f432fc4…

error: undefined reference to ‘cv::imread(std::__ndk1::basic_string<char

使用android studio编译项目时,由于用到了 cv::imread()函数,编译时却报错找不到该函数的定义。 cv::imread一般是在highgui.hpp中定义,因此我加上了该头文件: #include “opencv2/highgui/highgui.hpp” 但…

webtim开源即时通讯平台第三版发布

webtim是Web开源通讯平台。服务器是 Tim 。前端使用tim的js客户端 timjs 调用tim服务器接口渲染页面。 webtim开发目的是通过界面来显式表达tim接口功能。tim是去中心化的分布式IM引擎。支持多种基础通讯模式,对端到端的数据流传输支持非常全面,几乎涵…

【信息安全】hydra爆破工具的使用方法

hydra简介 hydra又名九头蛇,与burp常规的爆破模块不同,hydra爆破的范围更加广泛,可以爆破远程桌面连接,数据库这类的密码。他在kali系统中自带。 参数说明 -l 指定用户名 -L 指定用户名字典文件 -p 指定密码 -P 指…

Java十大经典算法—KMP

字符串匹配问题: 1.暴力匹配 public class ViolenceMatch {public static void main(String[] args) {String str1 "硅硅谷 尚硅谷你尚硅 尚硅谷你尚硅谷你尚硅你好";String str2 "尚硅谷你尚硅你好";int index violenceMatch(str1, str2);S…

力扣_数组29—根据前序与中序遍历序列构建二叉树、根据中序与后序遍历序列构建二叉树

题目 给定两个整数数组 p r e o r d e r preorder preorder 和 i n o r d e r inorder inorder ,其中 p r e o r d e r preorder preorder 是二叉树的先序遍历, i n o r d e r inorder inorder 是同一棵树的中序遍历,请构造二叉树并返回…

数模学习day11-系统聚类法

本文参考辽宁石油化工大学于晶贤教授的演示文档聚类分析之系统聚类法及其SPSS实现。 目录 1.样品与样品间的距离 2.指标和指标间的“距离” 相关系数 夹角余弦 3.类与类间的距离 (1)类间距离 (2)类间距离定义方式 1.最短…

数据科学竞赛平台推荐

✅作者简介:人工智能专业本科在读,喜欢计算机与编程,写博客记录自己的学习历程。 🍎个人主页:小嗷犬的个人主页 🍊个人网站:小嗷犬的技术小站 🥭个人信条:为天地立心&…