机器学习扩散模型简介

一、说明  

        扩散模型的迅速崛起是过去几年机器学习领域最大的发展之一。在这本易于理解的指南中了解您需要了解的有关扩散模型的所有信息。

        扩散模型是生成模型,在过去几年中越来越受欢迎,这是有充分理由的。在 2020 年代发布的几篇开创性论文就向世界展示了 Diffusion 模型的能力,例如在图像合成方面击败 GAN [ 6 ]。最近,从业者将看到DALL-E 2(OpenAI 上个月发布的图像生成模型)中使用的扩散模型。

DALL-E 2 生成的各种图像(来源)。

        鉴于扩散模型最近的成功浪潮,许多机器学习从业者肯定对其内部工作原理感兴趣。在本文中,我们将研究扩散模型的理论基础,然后演示如何在 PyTorch 中使用扩散模型生成图像。如需对扩散模型进行技术性较低、

        更直观的解释,请随时查看关于物理学如何推动生成式 AI 的发展文章。让我们深入了解吧!

二、扩散模型 - 简介

        扩散模型是生成模型,这意味着它们用于生成与训练数据相似的数据。 从根本上讲,扩散模型的工作原理是通过连续添加高斯噪声来破坏训练数据,然后学习通过逆转该噪声过程来恢复数据。训练后,我们可以使用扩散模型来生成数据,只需将随机采样的噪声传递到学习的去噪过程中即可

扩散模型可用于从噪声生成图像(改编自源)

        更具体地说,扩散模型是使用固定马尔可夫链映射到潜在空间的潜在变量模型。该链逐渐向数据添加噪声以获得近似后验q(\textbf{x}_{1:T}|\textbf{x}_0) , 在这里\textbf{x}_1, ... , \textbf{x}_T 时间是具有相同维度的潜在变量X_0。在下图中,我们看到了图像数据的马尔可夫链。

(根据源码修改)

        最终,图像渐近变换为纯高斯噪声。训练扩散模型的目标是学习相反的过程 - 即训练p_\theta(x_{t-1}|x_t)。通过沿着这条链向后遍历,我们可以生成新的数据。

(根据源码修改)

2.1 扩散模型的好处

        如上所述,近年来对扩散模型的研究呈爆炸式增长。受非平衡热力学[ 1 ]的启发,扩散模型目前可产生最先进的图像质量,其示例如下:

(改编自来源)

        除了尖端的图像质量之外,扩散模型还具有许多其他优点,包括不需要对抗性训练。对抗性训练的困难是有据可查的;而且,如果存在具有可比性能和训练效率的非对抗性替代方案,通常最好利用它们。在训练效率方面,扩散模型还具有可扩展性和并行性的额外优势。

        虽然扩散模型几乎似乎是凭空产生结果,但有许多仔细且有趣的数学选择和细节为这些结果提供了基础,并且最佳实践仍在文献中不断发展。现在让我们更详细地了解一下支持扩散模型的数学理论。

2.2 扩散模型 - 深入探讨

        如上所述,扩散模型由前向过程(或扩散过程)和反向过程(或反向扩散过程)组成,其中数据(通常是图像)逐渐被噪声化,其中噪声被转换回来自目标分布的样本。

        当噪声水平足够低时,前向过程中的采样链转换可以设置为条件高斯。将此事实与马尔可夫假设相结合,得出前向过程的简单参数化:

        数学笔记

        在这里\beta_1, ..., \beta_T时间是一个方差表(学习的或固定的),如果表现良好,可以确保 x_T时间 对于足够大的 T ,几乎是各向同性高斯分布

给定马尔可夫假设,潜在变量的联合分布是高斯条件链转换的乘积(从源修改)。

        如前所述,扩散模型的“魔力”来自于相反的过程。在训练过程中,模型学习扭转这种扩散过程以生成新数据。从纯高斯噪声开始p(\textbf{x}_{T}) := \mathcal{N}(\textbf{x}_T, \textbf{0}, \textbf{I}),模型学习联合分布p_\theta(\textbf{x}_{0:T})作为

        其中学习高斯跃迁的时间相关参数。特别注意,马尔可夫公式断言给定的反向扩散转移分布仅取决于前一个时间步(或后一个时间步,取决于您如何看待它):

(根据源码修改)

想要了解如何在 PyTorch 中构建扩散模型?

        查看我们的 MinImagen 项目,我们在其中构建了文本到图像模型 Imagen 的最小实现!

2.3 训练

        通过寻找使训练数据的可能性最大化的逆马尔可夫转移来训练扩散模型。实际上,训练等效地包括最小化负对数似然的变分上限。

符号详细信息

我们寻求重写L_{vlb}我乙就Kullback-Leibler (KL) 散度而言。KL 散度是一种不对称统计距离度量,衡量一个概率分布P与参考分布Q的差异程度。我们有兴趣制定L_{vlb}  就 KL 散度而言,因为我们的马尔可夫链中的转移分布是高斯分布,并且高斯分布之间的 KL 散度具有闭合形式

2.4 什么是 KL 散度?

连续分布的 KL 散度的数学形式为

双条表示该函数相对于其参数不对称。

        下面您可以看到变化分布P(蓝色)与参考分布Q(红色)的 KL 散度。绿色曲线表示上述 KL 散度定义中积分内的函数,曲线下的总面积表示任意给定时刻  PQ的 KL 散度值,该值也以数字形式显示。

        铸件 L_{vlb}依照 KL 散度而言

        如前所述, [ 1 ]几乎完全可以重写L_{vlb},就KL 散度而言:

        在这里

        推导详情

        调节后向过程x_0 在L_{t-1} 结果是一种易于处理的形式,导致所有 KL 散度都是高斯分布之间的比较。这意味着可以使用封闭式表达式而不是蒙特卡罗估计来精确计算散度[ 3 ]。

2.5 型号选择

        建立了目标函数的数学基础后,我们现在需要就如何实现扩散模型做出一些选择。对于正向过程,唯一需要的选择是定义方差表,其值通常在正向过程中增加。

        对于相反的过程,我们更多地选择高斯分布参数化/模型架构。请注意扩散模型提供的高度灵活性- 我们架构的唯一要求是其输入和输出具有相同的维度。

        我们将在下面更详细地探讨这些选择的细节。

        转发过程和 L_T

        如上所述,关于前向过程,我们必须定义方差表。特别是,我们将它们设置为与时间相关的常数,忽略了它们是可以学习的事实。例如[ 3 ],一个线性时间表\beta_1=10^{-4}\beta_T=0.2可能会使用,或者可能是几何级数。

        无论选择什么特定值,方差表是固定的这一事实会导致L_{T}就我们的一组可学习参数而言,它成为一个常数,使我们能够在训练时忽略它。

逆向过程和L_{1:T-1}

现在我们讨论定义逆过程所需的选择。回想一下上面我们将逆马尔可夫转移定义为高斯:

        我们现在必须定义函数形式\pmb{\mu}_\theta或者\pmb{\Sigma}_\theta。虽然有更复杂的参数化方法\pmb{\Sigma}_\theta,我们简单设置

        也就是说,我们假设多元高斯是具有相同方差的独立高斯的乘积,方差值可以随时间变化。我们将这些方差设置为等于我们的前向过程方差表

        鉴于这个新的配方\pmb{\Sigma}_\theta, 我们有

这使我们能够转变

        其中差值的第一项是以下项的线性组合x_tx_0这取决于差异表\beta_t。该函数的确切形式与我们的目的无关,但可以在[ 3 ]中找到。

        上述比例的意义在于,最直接的参数化\mu_\theta简单地预测扩散后验平均值。重要的是,[ 3 ]的作者实际上发现训练\mu_\theta预测任何给定时间步长的噪声分量会产生更好的结果。特别地,让

        在这里

        这导致了以下替代损失函数,[ 3 ]的作者发现它可以带来更稳定的训练和更好的结果:

[ 3 ]的作者还注意到扩散模型的这种表述与基于 Langevin 动力学的分数匹配生成模型的联系。事实上,扩散模型和基于分数的模型似乎可能是同一枚硬币的两面,类似于基于波的量子力学和基于矩阵的量子力学的独立和同时发展,揭示了相同现象的两种等效公式[ 2 ] ]。

2.6 网络架构

        虽然我们的简化损失函数旨在训练模型\pmb{\epsilon}_\theta ,我们还没有定义这个模型的架构。请注意,模型的唯一要求是其输入和输出维度相同。

        考虑到这一限制,图像扩散模型通常采用类似 U-Net 的架构来实现,这也许并不奇怪。

U-Net的架构(来源)

2.7 逆向过程解码器和L_0 

        逆过程的路径由连续条件高斯分布下的许多变换组成。在反向过程结束时,回想一下我们正在尝试生成一个由整数像素值组成的图像。因此,我们必须设计一种方法来获取所有像素上每个可能像素值的离散(对数)似然。

        完成此操作的方法是将反向扩散链中的最后一个转换设置为独立的离散解码器。确定给定图像的可能性x_0给定x_1 ,我们首先在数据维度之间施加独立性:

        其中D是数据的维数,上标i表示提取一个坐标。现在的目标是确定给定像素的每个整数值的可能性有多大,定时间点轻微噪声图像中相应像素的可能值的分布t=1 :

        其中像素分布t=1 源自以下多元高斯,其对角协方差矩阵允许我们将分布拆分为单变量高斯的乘积,一个对应于数据的每个维度:

        我们假设图像由整数组成{0, 1, ..., 255} (与标准 RGB 图像一样)已线性缩放至[-1,1]。然后,我们将实际线分解为小“桶”,其中,对于给定的缩放像素值x,该范围的桶是[x-1/255, x+1/255]。给定相应像素的单变量高斯分布,像素值 x的概率x_1,是以x为中心的桶内单变量高斯分布下的面积

        下面您可以看到每个桶的面积及其均值为 0 高斯的概率,在这种情况下,对应于平均像素值为255/2(一半亮度)。红色曲线表示t=1图像中特定像素的分布,面积给出了t=0图像中对应像素值的概率。

        技术说明

        给定每个像素的t=0像素值,则p_\theta(x_0 | x_1) 只是他们的产品。该过程由以下等式简洁地概括:

        在这里

        和

给定这个方程p_\theta(x_0 | x_1) ,我们可以计算出最后一项 L_{vlb}  它没有被表述为 KL 散度:

2.8 最终目标

正如上一节中提到的,[ 3 ]的作者发现在给定时间步长预测图像的噪声分量会产生最佳结果。最终,他们使用以下目标:

因此,我们的扩散模型的训练和采样算法可以在下图中简洁地描述:

(来源)

三、扩散模型理论总结

        在本节中,我们详细介绍了扩散模型的理论。我们很容易陷入数学细节中,因此我们在下面的本节中记下最重要的几点,以便从鸟瞰的角度保持方向:

  1. 我们的扩散模型被参数化为马尔可夫链,这意味着我们的潜在变量x_1, ... , x_T 时间仅取决于前一个(或后一个)时间步长。
  2. 马尔可夫链中的转移分布是高斯分布其中前向过程需要方差调度,而反向过程参数是学习的。
  3. 扩散过程确保x_T时间对于足够大的 T,渐近分布为各向同性高斯分布。
  4. 在我们的例子中,方差表是固定的,但它也是可以学习的。对于固定的时间表,遵循几何级数可能比线性级数提供更好的结果。在任何一种情况下,方差通常都会随着系列中的时间而增加(\beta_i < \beta_j 随着i < j)。
  5. 扩散模型非常灵活,允许使用输入和输出维度相同的任何架构。许多实现都使用类似 U-Net 的架构。
  6. 训练目标是最大化训练数据的 可能性。 这表现为调整模型参数以最小化数据的负对数似然的变分上限
  7. 由于我们的马尔可夫假设,目标函数中的几乎所有项都可以转换为KL 散度。鉴于我们使用的是高斯分布,这些值变得可以计算,因此省略了执行蒙特卡洛近似的需要。
  8. 最终,使用简化的训练目标来训练预测给定潜在变量的噪声分量的函数会产生最佳且最稳定的结果。
  9. 作为反向扩散过程的最后一步,离散解码器用于获取像素值的对数似然。

        了解了扩散模型的高级概述后,让我们继续了解如何在 PyTorch 中使用扩散模型。

四、PyTorch 中的扩散模型

        虽然扩散模型尚未像机器学习中的其他旧架构/方法那样民主化,但仍然有可用的实现。在 PyTorch 中使用扩散模型的最简单方法是使用该denoising-diffusion-pytorch包,它实现了像本文中讨论的图像扩散模型。要安装该软件包,只需在终端中输入以下命令:

pip install denoising_diffusion_pytorch

4.1 最小的例子

为了训练模型并生成图像,我们首先导入必要的包:

import torch
from denoising_diffusion_pytorch import Unet, GaussianDiffusion

        接下来,我们定义网络架构,在本例中是 U-Net。该dim参数指定第一次下采样之前的特征图数量,该dim_mults参数提供该值和后续下采样的被乘数:

model = Unet(dim = 64,dim_mults = (1, 2, 4, 8)
)

        现在我们的网络架构已经定义,我们需要定义扩散模型本身。我们传入刚刚定义的 U-Net 模型以及几个参数 - 要生成的图像的大小、扩散过程中的时间步数以及 L1 和 L2 范数之间的选择。

diffusion = GaussianDiffusion(model,image_size = 128,timesteps = 1000,   # number of stepsloss_type = 'l1'    # L1 or L2
)

        现在已经定义了扩散模型,是时候进行训练了。我们生成随机数据进行训练,然后以通常的方式训练扩散模型:

training_images = torch.randn(8, 3, 128, 128)
loss = diffusion(training_images)
loss.backward()

        一旦模型训练完成,我们最终就可以使用对象sample()的方法生成图像diffusion。这里我们生成 4 张图像,考虑到我们的训练数据是随机的,这些图像只是噪声:

sampled_images = diffusion.sample(batch_size = 4)

4.2 自定义数据培训

        该denoising-diffusion-pytorch包还允许您在特定数据集上训练扩散模型。只需将'path/to/your/images'字符串替换为下面对象中的数据集目录路径Trainer(),然后更改image_size为适当的值即可。之后,只需运行代码来训练模型,然后像以前一样进行采样。请注意,PyTorch 必须在启用 CUDA 的情况下进行编译才能使用该类Trainer

from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainermodel = Unet(dim = 64,dim_mults = (1, 2, 4, 8)
).cuda()diffusion = GaussianDiffusion(model,image_size = 128,timesteps = 1000,   # number of stepsloss_type = 'l1'    # L1 or L2
).cuda()trainer = Trainer(diffusion,'path/to/your/images',train_batch_size = 32,train_lr = 2e-5,train_num_steps = 700000,         # total training stepsgradient_accumulate_every = 2,    # gradient accumulation stepsema_decay = 0.995,                # exponential moving average decayamp = True                        # turn on mixed precision
)trainer.train()

下面您可以看到从多元高斯噪声到 MNIST 数字的渐进式去噪,类似于反向扩散:

五、最后的话

        扩散模型是一种概念上简单而优雅的方法来解决数据生成问题。他们最先进的成果与非对抗性训练相结合,将他们推向了很高的高度,鉴于其刚刚起步的地位,预计未来几年将取得进一步的进步。特别是,扩散模型被发现对于DALL-E 2等尖端模型的性能至关重要。

#参考

[1]使用非平衡热力学的深度无监督学习

[2]通过估计数据分布的梯度进行生成建模

[3]去噪扩散概率模型

[4]训练基于分数的生成模型的改进技术

[5]改进的去噪扩散概率模型

[6]扩散模型在图像合成方面击败了 GAN

[7] GLIDE:使用文本引导扩散模型实现真实感图像生成和编辑

[8]使用 CLIP Latents 生成分层文本条件图像

【9】Introduction to Diffusion Models for Machine Learning (assemblyai.com)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/628083.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL系列之数据导入导出

前言 大数据与云计算作为当今时代&#xff0c;数据要素发展的“动力引擎”&#xff0c;已经走进了社会生活的方方方面。而背后承载的云服务或数据服务的高效运转&#xff0c;起了决定作用。 作为数据存储的重要工具&#xff0c;数据库的品类和特性也日新月异。从树型、网络型…

MySQL/Oracle 的 字符串拼接

目录 MySQL、Oracle 的 字符串拼接1、MySQL 的字符串拼接1.1 CONCAT(str1,str2,...) : 可以拼接多个字符串1.2 CONCAT_WS(separator,str1,str2,...) : 指定分隔符拼接多个字符串1.3 GROUP_CONCAT(expr) : 聚合函数&#xff0c;用于将多行的值连接成一个字符串。 2、Oracle 的字…

C#灵活控制多线程的状态(开始暂停继续取消)

ManualResetEvent类 ManualResetEvent是一个同步基元&#xff0c;用于在多线程环境中协调线程的执行。它提供了两种状态&#xff1a;终止状态和非终止状态。 在终止状态下&#xff0c;ManualResetEvent允许线程继续执行。而在非终止状态下&#xff0c;ManualResetEvent会阻塞线…

Python画球面投影图

天文学研究中&#xff0c;有时候需要画的并不是传统的XYZ坐标系&#xff0c;而是需要画一个形如这样子的球面投影图&#xff1a; 下面讲一下这种图怎么画 1. 首先要安装healpy包 pip install healpy 2. 然后导入包 如果之前安装过healpy&#xff0c;有的会提示不存在healpy…

【蓝桥杯日记】第一篇——如何搭建系统环境

目录 前言 环境相关文件 学生机环境-Web应用开发环境&#xff08;第十五届大赛&#xff09; 学生机环境-Java编程环境&#xff08;第十五届大赛&#xff09; 学生机环境-C/C编程环境&#xff08;第十五届大赛&#xff09; 学生机环境-Python编程环境 &#xff08;第十五届…

20240112让移远mini-PCIE接口的4G模块EC20在Firefly的AIO-3399J开发板的Android11下跑通【DTS部分】

20240112让移远mini-PCIE接口的4G模块EC20在Firefly的AIO-3399J开发板的Android11下跑通【DTS部分】 2024/1/12 16:20 https://blog.csdn.net/u010164190/article/details/79096345 [Android6.0][RK3399] PCIe 接口 4G模块 EC20 调试记录 https://blog.csdn.net/hnjztyx/artic…

【Linux】线程池实现

&#x1f4d7;线程池实现&#xff08;单例模式&#xff09; 1️⃣线程池概念2️⃣线程池代码样例3️⃣部分问题与细节&#x1f538;类成员函数参数列表中隐含的this指针&#x1f538;单例模式&#x1f538;一个失误导致的bug 4️⃣调用线程池完成任务 1️⃣线程池概念 线程池是…

【Linux驱动】设备树中指定中断 | 驱动中获得中断 | 按键中断实验

&#x1f431;作者&#xff1a;一只大喵咪1201 &#x1f431;专栏&#xff1a;《Linux驱动》 &#x1f525;格言&#xff1a;你只管努力&#xff0c;剩下的交给时间&#xff01; 目录 &#x1f3c0;在设备树中指定中断&#x1f3c0;代码中获得中断&#x1f3c0;按键中断⚽驱动…

闪存剩下内容

1&#xff1a;通过Arduino IDE向闪存文件系统上传文件 1. 下载 Arduino-ESP8266闪存文件插件程序 2&#xff1a;使用闪存文件系统建立功能更加丰富的网络服务器 1&#xff1a;在网页中加载闪存文件系统中的图片、CSS和JavaScript index.html&#xff1a;ESP8266开发板建立的网…

SpringBoot+SSM项目实战 苍穹外卖(12) Apache POI

继续上一节的内容&#xff0c;本节是苍穹外卖后端开发的最后一节&#xff0c;本节学习Apache POI&#xff0c;完成工作台、数据导出功能。 目录 工作台Apache POI入门案例 导出运营数据Excel报表 工作台 工作台是系统运营的数据看板&#xff0c;并提供快捷操作入口&#xff0c…

初识OpenCV

首先你得保证你的虚拟机Ubuntu能上网 可看 http://t.csdnimg.cn/bZs6c 打开终端输入 sudo apt-get install libopencv-dev 回车 输入密码 回车 遇到Y/N 回车 OpenCV在线文档 opencv 文档链接 点zip可以下载&#xff0c;点前面的直接在线浏览&#xff0c;但是很慢 https…

单元测试:Testing leads to failure, and failure leads to understanding

单元测试的概念可能多数读者都有接触过。作为开发人员&#xff0c;我们编写一个个测试用例&#xff0c;测试框架发现这些测试用例&#xff0c;将它们组装成测试 suite 并运行&#xff0c;收集测试报告&#xff0c;并且提供测试基础设施&#xff08;断言、mock、setup 和 teardo…

JAVAEE初阶 文件IO(一)

这里写目录标题 一. 计算机中存储数据的设备1.1 CPU1.2 内存1.3 硬盘1.4 三种存储的区别 二.文件系统2.1 相对路径2.2 绝对路径2.3 .和..的含义2.4 例子2.5 everything工具 三.文件3.1 文本文件3.2 二进制文件 四. JAVA对于文件的API4.1 getParent getName getPath getAbsolute…

Jest单元测试:玩转代码的小捉迷藏!

Jest Jest 是什么&#xff1f; Jest 是一个流行的 JavaScript 测试框架&#xff0c;专注于简化和改进代码的测试流程。它由 Facebook 开发并维护&#xff0c;具有以下特点&#xff1a; 1、易用性&#xff1a;Jest 提供了一个简单而强大的测试框架&#xff0c;使得编写和运行测…

uniapp h5 发行后 微信第二次打开网址 页面白屏

发行后把网址给客户&#xff0c;第一次可以正常登录打开&#xff0c;第二次打开白屏 原因&#xff1a;第一次打开时没有token&#xff0c;所以跳转登录页&#xff0c;可以正常访问 第二次打开时有token&#xff0c;但是网址根目录没有配置默认页面&#xff0c;所以白屏 解决…

Windows Server调整策略实现999999个远程用户用时登录

正文共&#xff1a;1234 字 23 图&#xff0c;预估阅读时间&#xff1a;2 分钟 上篇文章中&#xff08;Windows Server 2019配置多用户远程桌面登录服务器&#xff09;&#xff0c;我们主要介绍了Windows Server 2019在配置远程桌面时&#xff0c;如何通过3种方式创建本地用户账…

使用Qt连接scrcpy-server控制手机

Qt连接scrcpy-server 测试环境如何启动scrcpy-server1. 连接设备2. 推送scrcpy-server到手机上3. 建立Adb隧道连接4. 启动服务5. 关闭服务 使用QTcpServer与scrcpy-server建立连接建立连接并视频推流完整流程1. 开启视频推流过程2. 关闭视频推流过程 视频流的解码1. 数据包协议…

NVMe系统内存结构 - Meta Data

NVMe系统内存结构 - Meta Data 1 为什么需要数据保护2 Meta Data定义3 Meta Data传输方式4 常见Meta Data使用场景4.1 不带数据保护信息4.2 带数据保护信息“数据写”流程4.3 带数据保护信息“数据读”流程4.4 SSD内部加入数据保护信息4.5 SSD内部根据数据保护信息验证数据 本文…

如何在你的网站接入QQ登录?

文章目录 准备阶段申请QQ登录的权限创建应用最后上传qqlogin.php代码 准备阶段 国内服务器和备案域名需要你有张独一无二本人的身份证你正面手持身份证的图片一张100px*100px的网站图标 申请QQ登录的权限 首先访问qq互联&#xff0c;点击我直接访问 登陆完成后我们点击面的…

bash shell基础命令(一)

1.shell启动 shell提供了对Linux系统的交互式访问&#xff0c;通常在用户登录终端时启动。系统启动的shell程序取决于用户账户的配置。 /etc/passwd/文件包含了所有用户的基本信息配置&#xff0c; $ cat /etc/passwd root:x:0:0:root:/root:/bin/bash ...例如上述root账户信…