【论文阅读】Improved Denoising Diffusion Probabilistic Models

Improved Denoising Diffusion Probabilistic Models

文章目录

引用: Nichol A Q, Dhariwal P. Improved denoising diffusion probabilistic models[C]//International conference on machine learning. PMLR, 2021: 8162-8171.

论文链接: https://arxiv.org/abs/2102.09672

代码链接: https://github.com/openai/improved-diffusion

概述

去噪扩散概率模型 (DDPM) 是一类生成模型,最近已被证明可以产生出色的样本。实验表明,通过一些简单的修改,DDPM还可以在保持高样品质量的同时实现竞争性的对数似然。为了更紧密地优化变分下界 (VLB),我们使用简单的重新参数化和混合学习目标来学习逆向过程方差,该目标将 VLB 与 Ho 等人[1]的简化目标相结合,允许采样前向传递减少一个数量级,样本质量差异可以忽略不计,这对于这些模型的实际部署非常重要。使用混合目标,模型获得了比直接优化对数似然获得的对数似然更好的对数似然,并发现后一个目标在训练过程中具有更多的梯度噪声。与混合目标相比,一个简单的重要性采样技术可以减少这种噪声,并能够获得更好的对数似然。此外,论文还使用精确度和召回率来比较 DDPM 和 GAN 对目标分布的覆盖程度。最后,我们表明,这些模型的样本质量和可能性可以随着模型容量和训练计算而平滑扩展,使其易于扩展。

Improving the Log-likelihood

虽然Ho等人[1]发现DDPM可以根据FID[2]和Inception Score[3]生成高保真样本,但他们无法通过这些模型实现竞争对数可能性。对数似然是生成建模中广泛使用的指标,人们普遍认为,优化对数似然会迫使生成模型捕获数据分布的所有模式。此外,最近的工作[4]表明,对数似然的微小改进可以对样本质量和学习的特征表示产生巨大影响。因此,重要的是要探讨为什么 DDPM 似乎在这个指标上表现不佳,因为这可能表明一个根本性的缺点,例如模式覆盖率差。

为了研究不同修改的影响,在ImageNet 64×64和CIFAR-10数据集上训练具有固定超参数的固定模型架构。虽然 CIFAR-10 在此类模型中的应用更多,但论文选择研究 ImageNet 64 × 64,因为它在多样性和分辨率之间提供了良好的权衡,能够快速训练模型而不必担心过度拟合。此外,ImageNet 64×64 已在生成建模的背景下进行了广泛研究,能够将 DDPM 直接与许多其他生成模型进行比较。

Ho等人[1]的设置(在设置 σ t 2 = β t σ^2_t = β_t σt2=βt T = 1000 T = 1000 T=1000 的同时优化 L s i m p l e L_{simple} Lsimple )在 200K 训练迭代后,在 ImageNet 64 × 64 64 × 64 64×64 上实现了 3.99 3.99 3.99 b i t s / d i m bits/dim bits/dim) 的对数似然。论文在早期的实验中发现,可以通过将 T T T 1000 1000 1000 增加到 4000 4000 4000 来提高对数似然;通过此更改,对数似然提高到 3.77 3.77 3.77

Learning ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t)

在这里插入图片描述

Ho等人[1]将 ∑ θ ( x t , t ) = σ t 2 I \sum_{\theta}(x_{t}, t) = \sigma_{t}^{2}I θ(xt,t)=σt2I,其中 σ t σ_t σt 不是学习的。奇怪的是,他们发现将 σ t 2 σ^2_t σt2 固定到 β t β_t βt 产生的样品质量与将其固定到 β ~ t \tilde { \beta } _ { t } β~t 大致相同。考虑到 β t β_t βt β ~ t \tilde { \beta } _ { t } β~t 代表两个相反的极端,有理由问为什么这种选择不会影响样本。图 1 给出了一个线索,**它表明 β t β_t βt β ~ t \tilde { \beta } _ { t } β~t 几乎相等(除了接近 t = 0 t = 0 t=0),即模型正在处理难以察觉的细节。此外,随着扩散步骤数量的增加, β t β_t βt和β ̃t似乎在更多的扩散过程中彼此靠近。这表明,在无限扩散步骤的极限下, σ t σ_t σt的选择对样品质量可能完全无关紧要。换句话说,当添加更多的扩散步骤时,模型平均值 μ θ ( x t , t ) \mu _ { \theta } ( x _ { t } , t ) μθ(xt,t) ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t)更能决定分布。虽然上述论点表明,为了样本质量,固定 σ t σ_t σt 是一个合理的选择,但它并没有说明对数似然性。事实上,图2显示,扩散过程的前几步对变分下限的贡献最大。因此,似乎可以通过使用更好的 ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t) 选择来提高对数似然。为了实现这一目标,必须学习 ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t),而不会遇到 Ho 等人遇到的不稳定性。

由于图 1 显示 ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t)的理想范围非常小,因此神经网络很难直接预测 ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t),即使在对数域中也是如此。相反,我们发现最好将方差参数化为在log域 β t β_t βt β ~ t \tilde { \beta } _ { t } β~t之间的插值。具体而言,模型输出一个向量 v v v,每个维度包含一个分量,将此输出转换为方差,如下所示:

∑ θ ( x t , t ) = e x p ( v log ⁡ β t + ( 1 − v ) log ⁡ β ~ t ) \sum _ { \theta } ( x _ { t } , t ) = e x p ( v \log \beta _ { t } + ( 1 - v ) \log \tilde { \beta } _ { t } ) θ(xt,t)=exp(vlogβt+(1v)logβ~t)

没有对 v v v 施加任何约束,理论上允许模型预测插值范围之外的方差。由于 Lsimple 不依赖于 ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t),因此定义了一个新的混合目标:

L h y b r i d = L s i m p l e + λ L v l b L _ { h y b r i d } = L _ { s i m p l e } + \lambda L _ { v l b } Lhybrid=Lsimple+λLvlb

对于实验,设置 λ = 0.001 λ = 0.001 λ=0.001 以防止 L v l b L_{vlb} Lvlb 压倒 L s i m p l e L_{simple} Lsimple。按照同样的推理思路,还对 L v l b L_{vlb} Lvlb项的 μ θ ( x t , t ) \mu _ { \theta } ( x _ { t } , t ) μθ(xt,t)输出应用了停止梯度。这样,$L_{vlb} $可以引导 ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t),而 L s i m p l e L_{simple} Lsimple 仍然是影响 μ θ ( x t , t ) \mu _ { \theta } ( x _ { t } , t ) μθ(xt,t)的主要来源。

在这里插入图片描述

Improving the Noise Schedule

虽然Ho等人中使用的线性噪声调度对于高分辨率图像效果良好,但对于分辨率为64×64和32×32的图像来说,它是次优的。特别地,前向噪声处理的末尾噪声太大,因此对样本质量没有太大贡献。这可以在图3中直观地看到。这种影响的结果在图4中进行了研究,当跳过高达20%的反向扩散过程时,用线性时间表训练的模型不会变得更糟(通过FID测量)。为了解决这个问题,根据 α t ˉ \bar { \alpha _ { t } } αtˉ构建了一个不同的噪声表:

α t ˉ = f ( t ) f ( 0 ) , f ( t ) = cos ⁡ ( t / T + s 1 + s ⋅ π 2 ) 2 \bar { \alpha _ { t } } = \frac { f ( t ) } { f ( 0 ) } , f ( t ) = \cos \left( \frac { t / T + s } { 1 + s } \cdot \frac { \pi } { 2 } \right) ^ { 2 } αtˉ=f(0)f(t),f(t)=cos(1+st/T+s2π)2

β t = 1 − α ‾ t α ‾ t − 1 \beta _ { t } = 1 - \frac { \overline { \alpha } _ { t } } { \overline { \alpha } _ { t - 1 } } βt=1αt1αt

在实践中,将 β t \beta_t βt 裁剪为不大于 0.999,以防止在扩散过程结束时接近 $t = T $的奇点。

在这里插入图片描述

提出的余弦时间表被设计为在过程中具有 α t ˉ \bar { \alpha _ { t } } αtˉ的线性下降,同时在$ t = 0 $和 t = T t = T t=T 的极端附近变化很小,以防止噪声水平的突然变化。图 5 显示了两个计划的 α α α进展情况。可以看到,线性时间表以更快的速度趋向于零,破坏信息的速度比必要的要快得多。使用较小的偏移量 s s s 来防止 β t β_t βt 在$ t = 0 附近太小,因为在过程开始时有少量的噪声会使网络难以足够准确地预测。 ∗ ∗ 特别是,选择了 附近太小,因为在过程开始时有少量的噪声会使网络难以足够准确地预测。**特别是,选择了 附近太小,因为在过程开始时有少量的噪声会使网络难以足够准确地预测。特别是,选择了 s ,使得 ,使得 ,使得\sqrt { \beta _ { 0 } }$略小于像素箱大小 1 / 127.5 1/127.5 1/127.5,因此 s = 0.008 s = 0.008 s=0.008。我们特别选择使用 c o s 2 cos^2 cos2,因为它是一个具有我们正在寻找的形状的通用数学函数。这种选择是任意的,我们预计许多其他具有类似形状的函数也可以使用。**

Reducing Gradient Noise

在这里插入图片描述

在这里插入图片描述

我们希望通过直接优化 L v l b L_{vlb} Lvlb 而不是优化 L h y b r i d L_{hybrid} Lhybrid 来实现最佳的对数似然。然而, L v l b L_{vlb} Lvlb在实践中实际上很难优化,至少在多样化的 ImageNet 64×64 数据集上是这样。图 6 显示了 $L_{vlb} $和 L h y b r i d L{hybrid} Lhybrid 的学习曲线。两条曲线都是嘈杂的,但在训练时间相同的情况下,混合目标显然在训练集上实现了更好的对数似然。通过评估使用两个目标训练的模型的梯度噪声标度证实了 L v l b L_{vlb} Lvlb 的梯度比 L h y b r i d L_{hybrid} Lhybrid 的梯度大得多,如图7所示。因此,我们寻找一种方法来减少 L v l b L_{vlb} Lvlb 的方差,以便直接优化对数似然性。注意到 L v l b L_{vlb} Lvlb的不同项具有很大差异的幅度(图 2),假设采样$ t $在 $L_{vlb} $中均匀地产生不必要的噪声。为了解决这个问题,采用了重要性抽样:

L v l b = E t ∼ p t [ L t p t ] , w h e r e p t ∝ E [ L t 2 ] a n d ∑ p t = 1 L_{vlb} = E_{ t \sim p_{t} } \left[ \frac { L_{t} } { p_{t} } \right] , where p_{t} \propto \sqrt { E \left[ L_{t} ^ {2} \right] } and \sum p_{t} = 1 Lvlb=Etpt[ptLt],whereptE[Lt2] andpt=1

由于 E [ L t 2 ] E \left[ L _ { t } ^ { 2 } \right] E[Lt2] 事先是未知的,并且可能在整个训练过程中发生变化,因此我们维护每个损失项的前 10 个值的历史记录,并在训练期间动态更新。在训练开始时,均匀地采样 t t t,直到为每个 $t ∈ [0, T −1] $抽取 10 个样本。有了这个重要性抽样目标,就能够通过优化 L v l b L_{vlb} Lvlb 来实现最佳的对数似然。如图 6 所示,即 L v l b L_{vlb} Lvlb(重采样)曲线。该图还显示,重要性采样物镜的噪声比原始的均匀采样要小得多。可以发现,在直接优化噪声较小的L_{{hybrid}时,重要性采样技术没有帮助。

Improving Sampling Speed

在这里插入图片描述

为了减少从 T T T K K K 的采样步骤数,使用$ K$ 个介于 1 1 1 T T T(含)之间的均匀分布的实数,然后将每个结果数字四舍五入到最接近的整数。在图 8 中,评估了使用 4000 扩散步骤,使用 25、50、100、200、400、1000 和 4000 个采样步骤训练的$ L_{hybrid}$ 模型和 L s i m p l e L_{simple} Lsimple 模型的 FID。.我们既针对训练有素的检查点,也针对培训中途的检查点。对于 CIFAR-10,使用了 200K 和 500K 的训练迭代,对于 ImageNet 64,使用了 500K 和 1500K 的训练迭代。可以发现,当使用较少的采样步骤时,具有固定sigmas的 L s i m p l e L_{simple} Lsimple 模型在样本质量方面受到的影响要大得多,而学习sigmas的 L h y b r i d L_{hybrid} Lhybrid模型保持了较高的样本质量。使用此模型,100 个采样步骤足以为完全训练的模型实现近乎最佳的 FID。

Scaling Model Size

在这里插入图片描述

为了衡量性能如何通过训练计算进行扩展,我们在 ImageNet 64 × 64 上训练了四个不同的模型,并使用 L h y b r i d L_{hybrid} Lhybrid 目标。为了改变模型容量,在所有层上应用深度乘法器,使得第一层有 64、96、128 或 192 个通道。请注意,之前的实验在第一层中使用了 128 个通道。由于每一层的深度都会影响初始权重的规模,因此将每个模型的Adam学习率按 1 / c h a n n e l m u l t i p l i e r 1 / \sqrt{channel multiplier} 1/channelmultiplier 缩放,因此128通道模型的学习率为0.0001。图 10 显示了 FID 和 NLL 相对于理论训练计算的改进情况。FID 曲线在对数-对数图上看起来近似线性,表明 FID 根据幂律(绘制为黑色虚线)进行缩放。NLL曲线不能完全拟合幂律,这表明验证NLL的扩展方式不如FID。这可能是由多种因素引起的,例如 1) 这种类型的扩散模型出乎意料的高不可约损失,或 2) 模型过度拟合到训练分布。还注意到,这些模型通常无法实现最佳对数似然,因为它们是使用 L h y b r i d L_{hybrid} Lhybrid 而不是直接使用 L v l b L_{vlb} Lvlb 进行训练的,以保持良好的对数似然性和样本质量。

实验

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

参考文献

[1] Ho, J., Jain, A., and Abbeel, P. Denoising diffusion probabilistic models, 2020.
[2] Heusel, M., Ramsauer, H., Unterthiner, T., Nessler, B., and Hochreiter, S. Gans trained by a two time-scale update rule converge to a local nash equilibrium. Advances in Neural Information Processing Systems 30 (NIPS 2017), 2017.
[3] Salimans, T., Goodfellow, I., Zaremba, W., Cheung, V., Radford, A., and Chen, X. Improved techniques for training gans, 2016.
[4] Kaplan, J., McCandlish, S., Henighan, T., Brown, T. B., Chess, B., Child, R., Gray, S., Radford, A., Wu, J., and Amodei, D. Scaling laws for neural language models, 2020.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/749071.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Kotlin 中List,Set,Map的创建与使用

目录 1. List 的使用 1.1 不可变 List 1.2 可变 List 2. Set 的使用 2.1 不可变 Set 2.2 可变 Set 3. Map 的使用 3.1 不可变Map 3.2 可变Map 本篇主要为已经有Java基础的同学展示Kotlin语言中的List,Set,Map的创建和使用,所以Java代…

CentOS安装MySQL详细教程

1.下载 MySQL yum包 wget http://repo.mysql.com/mysql57-community-release-el7-10.noarch.rpm 2.安装MySQL源 rpm -Uvh mysql57-community-release-el7-10.noarch.rpm 3.安装MySQL服务端 yum install -y mysql-community-server 4.启动MySQL systemctl start mysqld.service …

flink重温笔记(十八): flinkSQL 顶层 API ——时态表实现表数据动态变化(涵盖全面实用的 API )

Flink学习笔记 前言:今天是学习 flink 的第 18 天啦!很多小伙伴私信说,自己只会SQL语法来编写flinkSQL,如何使用代码来操作呢?因为工作中都是要用到代码编写的。还有小伙伴说,想要实现表是动态变化的&#…

【小白刷leetcode】第15题

【小白刷leetcode】第15题 动手刷leetcode,正在准备蓝桥,但是本人算法能力一直是硬伤。。。所以做得一直很痛苦。但是不熟练的事情像练吉他一样,就需要慢速,多练。 题目描述 看这个题目,说实在看的不是很懂。索性我们直…

uniapp 对video视频组件嵌套倍速按钮

这次接了需求是要求有倍速功能,去看了文档发现并没有倍速按钮的属性,想着手写一个吧 可最后发现原生层级太高,无论怎么样都迭不上去,就只能去找插件看看咯 找了好多插件发现都不可用,因为我这是app端,有些视…

mysql笔记:15. 事务和锁

文章目录 一、事务概述二、事务基本操作三、事务保存点四、事务的隔离级别1. READ UNCOMMITTED设置事务的隔离级别 2. READ COMMITTED3. REPEATABLE READ4. SERIALIZABLE 五、MySQL的锁InnoDB的锁类型1. InnoDB的行级锁2. InnoDB的表级锁 死锁 在开发过程中,我们经常…

配置服务器SSH

在终端中,运行以下命令以检查SSH服务器的状态: sudo service ssh status安装SSH服务器。您可以运行以下命令来安装OpenSSH服务器,这是SSH服务的一个流行实现: sudo apt install openssh-server如果SSH服务器正在运行&#xff0c…

Acwing100 --- 增减序列(差分)

给定一个长度为 n 的数列 a1,a2,…,an,每次可以选择一个区间 [l,r],使下标在这个区间内的数都加一或者都减一。 求至少需要多少次操作才能使数列中的所有数都一样,并求出在保证最少次数的前提下,最终得到的数列可能有多少种。 输入…

记录些实际应用开发过程中的prompt

Text2SQL 假设你是{dbType}的专家,需要通过问题描述和指令语句两部分内容帮忙生成对应查询SQL语句。第一部分问题说明: {queryContent} 第二部分指令内容: 1,不能幻觉出现新的字段,schema字段、表名称、表字段名称必须…

Vue组件中引入jQuery

两种在vue中引入jQuery的方式 1、普通html中使用jQuery 将jQuer的文件导入到项目中&#xff0c;然后直接使用<script src"jQuery.js"></script>即可。 <script src"jQuery.js"></script> 2、vue组件中使用jQuery 安装依赖 c…

KGCN---pytorch代码(2)---aggregator

代码&#xff1a; import torch import torch.nn.functional as Fclass Aggregator(torch.nn.Module):Aggregator classMode in [sum, concat, neighbor]#最后一个 neighbor 的聚合器直接就是利用邻域表示来代替 v 结点的表示def __init__(self, batch_size, dim, aggregator)…

vue组件基础及注册

1、组件的命名 kebab-case&#xff08;短横线&#xff09;命名法&#xff1a;字母全小写且必须包含一个连字符&#xff1b;例&#xff1a;my-component-namePascalCase&#xff08;帕斯卡&#xff09;命名法&#xff1a;首字符大写&#xff1b;例&#xff1a;MyComponentName …

C语言数据结构基础笔记——树、二叉树简介

1.树 树是一种 非线性 的数据结构&#xff0c;它是由 n &#xff08; n>0 &#xff09;个有限结点组成一个具有层次关系的集合。 把它叫做树是因 为它看起来像一棵倒挂的树&#xff0c;也就是说它是根朝上&#xff0c;而叶朝下的。 &#xff08;图片来源于网络&#xff09;…

【OJ】string类题目

个人主页 &#xff1a; zxctscl 如有转载请先通知 题目 1. 415字符串相加1.1 分析1.2 代码 2. 344反转字符串2.1 分析2.2 代码 3. HJ1字符串最后一个单词的长度3.1 分析3.2 代码 4. 387.字符串中的第一个唯一字符4.1 分析4.2 代码 5. 125验证回文串5.1 分析5.2 代码 1. 415字符…

【python小技能】使用Python发送电子邮件的完整指南(适合零基础)

前言 在现代通信中&#xff0c;电子邮件是一种不可或缺的工具。使用Python编程语言&#xff0c;我们可以轻松地编写代码来发送电子邮件。本文将为零基础的读者提供一个完整的指南&#xff0c;教你如何使用Python发送电子邮件 安装库 首先&#xff0c;我们需要安装smtplib库。…

wordpress被恶意搜索攻击(网址/?s=****)解决方法。

源地址&#xff1a;https://www.ctvol.com/seoomethods/1413686.html 什么叫恶意搜索攻击&#xff1f; wordpress恶意搜索攻击并不是像病毒一样的攻击&#xff0c;而是一种seo分支黑帽手段&#xff0c;通过被攻击网站搜索功能中长尾关键词来实现攻击&#xff0c;通过网址不断…

Clickhouse MergeTree原理(二)—— 表和分区的维护

作者&#xff1a;俊达 引言 MergeTree是Clickhouse中最核心的存储引擎。上一篇文章中&#xff0c;我们介绍了MergeTree的基本结构。 1、MergeTree由分区&#xff08;partiton&#xff09;和part组成。 2、Part是MergeTree可操作的基本数据单元。 当插入数据时&#xff0c;会…

MySQL 中的“两阶段提交”机制

在MySQL数据库中&#xff0c;为了确保redo log&#xff08;重做日志&#xff09;和binlog&#xff08;二进制日志&#xff09;之间的数据安全性和一致性&#xff0c;引入了“两阶段提交”这一重要概念。MySQL将redo log的写入过程细分为“prepare”和“commit”两个步骤&#x…

【LeetCode热题100】146. LRU 缓存(链表)

一.题目要求 请你设计并实现一个满足 LRU (最近最少使用) 缓存 约束的数据结构。 实现 LRUCache 类&#xff1a; LRUCache(int capacity) 以 正整数 作为容量 capacity 初始化 LRU 缓存int get(int key) 如果关键字 key 存在于缓存中&#xff0c;则返回关键字的值&#xff0c…

Jenkins插件Parameterized Scheduler用法

Jenkins定时触发构建的同时设定参数。可以根据不同的定时构建器设置不同参数或环境变量的值。可以设置多个参数。并结合when控制stage流程的执行。结合when和triggeredBy区分定时构建的stage和手动执行的stage。 目录 什么是Parameterized Scheduler&#xff1f;如何配置实现呢…