Diffusion Model: DDPM

本文相关内容只记录看论文过程中一些难点问题,内容间逻辑性不强,甚至有点混乱,因此只作为本人“备忘”,不建议其他人阅读。

Denoising Diffusion Probabilistic Models: https://arxiv.org/abs/2006.11239

DDPM

一、基于 x_0 已知的情况下,x_t 分布的推导过程:推导过程中,直接递归迭代即可。同时,过程中使用了 —— 两个高斯分布的和也满足高斯分布,其中均值为两个高斯分布均值的和,方差为两个高斯分布方差的和。

二、逆向过程中,q(x_{t-1}|x_t, x_0) 分布求解

进一步根据 1 中的结果可得:

公式 9 中的 z_{\theta}(x_t,t) 就是 diffusion model 需要估计的噪声均值,而噪声的方式是由 \alpha_t 或者 \beta_t 直接得到的。

三、具体训练过程:训练过程比较直接,利用 一 中的公式即可。

https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py L274def q_sample(self, x_start, t, noise=None):noise = default(noise, lambda: torch.randn_like(x_start))return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)def get_loss(self, pred, target, mean=True):if self.loss_type == 'l1':loss = (target - pred).abs()if mean:loss = loss.mean()elif self.loss_type == 'l2':if mean:loss = torch.nn.functional.mse_loss(target, pred)else:loss = torch.nn.functional.mse_loss(target, pred, reduction='none')else:raise NotImplementedError("unknown loss type '{loss_type}'")return loss# 输入参数说明:
# x_start:原始图像 x0
# t:当前扩散步数
# noise:噪声,需要注意这里的 noise 与 x_start 维度相同;具体含义是每个位置上元素都服从 0-1 高斯分布
def p_losses(self, x_start, t, noise=None):# 生成第 t 步的高斯噪声noise = default(noise, lambda: torch.randn_like(x_start))# 根据本文 一 中推导的公式得到第 t 步加噪后的图像x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)# 模型预测结果,根据具体的设置,好像可以回归加的噪声,也可以直接回归原始图像model_out = self.model(x_noisy, t)loss_dict = {}if self.parameterization == "eps":# 模型估计噪声target = noiseelif self.parameterization == "x0":# 模型直接估计原始图像target = x_startelse:raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")# 使用 L1 或者 L2 Loss 计算误差loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])log_prefix = 'train' if self.training else 'val'loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})loss_simple = loss.mean() * self.l_simple_weightloss_vlb = (self.lvlb_weights[t] * loss).mean()loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})loss = loss_simple + self.original_elbo_weight * loss_vlbloss_dict.update({f'{log_prefix}/loss': loss})return loss, loss_dict

四、具体生成(采样)过程:根据 二 中推导的公式,依次计算前一步图像的分布。

需要注意:

  1. 具体回归的均值的维度与图像维度完全相同,即图像每个位置(包括不同通道)都建模为高斯分布,均值就是无随机时图像应该有的“样子”。PS:具体是对 8 倍下采样的特征图进行采样;因此在最后需要接一个 decoder 将采样出的特征值上采样得到最终的图像。

  2. 因此,在 T=0 步得到的均值就是最终生成的图像;不过在 T> 0 步依据均值和方差进行采样,可能的原因是增加生成图像的多样性。

https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py L222# 根据本文 二 中的公式计算 x_t-1 的均值和方差
def q_posterior(self, x_start, x_t, t):posterior_mean = (extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t)posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)return posterior_mean, posterior_variance, posterior_log_variance_clippeddef p_mean_variance(self, x, t, clip_denoised: bool):model_out = self.model(x, t)if self.parameterization == "eps":x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)elif self.parameterization == "x0":x_recon = model_outif clip_denoised:x_recon.clamp_(-1., 1.)model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)return model_mean, posterior_variance, posterior_log_variance# 基于估计的图像每个位置的均值 model_mean 和方差 model_log_variance 生成对应随机图像
@torch.no_grad()
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):b, *_, device = *x.shape, x.devicemodel_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)noise = noise_like(x.shape, device, repeat_noise)# no noise when t == 0nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise# 从 T 步 ——> T-1 步 ——> ... ——> 0 步,依次进行反向估计
@torch.no_grad()
def p_sample_loop(self, shape, return_intermediates=False):device = self.betas.deviceb = shape[0]img = torch.randn(shape, device=device)intermediates = [img]for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),clip_denoised=self.clip_denoised)if i % self.log_every_t == 0 or i == self.num_timesteps - 1:intermediates.append(img)if return_intermediates:return img, intermediatesreturn img# 采样入口函数,batch_size 一次生成的图像数量
@torch.no_grad()
def sample(self, batch_size=16, return_intermediates=False):image_size = self.image_sizechannels = self.channelsreturn self.p_sample_loop((batch_size, channels, image_size, image_size),return_intermediates=return_intermediates)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/168685.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

QFont如何设置斜体|QlineEdit设置只能输入数字|QThread::finished信号发出后worker未调用析构函数

QFont如何设置斜体 要设置 QFont 的斜体,你可以使用 setItalic() 方法。以下是一个示例代码: #include <QApplication> #include <QLabel> #include <QFont> int main(int argc, char *argv

可观测性建设实践之 - 日志分析的权衡取舍

指标、日志、链路是服务可观测性的三大支柱&#xff0c;在服务稳定性保障中&#xff0c;通常指标侧重于发现故障和问题&#xff0c;日志和链路分析侧重于定位和分析问题&#xff0c;其中日志实际上是串联这三大维度的一个良好桥梁。 但日志分析往往面临成本和效果之间的权衡问…

NET 8.0 中新的变化

1性能提升 .NET 8在整个堆栈中带来了数千项性能改进 。默认情况下会启用一种名为动态配置文件引导优化 (PGO) 的新代码生成器&#xff0c;它可以根据实际使用情况优化代码&#xff0c;并且可以将应用程序的性能提高高达 20%。现在支持的 AVX-512 指令集能够对 512 位数据向量执…

mysql union 和 union all区别?

在MySQL中&#xff0c;UNION和UNION ALL都是用于合并两个或多个SELECT语句的结果集。它们之间的主要区别在于如何处理重复记录。 UNION:UNION在合并结果集时会删除重复的记录。这意味着如果两个SELECT语句的输出结果中有相同的记录&#xff0c;那么UNION只会保留其中一个。在执…

您的计算机已被.locked1勒索病毒感染?恢复您的数据的方法在这里!

尊敬的读者&#xff1a; 勒索病毒如.locked1已经成为网络安全的一大威胁。这类病毒通过加密用户文件&#xff0c;并勒索赎金以解密这些文件&#xff0c;给用户和组织带来了巨大的损害。本文将深入介绍.locked1勒索病毒的特点、恶意目的&#xff0c;以及如何恢复被其加密的数据…

Oracle数据库语句大全

一&#xff0e;入门部分 创建表空间 create tablespace schooltbs datafile ‘D:\oracle\datasource\schooltbs.dbf’ size 10M autoextend on;删除表空间 drop tablespace schooltbs[including contents and datafiles];查询表空间基本信息 select *||tablespace_name from D…

PyQt6运行QTDesigner生成的ui文件程序

2024版 PyQt6 Python桌面开发 视频教程(无废话版) 玩命更新中~_哔哩哔哩_bilibili2024版 PyQt6 Python桌面开发 视频教程(无废话版) 玩命更新中~共计18条视频&#xff0c;包括&#xff1a;2024版 PyQt6 Python桌面开发 视频教程(无废话版) 玩命更新中~、第2讲 PyQt6库和工具库Q…

linux网络编程之UDP编程

linux网络编程之UDP编程 UDP编程模型服务端客户端 tcp与udp的区别 UDP编程模型 服务端 1.创建socket 2.构建服务器协议地址簇 3.绑定 4. 通信 sendto&#xff08;多了两个参数&#xff09; send connect #include <stdio.h> #include <sys/types.h> /*…

Selenium实战指南:安装、使用技巧和JavaScript注入案例解析

背景 ​ 最近一段时间我会重新开一个关于selenium的专题&#xff0c;由浅入深的给大家讲一下selenium&#xff0c;同时回顾一下之前学的内容&#xff0c;selenium可以实现模拟登录&#xff0c;动态数据获取&#xff0c;获取动态cookie等等&#xff0c;还有可以写一些抢p的脚本…

matlab使用plot画图坐标轴上的导数速度一点和加速度两点如何显示

一、背景 在使用matlab中的plot函数画图时&#xff0c;有时需要在坐标轴上显示一个点的导数项&#xff0c;如横坐标是时间&#xff0c;纵坐标是速度&#xff0c;也就是位置的导数 y ˙ \dot y y˙​&#xff0c;如下图所示&#xff0c;这在matlab如何操作呢&#xff1f; 二…

护士排班问题:Nurse Rostering Problem(NRP)实战并可视化页面

文章目录 护士排班NRP问题问题示例模型求解排班表可视化护士排班NRP问题 基于计算机的自动化排班有助于提高排班的效率和质量,从而使得人力资源得到有效的利用。护士排班问题并不专指对于医院护士的排班,实际上泛指这种限制条件较多的排班问题。护士排班NRP问题是一个典型的…

工厂方法模式 (Factory Method Pattern)

定义&#xff1a; 工厂方法模式&#xff08;Factory Method Pattern&#xff09;是一种创建型设计模式&#xff0c;用于解决对象创建的问题。它定义了一个创建对象的接口&#xff0c;但让子类决定实例化哪一个类。工厂方法使一个类的实例化延迟到其子类。 工厂方法模式的关键…

为什么要升级水经微图到64位?

前段时间&#xff0c;水经微图升级到了64位版。 这里为大家说明一下我们为什么要升级水经微图到64位。 顺便再分享一下我们在PC端产品上的一些调整。 为什么要升级到64位&#xff1f; 水经微图一直以来有一个巨大的问题&#xff0c;那就是矢量加载与绘制功能相当弱。 但凡…

SecureCRT出现Key exchange failed.No compatible key exchange method. 错误解决方法

SecureCRT出现Key exchange failed.No compatible key exchange method. 如下 Key exchange failed. No compatible key exchange method. The server supports these methods: curve25519-sha256,curve25519-sha256libssh.org,diffie-hellman-group-exchange-sha256解决方法&…

【计网 可靠数据传输RDT】 中科大笔记 (十 一)

目录 0 引言1 RDT的原理RDT的原理&#xff1a; 2 RDT的机制与作用2.1 重要协议停等协议&#xff08;Stop-and-Wait&#xff09;:连续ARQ协议: 2.2 机制与作用实现机制&#xff1a;RDT的作用&#xff1a; &#x1f64b;‍♂️ 作者&#xff1a;海码007&#x1f4dc; 专栏&#x…

Java8 对象List 排序

目录 1.stream流式排序 1.使用说明: 2.多字段排序 2.Collections.sort(......) 排序 1.stream流式排序 Java8提供了流式操作来简化我们的编程&#xff0c;比如排序、分组、过滤、Map操作等API&#xff0c;配合Lambda表达式给我们编程带来了很大的便利&#xff0c;这篇文章重…

react高阶成分(HOC)

使用React函数式组件写了一个身份验证的一个功能&#xff0c;示例通过高阶组件实现的一个效果展示&#xff1a; import React, { useState, useEffect } from react;// 定义一个高阶组件&#xff0c;它接受一个组件作为输入&#xff0c;并返回一个新的包装组件 const withAuth…

Qt QIODevice介绍

作者:令狐掌门 技术交流QQ群:675120140 csdn博客:https://mingshiqiang.blog.csdn.net/ 文章目录 主要功能用法示例读取数据写入数据使用数据流基于套接字的读写注意事项QIODevice 是 Qt 中所有输入/输出设备的抽象基类。它为派生类提供了一组标准的接口用于读写数据。这些派…

Linux中tar命令的几个高级用法

在Linux世界中&#xff0c;Tar命令是一把解密归档世界的魔法工具。无论是打包、压缩还是解压&#xff0c;Tar命令都能胜任。本文将生动地介绍Tar命令的基本用法&#xff0c;并深入探讨五个常用选项&#xff0c;帮助读者在Linux系统中灵活运用这个强大的工具。 一、命令概述 Ta…

网络安全面试经历

2023-11-22 X亭安全服务实习生面试 一面&#xff1a; 工作方向&#xff1a;偏蓝队 总结&#xff1a;实习蓝队面试没有什么难度&#xff0c;没有什么技术上的细节问题&#xff0c;之前准备的细节问题没有考 最后和面试官聊了聊对网安的认识&#xff0c;聊了聊二进制的知识…