Diffusion Model: DDPM

本文相关内容只记录看论文过程中一些难点问题,内容间逻辑性不强,甚至有点混乱,因此只作为本人“备忘”,不建议其他人阅读。

Denoising Diffusion Probabilistic Models: https://arxiv.org/abs/2006.11239

DDPM

一、基于 x_0 已知的情况下,x_t 分布的推导过程:推导过程中,直接递归迭代即可。同时,过程中使用了 —— 两个高斯分布的和也满足高斯分布,其中均值为两个高斯分布均值的和,方差为两个高斯分布方差的和。

二、逆向过程中,q(x_{t-1}|x_t, x_0) 分布求解

进一步根据 1 中的结果可得:

公式 9 中的 z_{\theta}(x_t,t) 就是 diffusion model 需要估计的噪声均值,而噪声的方式是由 \alpha_t 或者 \beta_t 直接得到的。

三、具体训练过程:训练过程比较直接,利用 一 中的公式即可。

https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py L274def q_sample(self, x_start, t, noise=None):noise = default(noise, lambda: torch.randn_like(x_start))return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)def get_loss(self, pred, target, mean=True):if self.loss_type == 'l1':loss = (target - pred).abs()if mean:loss = loss.mean()elif self.loss_type == 'l2':if mean:loss = torch.nn.functional.mse_loss(target, pred)else:loss = torch.nn.functional.mse_loss(target, pred, reduction='none')else:raise NotImplementedError("unknown loss type '{loss_type}'")return loss# 输入参数说明:
# x_start:原始图像 x0
# t:当前扩散步数
# noise:噪声,需要注意这里的 noise 与 x_start 维度相同;具体含义是每个位置上元素都服从 0-1 高斯分布
def p_losses(self, x_start, t, noise=None):# 生成第 t 步的高斯噪声noise = default(noise, lambda: torch.randn_like(x_start))# 根据本文 一 中推导的公式得到第 t 步加噪后的图像x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)# 模型预测结果,根据具体的设置,好像可以回归加的噪声,也可以直接回归原始图像model_out = self.model(x_noisy, t)loss_dict = {}if self.parameterization == "eps":# 模型估计噪声target = noiseelif self.parameterization == "x0":# 模型直接估计原始图像target = x_startelse:raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")# 使用 L1 或者 L2 Loss 计算误差loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])log_prefix = 'train' if self.training else 'val'loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})loss_simple = loss.mean() * self.l_simple_weightloss_vlb = (self.lvlb_weights[t] * loss).mean()loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})loss = loss_simple + self.original_elbo_weight * loss_vlbloss_dict.update({f'{log_prefix}/loss': loss})return loss, loss_dict

四、具体生成(采样)过程:根据 二 中推导的公式,依次计算前一步图像的分布。

需要注意:

  1. 具体回归的均值的维度与图像维度完全相同,即图像每个位置(包括不同通道)都建模为高斯分布,均值就是无随机时图像应该有的“样子”。PS:具体是对 8 倍下采样的特征图进行采样;因此在最后需要接一个 decoder 将采样出的特征值上采样得到最终的图像。

  2. 因此,在 T=0 步得到的均值就是最终生成的图像;不过在 T> 0 步依据均值和方差进行采样,可能的原因是增加生成图像的多样性。

https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py L222# 根据本文 二 中的公式计算 x_t-1 的均值和方差
def q_posterior(self, x_start, x_t, t):posterior_mean = (extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t)posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)return posterior_mean, posterior_variance, posterior_log_variance_clippeddef p_mean_variance(self, x, t, clip_denoised: bool):model_out = self.model(x, t)if self.parameterization == "eps":x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)elif self.parameterization == "x0":x_recon = model_outif clip_denoised:x_recon.clamp_(-1., 1.)model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)return model_mean, posterior_variance, posterior_log_variance# 基于估计的图像每个位置的均值 model_mean 和方差 model_log_variance 生成对应随机图像
@torch.no_grad()
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):b, *_, device = *x.shape, x.devicemodel_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)noise = noise_like(x.shape, device, repeat_noise)# no noise when t == 0nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise# 从 T 步 ——> T-1 步 ——> ... ——> 0 步,依次进行反向估计
@torch.no_grad()
def p_sample_loop(self, shape, return_intermediates=False):device = self.betas.deviceb = shape[0]img = torch.randn(shape, device=device)intermediates = [img]for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),clip_denoised=self.clip_denoised)if i % self.log_every_t == 0 or i == self.num_timesteps - 1:intermediates.append(img)if return_intermediates:return img, intermediatesreturn img# 采样入口函数,batch_size 一次生成的图像数量
@torch.no_grad()
def sample(self, batch_size=16, return_intermediates=False):image_size = self.image_sizechannels = self.channelsreturn self.p_sample_loop((batch_size, channels, image_size, image_size),return_intermediates=return_intermediates)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/168685.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

可观测性建设实践之 - 日志分析的权衡取舍

指标、日志、链路是服务可观测性的三大支柱,在服务稳定性保障中,通常指标侧重于发现故障和问题,日志和链路分析侧重于定位和分析问题,其中日志实际上是串联这三大维度的一个良好桥梁。 但日志分析往往面临成本和效果之间的权衡问…

NET 8.0 中新的变化

1性能提升 .NET 8在整个堆栈中带来了数千项性能改进 。默认情况下会启用一种名为动态配置文件引导优化 (PGO) 的新代码生成器,它可以根据实际使用情况优化代码,并且可以将应用程序的性能提高高达 20%。现在支持的 AVX-512 指令集能够对 512 位数据向量执…

您的计算机已被.locked1勒索病毒感染?恢复您的数据的方法在这里!

尊敬的读者: 勒索病毒如.locked1已经成为网络安全的一大威胁。这类病毒通过加密用户文件,并勒索赎金以解密这些文件,给用户和组织带来了巨大的损害。本文将深入介绍.locked1勒索病毒的特点、恶意目的,以及如何恢复被其加密的数据…

PyQt6运行QTDesigner生成的ui文件程序

2024版 PyQt6 Python桌面开发 视频教程(无废话版) 玩命更新中~_哔哩哔哩_bilibili2024版 PyQt6 Python桌面开发 视频教程(无废话版) 玩命更新中~共计18条视频,包括:2024版 PyQt6 Python桌面开发 视频教程(无废话版) 玩命更新中~、第2讲 PyQt6库和工具库Q…

Selenium实战指南:安装、使用技巧和JavaScript注入案例解析

背景 ​ 最近一段时间我会重新开一个关于selenium的专题,由浅入深的给大家讲一下selenium,同时回顾一下之前学的内容,selenium可以实现模拟登录,动态数据获取,获取动态cookie等等,还有可以写一些抢p的脚本…

matlab使用plot画图坐标轴上的导数速度一点和加速度两点如何显示

一、背景 在使用matlab中的plot函数画图时,有时需要在坐标轴上显示一个点的导数项,如横坐标是时间,纵坐标是速度,也就是位置的导数 y ˙ \dot y y˙​,如下图所示,这在matlab如何操作呢? 二…

【计网 可靠数据传输RDT】 中科大笔记 (十 一)

目录 0 引言1 RDT的原理RDT的原理: 2 RDT的机制与作用2.1 重要协议停等协议(Stop-and-Wait):连续ARQ协议: 2.2 机制与作用实现机制:RDT的作用: 🙋‍♂️ 作者:海码007📜 专栏&#x…

Linux中tar命令的几个高级用法

在Linux世界中,Tar命令是一把解密归档世界的魔法工具。无论是打包、压缩还是解压,Tar命令都能胜任。本文将生动地介绍Tar命令的基本用法,并深入探讨五个常用选项,帮助读者在Linux系统中灵活运用这个强大的工具。 一、命令概述 Ta…

日本服务器访问速度和带宽有没有直接关系?

​  对于许多网站和应用程序来说,服务器的访问速度是至关重要的。用户希望能够快速加载页面、上传和下载文件,而这些都与服务器的带宽有关。那么,日本服务器的访问速度和带宽之间是否存在直接关系呢? 我们需要了解什么是带宽。带宽是指网络…

MySQL的体系结构与SQL的执行流程

文章目录 前言体系结构SQL语句的执行流程1、连接MySQL2、查询缓存3、解析SQL语句4、优化SQL语句5、执行SQL语句 总结 前言 如果你在使用MySQL时只会写sql语句的,那么你应该看一下《MySQL优化的底层逻辑》。如果你只了解到sql是如何优化的,那么你应该通过…

【数据结构】什么是栈?

🦄个人主页:修修修也 🎏所属专栏:数据结构 ⚙️操作环境:Visual Studio 2022 目录 📌栈的定义 📌元素进栈出栈的顺序 📌栈的抽象数据类型 📌栈的顺序存储结构 📌栈的链式存储结构 链栈的进…

java集合,ArrayList、LinkedList和Vector,多线程场景下如何使用 ArrayList

文章目录 Java集合1.2 流程图关系1.3 底层实现1.4 集合与数组的区别1.4.1 元素类型1.4.2 元素个数 1.5 集合的好处1.6 List集合我们以ArrayList集合为例1.7 迭代器的常用方法1.8 ArrayList、LinkedList和Vector的区别1.8.1 说出ArrayList,Vector, LinkedList的存储性能和特性1.…

汽车电子 -- 根据DBC解析CAN报文

采集的CAN报文,怎么通过DBC解析呢?有一下几种方法。 首先需要确认是CAN2.0 还是CAN FD报文。 还有是 实时解析 和 采集数据 进行解析。 一、CAN2.0报文实时解析: 1、CANTest工具 使用CAN分析仪 CANalyst-II,采集CAN报文。 使用…

JSP EL 通过 三元运算符 控制界面 标签 标签属性内容

然后 我们来说说 EL配合三元运算符的妙用 我们先这样写 <% page contentType"text/html; charsetUTF-8" pageEncoding"UTF-8" %> <%request.setCharacterEncoding("UTF-8");%> <!DOCTYPE html> <html> <head>&l…

智慧城市运营管理平台解决方案:PPT全文61页,附下载

关键词&#xff1a;智慧城市建设方案&#xff0c;智慧城市解决方案&#xff0c;智慧城市的发展前景和趋势&#xff0c;智慧城市建设内容&#xff0c;智慧城市运营管理平台 一、智慧城市运营平台建设背景 随着城市化进程的加速&#xff0c;城市面临着诸多挑战&#xff0c;如环…

概率论与数理统计中常见的随机变量分布律、数学期望、方差及其介绍

1 离散型随机变量 1.1 0-1分布 设随机变量X的所有可能取值为0与1两个值&#xff0c;其分布律为 若分布律如上所示&#xff0c;则称X服从以P为参数的(0-1)分布或两点分布。记作X~ B(1&#xff0c;p) 0-1分布的分布律利用表格法表示为: X01P1-PP 0-1分布的数学期望E(X) 0 *…

人工智能基础部分22-几种卷积神经网络结构的介绍,并用pytorch框架搭建模型

大家好&#xff0c;我是微学AI&#xff0c;今天给大家介绍一下人工智能基础部分22-几种卷积神经网络结构的介绍&#xff0c;本篇文章我将给大家详细介绍VGG16、VGG19、ResNet、SENet、MobileNet这几个卷积神经网络结构&#xff0c;以及pytorch搭建代码&#xff0c;利用通用数据…

网站监控有什么用,什么是网站监控?

网站内容监控是指采用数据采集、人工智能、云计算、机器学习、语义分析等技术&#xff0c;结合网站内容监管指标&#xff0c;针对网站内容安全、信息发布、办事服务、互动交流、功能设计、创新发展等指标进行实时监测&#xff0c;以防止网站页面内容被篡改&#xff0c;出现黄、…

5G智慧工地整体解决方案:文件全文115页,附下载

关键词&#xff1a;5G智慧工地&#xff0c;智慧工地建设方案&#xff0c;智慧工地管理平台系统&#xff0c;智慧工地建设调研报告&#xff0c;智慧工地云平台建设 一、5G智慧工地建设背景 5G智慧工地是利用5G技术、物联网、大数据、云计算、AI等信息技术&#xff0c;围绕“人…

使用git下载远程所有分支到本地

使用git下载远程所有分支到本地&#xff1a; 打开gitbash 输入以下命令即可&#xff1a; git clone git地址 cd git文件夹 git branch -r | grep -v \-> | while read remote; do git branch --track "${remote#origin/}" "$remote"; done git fetch -…