【论文阅读】Improved Denoising Diffusion Probabilistic Models

Improved Denoising Diffusion Probabilistic Models

文章目录

引用: Nichol A Q, Dhariwal P. Improved denoising diffusion probabilistic models[C]//International conference on machine learning. PMLR, 2021: 8162-8171.

论文链接: https://arxiv.org/abs/2102.09672

代码链接: https://github.com/openai/improved-diffusion

概述

去噪扩散概率模型 (DDPM) 是一类生成模型,最近已被证明可以产生出色的样本。实验表明,通过一些简单的修改,DDPM还可以在保持高样品质量的同时实现竞争性的对数似然。为了更紧密地优化变分下界 (VLB),我们使用简单的重新参数化和混合学习目标来学习逆向过程方差,该目标将 VLB 与 Ho 等人[1]的简化目标相结合,允许采样前向传递减少一个数量级,样本质量差异可以忽略不计,这对于这些模型的实际部署非常重要。使用混合目标,模型获得了比直接优化对数似然获得的对数似然更好的对数似然,并发现后一个目标在训练过程中具有更多的梯度噪声。与混合目标相比,一个简单的重要性采样技术可以减少这种噪声,并能够获得更好的对数似然。此外,论文还使用精确度和召回率来比较 DDPM 和 GAN 对目标分布的覆盖程度。最后,我们表明,这些模型的样本质量和可能性可以随着模型容量和训练计算而平滑扩展,使其易于扩展。

Improving the Log-likelihood

虽然Ho等人[1]发现DDPM可以根据FID[2]和Inception Score[3]生成高保真样本,但他们无法通过这些模型实现竞争对数可能性。对数似然是生成建模中广泛使用的指标,人们普遍认为,优化对数似然会迫使生成模型捕获数据分布的所有模式。此外,最近的工作[4]表明,对数似然的微小改进可以对样本质量和学习的特征表示产生巨大影响。因此,重要的是要探讨为什么 DDPM 似乎在这个指标上表现不佳,因为这可能表明一个根本性的缺点,例如模式覆盖率差。

为了研究不同修改的影响,在ImageNet 64×64和CIFAR-10数据集上训练具有固定超参数的固定模型架构。虽然 CIFAR-10 在此类模型中的应用更多,但论文选择研究 ImageNet 64 × 64,因为它在多样性和分辨率之间提供了良好的权衡,能够快速训练模型而不必担心过度拟合。此外,ImageNet 64×64 已在生成建模的背景下进行了广泛研究,能够将 DDPM 直接与许多其他生成模型进行比较。

Ho等人[1]的设置(在设置 σ t 2 = β t σ^2_t = β_t σt2=βt T = 1000 T = 1000 T=1000 的同时优化 L s i m p l e L_{simple} Lsimple )在 200K 训练迭代后,在 ImageNet 64 × 64 64 × 64 64×64 上实现了 3.99 3.99 3.99 b i t s / d i m bits/dim bits/dim) 的对数似然。论文在早期的实验中发现,可以通过将 T T T 1000 1000 1000 增加到 4000 4000 4000 来提高对数似然;通过此更改,对数似然提高到 3.77 3.77 3.77

Learning ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t)

在这里插入图片描述

Ho等人[1]将 ∑ θ ( x t , t ) = σ t 2 I \sum_{\theta}(x_{t}, t) = \sigma_{t}^{2}I θ(xt,t)=σt2I,其中 σ t σ_t σt 不是学习的。奇怪的是,他们发现将 σ t 2 σ^2_t σt2 固定到 β t β_t βt 产生的样品质量与将其固定到 β ~ t \tilde { \beta } _ { t } β~t 大致相同。考虑到 β t β_t βt β ~ t \tilde { \beta } _ { t } β~t 代表两个相反的极端,有理由问为什么这种选择不会影响样本。图 1 给出了一个线索,**它表明 β t β_t βt β ~ t \tilde { \beta } _ { t } β~t 几乎相等(除了接近 t = 0 t = 0 t=0),即模型正在处理难以察觉的细节。此外,随着扩散步骤数量的增加, β t β_t βt和β ̃t似乎在更多的扩散过程中彼此靠近。这表明,在无限扩散步骤的极限下, σ t σ_t σt的选择对样品质量可能完全无关紧要。换句话说,当添加更多的扩散步骤时,模型平均值 μ θ ( x t , t ) \mu _ { \theta } ( x _ { t } , t ) μθ(xt,t) ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t)更能决定分布。虽然上述论点表明,为了样本质量,固定 σ t σ_t σt 是一个合理的选择,但它并没有说明对数似然性。事实上,图2显示,扩散过程的前几步对变分下限的贡献最大。因此,似乎可以通过使用更好的 ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t) 选择来提高对数似然。为了实现这一目标,必须学习 ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t),而不会遇到 Ho 等人遇到的不稳定性。

由于图 1 显示 ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t)的理想范围非常小,因此神经网络很难直接预测 ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t),即使在对数域中也是如此。相反,我们发现最好将方差参数化为在log域 β t β_t βt β ~ t \tilde { \beta } _ { t } β~t之间的插值。具体而言,模型输出一个向量 v v v,每个维度包含一个分量,将此输出转换为方差,如下所示:

∑ θ ( x t , t ) = e x p ( v log ⁡ β t + ( 1 − v ) log ⁡ β ~ t ) \sum _ { \theta } ( x _ { t } , t ) = e x p ( v \log \beta _ { t } + ( 1 - v ) \log \tilde { \beta } _ { t } ) θ(xt,t)=exp(vlogβt+(1v)logβ~t)

没有对 v v v 施加任何约束,理论上允许模型预测插值范围之外的方差。由于 Lsimple 不依赖于 ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t),因此定义了一个新的混合目标:

L h y b r i d = L s i m p l e + λ L v l b L _ { h y b r i d } = L _ { s i m p l e } + \lambda L _ { v l b } Lhybrid=Lsimple+λLvlb

对于实验,设置 λ = 0.001 λ = 0.001 λ=0.001 以防止 L v l b L_{vlb} Lvlb 压倒 L s i m p l e L_{simple} Lsimple。按照同样的推理思路,还对 L v l b L_{vlb} Lvlb项的 μ θ ( x t , t ) \mu _ { \theta } ( x _ { t } , t ) μθ(xt,t)输出应用了停止梯度。这样,$L_{vlb} $可以引导 ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t),而 L s i m p l e L_{simple} Lsimple 仍然是影响 μ θ ( x t , t ) \mu _ { \theta } ( x _ { t } , t ) μθ(xt,t)的主要来源。

在这里插入图片描述

Improving the Noise Schedule

虽然Ho等人中使用的线性噪声调度对于高分辨率图像效果良好,但对于分辨率为64×64和32×32的图像来说,它是次优的。特别地,前向噪声处理的末尾噪声太大,因此对样本质量没有太大贡献。这可以在图3中直观地看到。这种影响的结果在图4中进行了研究,当跳过高达20%的反向扩散过程时,用线性时间表训练的模型不会变得更糟(通过FID测量)。为了解决这个问题,根据 α t ˉ \bar { \alpha _ { t } } αtˉ构建了一个不同的噪声表:

α t ˉ = f ( t ) f ( 0 ) , f ( t ) = cos ⁡ ( t / T + s 1 + s ⋅ π 2 ) 2 \bar { \alpha _ { t } } = \frac { f ( t ) } { f ( 0 ) } , f ( t ) = \cos \left( \frac { t / T + s } { 1 + s } \cdot \frac { \pi } { 2 } \right) ^ { 2 } αtˉ=f(0)f(t),f(t)=cos(1+st/T+s2π)2

β t = 1 − α ‾ t α ‾ t − 1 \beta _ { t } = 1 - \frac { \overline { \alpha } _ { t } } { \overline { \alpha } _ { t - 1 } } βt=1αt1αt

在实践中,将 β t \beta_t βt 裁剪为不大于 0.999,以防止在扩散过程结束时接近 $t = T $的奇点。

在这里插入图片描述

提出的余弦时间表被设计为在过程中具有 α t ˉ \bar { \alpha _ { t } } αtˉ的线性下降,同时在$ t = 0 $和 t = T t = T t=T 的极端附近变化很小,以防止噪声水平的突然变化。图 5 显示了两个计划的 α α α进展情况。可以看到,线性时间表以更快的速度趋向于零,破坏信息的速度比必要的要快得多。使用较小的偏移量 s s s 来防止 β t β_t βt 在$ t = 0 附近太小,因为在过程开始时有少量的噪声会使网络难以足够准确地预测。 ∗ ∗ 特别是,选择了 附近太小,因为在过程开始时有少量的噪声会使网络难以足够准确地预测。**特别是,选择了 附近太小,因为在过程开始时有少量的噪声会使网络难以足够准确地预测。特别是,选择了 s ,使得 ,使得 ,使得\sqrt { \beta _ { 0 } }$略小于像素箱大小 1 / 127.5 1/127.5 1/127.5,因此 s = 0.008 s = 0.008 s=0.008。我们特别选择使用 c o s 2 cos^2 cos2,因为它是一个具有我们正在寻找的形状的通用数学函数。这种选择是任意的,我们预计许多其他具有类似形状的函数也可以使用。**

Reducing Gradient Noise

在这里插入图片描述

在这里插入图片描述

我们希望通过直接优化 L v l b L_{vlb} Lvlb 而不是优化 L h y b r i d L_{hybrid} Lhybrid 来实现最佳的对数似然。然而, L v l b L_{vlb} Lvlb在实践中实际上很难优化,至少在多样化的 ImageNet 64×64 数据集上是这样。图 6 显示了 $L_{vlb} $和 L h y b r i d L{hybrid} Lhybrid 的学习曲线。两条曲线都是嘈杂的,但在训练时间相同的情况下,混合目标显然在训练集上实现了更好的对数似然。通过评估使用两个目标训练的模型的梯度噪声标度证实了 L v l b L_{vlb} Lvlb 的梯度比 L h y b r i d L_{hybrid} Lhybrid 的梯度大得多,如图7所示。因此,我们寻找一种方法来减少 L v l b L_{vlb} Lvlb 的方差,以便直接优化对数似然性。注意到 L v l b L_{vlb} Lvlb的不同项具有很大差异的幅度(图 2),假设采样$ t $在 $L_{vlb} $中均匀地产生不必要的噪声。为了解决这个问题,采用了重要性抽样:

L v l b = E t ∼ p t [ L t p t ] , w h e r e p t ∝ E [ L t 2 ] a n d ∑ p t = 1 L_{vlb} = E_{ t \sim p_{t} } \left[ \frac { L_{t} } { p_{t} } \right] , where p_{t} \propto \sqrt { E \left[ L_{t} ^ {2} \right] } and \sum p_{t} = 1 Lvlb=Etpt[ptLt],whereptE[Lt2] andpt=1

由于 E [ L t 2 ] E \left[ L _ { t } ^ { 2 } \right] E[Lt2] 事先是未知的,并且可能在整个训练过程中发生变化,因此我们维护每个损失项的前 10 个值的历史记录,并在训练期间动态更新。在训练开始时,均匀地采样 t t t,直到为每个 $t ∈ [0, T −1] $抽取 10 个样本。有了这个重要性抽样目标,就能够通过优化 L v l b L_{vlb} Lvlb 来实现最佳的对数似然。如图 6 所示,即 L v l b L_{vlb} Lvlb(重采样)曲线。该图还显示,重要性采样物镜的噪声比原始的均匀采样要小得多。可以发现,在直接优化噪声较小的L_{{hybrid}时,重要性采样技术没有帮助。

Improving Sampling Speed

在这里插入图片描述

为了减少从 T T T K K K 的采样步骤数,使用$ K$ 个介于 1 1 1 T T T(含)之间的均匀分布的实数,然后将每个结果数字四舍五入到最接近的整数。在图 8 中,评估了使用 4000 扩散步骤,使用 25、50、100、200、400、1000 和 4000 个采样步骤训练的$ L_{hybrid}$ 模型和 L s i m p l e L_{simple} Lsimple 模型的 FID。.我们既针对训练有素的检查点,也针对培训中途的检查点。对于 CIFAR-10,使用了 200K 和 500K 的训练迭代,对于 ImageNet 64,使用了 500K 和 1500K 的训练迭代。可以发现,当使用较少的采样步骤时,具有固定sigmas的 L s i m p l e L_{simple} Lsimple 模型在样本质量方面受到的影响要大得多,而学习sigmas的 L h y b r i d L_{hybrid} Lhybrid模型保持了较高的样本质量。使用此模型,100 个采样步骤足以为完全训练的模型实现近乎最佳的 FID。

Scaling Model Size

在这里插入图片描述

为了衡量性能如何通过训练计算进行扩展,我们在 ImageNet 64 × 64 上训练了四个不同的模型,并使用 L h y b r i d L_{hybrid} Lhybrid 目标。为了改变模型容量,在所有层上应用深度乘法器,使得第一层有 64、96、128 或 192 个通道。请注意,之前的实验在第一层中使用了 128 个通道。由于每一层的深度都会影响初始权重的规模,因此将每个模型的Adam学习率按 1 / c h a n n e l m u l t i p l i e r 1 / \sqrt{channel multiplier} 1/channelmultiplier 缩放,因此128通道模型的学习率为0.0001。图 10 显示了 FID 和 NLL 相对于理论训练计算的改进情况。FID 曲线在对数-对数图上看起来近似线性,表明 FID 根据幂律(绘制为黑色虚线)进行缩放。NLL曲线不能完全拟合幂律,这表明验证NLL的扩展方式不如FID。这可能是由多种因素引起的,例如 1) 这种类型的扩散模型出乎意料的高不可约损失,或 2) 模型过度拟合到训练分布。还注意到,这些模型通常无法实现最佳对数似然,因为它们是使用 L h y b r i d L_{hybrid} Lhybrid 而不是直接使用 L v l b L_{vlb} Lvlb 进行训练的,以保持良好的对数似然性和样本质量。

实验

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

参考文献

[1] Ho, J., Jain, A., and Abbeel, P. Denoising diffusion probabilistic models, 2020.
[2] Heusel, M., Ramsauer, H., Unterthiner, T., Nessler, B., and Hochreiter, S. Gans trained by a two time-scale update rule converge to a local nash equilibrium. Advances in Neural Information Processing Systems 30 (NIPS 2017), 2017.
[3] Salimans, T., Goodfellow, I., Zaremba, W., Cheung, V., Radford, A., and Chen, X. Improved techniques for training gans, 2016.
[4] Kaplan, J., McCandlish, S., Henighan, T., Brown, T. B., Chess, B., Child, R., Gray, S., Radford, A., Wu, J., and Amodei, D. Scaling laws for neural language models, 2020.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/749071.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Kotlin 中List,Set,Map的创建与使用

目录 1. List 的使用 1.1 不可变 List 1.2 可变 List 2. Set 的使用 2.1 不可变 Set 2.2 可变 Set 3. Map 的使用 3.1 不可变Map 3.2 可变Map 本篇主要为已经有Java基础的同学展示Kotlin语言中的List,Set,Map的创建和使用,所以Java代…

【小白刷leetcode】第15题

【小白刷leetcode】第15题 动手刷leetcode,正在准备蓝桥,但是本人算法能力一直是硬伤。。。所以做得一直很痛苦。但是不熟练的事情像练吉他一样,就需要慢速,多练。 题目描述 看这个题目,说实在看的不是很懂。索性我们直…

uniapp 对video视频组件嵌套倍速按钮

这次接了需求是要求有倍速功能,去看了文档发现并没有倍速按钮的属性,想着手写一个吧 可最后发现原生层级太高,无论怎么样都迭不上去,就只能去找插件看看咯 找了好多插件发现都不可用,因为我这是app端,有些视…

Vue组件中引入jQuery

两种在vue中引入jQuery的方式 1、普通html中使用jQuery 将jQuer的文件导入到项目中&#xff0c;然后直接使用<script src"jQuery.js"></script>即可。 <script src"jQuery.js"></script> 2、vue组件中使用jQuery 安装依赖 c…

C语言数据结构基础笔记——树、二叉树简介

1.树 树是一种 非线性 的数据结构&#xff0c;它是由 n &#xff08; n>0 &#xff09;个有限结点组成一个具有层次关系的集合。 把它叫做树是因 为它看起来像一棵倒挂的树&#xff0c;也就是说它是根朝上&#xff0c;而叶朝下的。 &#xff08;图片来源于网络&#xff09;…

【OJ】string类题目

个人主页 &#xff1a; zxctscl 如有转载请先通知 题目 1. 415字符串相加1.1 分析1.2 代码 2. 344反转字符串2.1 分析2.2 代码 3. HJ1字符串最后一个单词的长度3.1 分析3.2 代码 4. 387.字符串中的第一个唯一字符4.1 分析4.2 代码 5. 125验证回文串5.1 分析5.2 代码 1. 415字符…

wordpress被恶意搜索攻击(网址/?s=****)解决方法。

源地址&#xff1a;https://www.ctvol.com/seoomethods/1413686.html 什么叫恶意搜索攻击&#xff1f; wordpress恶意搜索攻击并不是像病毒一样的攻击&#xff0c;而是一种seo分支黑帽手段&#xff0c;通过被攻击网站搜索功能中长尾关键词来实现攻击&#xff0c;通过网址不断…

【LeetCode热题100】146. LRU 缓存(链表)

一.题目要求 请你设计并实现一个满足 LRU (最近最少使用) 缓存 约束的数据结构。 实现 LRUCache 类&#xff1a; LRUCache(int capacity) 以 正整数 作为容量 capacity 初始化 LRU 缓存int get(int key) 如果关键字 key 存在于缓存中&#xff0c;则返回关键字的值&#xff0c…

Jenkins插件Parameterized Scheduler用法

Jenkins定时触发构建的同时设定参数。可以根据不同的定时构建器设置不同参数或环境变量的值。可以设置多个参数。并结合when控制stage流程的执行。结合when和triggeredBy区分定时构建的stage和手动执行的stage。 目录 什么是Parameterized Scheduler&#xff1f;如何配置实现呢…

代码随想录训练营Day24:● 理论基础 ● 77. 组合

理论基础 回溯算法解决的问题 回溯法&#xff0c;一般可以解决如下几种问题&#xff1a; 组合问题&#xff1a;N个数里面按一定规则找出k个数的集合 切割问题&#xff1a;一个字符串按一定规则有几种切割方式 子集问题&#xff1a;一个N个数的集合里有多少符合条件的子集 排列…

yolo项目中如何训练自己的数据集

1.收集自己需要标注的图片 2.打开网站在线标注网站 2.1 点击右下角Get Start 2.2点击这里上传自己的图片 上传成功后有英文的显示 点击左边的Object Detection&#xff0c;表示用于目标检测 2.3选择新建标签还是从本地加载标签 如果是本地加载标签&#xff08;左边&#…

基本常用函数help()

Python内置函数 help()函数&#xff1a;查看对象的帮助信息 print()函数&#xff1a;用于打印输出 input()函数&#xff1a;根据输入内容返回所输入的字符串类型 format()函数&#xff1a;格式化显示 len()函数&#xff1a;返回对象的长度或项目个数 slice()函数&#xf…

26-Java访问者模式 ( Visitor Pattern )

Java访问者模式 摘要实现范例 访问者模式&#xff08;Visitor Pattern&#xff09;使用了一个访问者类&#xff0c;它改变了元素类的执行算法&#xff0c;通过这种方式&#xff0c;元素的执行算法可以随着访问者改变而改变访问者模式中&#xff0c;元素对象已接受访问者对象&a…

TouchGFX之MVP

TouchGFX用户接口遵循Model-View-Presenter&#xff08;MVP&#xff09;架构模式&#xff0c;它是Model-View-Controller&#xff08;MVC&#xff09;模式的派生模式。 两者都广泛用于构建用户接口应用。 MVP模式的主要优势是&#xff1a; 关注点分离&#xff1a;将代码分成不…

mysql 排序底层原理解析

前言 本章详细讲下排序&#xff0c;排序在我们业务开发非常常见&#xff0c;有对时间进行排序&#xff0c;又对城市进行排序的。不合适的排序&#xff0c;将对系统是灾难性的&#xff0c;这个不是危言耸听。可能有些人会想&#xff0c;对于排序mysql 是怎么实现的&#xff0c;…

Android 地图SDK 绘制点 删除 指定

问题 Android 地图SDK 删除指定绘制点 详细问题 笔者进行Android 项目开发&#xff0c;对于已标记的绘制点&#xff0c;提供撤回按钮&#xff0c;即删除绘制点&#xff0c;如何实现。 解决方案 新增绘制点 private List<Marker> markerList new ArrayList<>…

Oracle数据库:使用 bash脚本 + 定时任务 自动备份数据

Oracle数据库&#xff1a;使用 bash脚本 定时任务 自动备份数据 1、前言2、为什么需要自动化备份&#xff1f;3、编写备份脚本4、备份脚本授权5、添加定时任务6、重启 crond / 检查 crond 服务状态7、备份文件检查 &#x1f496;The Begin&#x1f496;点点关注&#xff0c;收…

AI赋能写作:AI大模型高效写作一本通

❤️作者主页&#xff1a;小虚竹 ❤️作者简介&#xff1a;大家好,我是小虚竹。2022年度博客之星评选TOP 10&#x1f3c6;&#xff0c;Java领域优质创作者&#x1f3c6;&#xff0c;CSDN博客专家&#x1f3c6;&#xff0c;华为云享专家&#x1f3c6;&#xff0c;掘金年度人气作…

Java学习笔记(15)

JDK7前时间相关类 Date时间类 Simpledateformat Format 格式化 Parse 解析 默认格式 指定格式 EE&#xff1a;表示周几 Parse&#xff1a;把字符串时间转成date对象 注意&#xff1a;创建对象的格式要和字符串的格式一样 Calendar日历类 不能创建对象 Getinstance 获取当…

【C#】【SAP2000】读取SAP2000中所有Frame对象在指定工况的温度荷载值到Grasshopper中

if (build true) {// 连接到正在运行的 SAP2000// 使用 COM 接口获取 SAP2000 的 API 对象cOAPI mySapObject (cOAPI)System.Runtime.InteropServices.Marshal.GetActiveObject("CSI.SAP2000.API.SapObject");// 获取 SAP2000 模型对象cSapModel mySapModel mySap…