1、前言
本篇文章,我们讲NCSN,也就是噪声条件分数网络。这是宋飏老师在2019年提出的模型,思路与传统的生成模型大不相同,令人拍案叫绝!!!
参考论文:
①Generative Modeling by Estimating Gradients of the Data Distribution (arxiv.org)
②Tutorial on Diffusion Models for Imaging and Vision (arxiv.org)
参考代码:GitHub - Lingyu-Kong/ncsn: Handwritten Score-Based Generative Model
视频:[噪声条件得分(分数)网络——NCSN原理解析-哔哩哔哩]
Ps:这篇文章我简单讲一下思路就算了,过程并不严谨,因为这个内容并不是很重要
2、引入
回忆一下梯度下降,假设我们有一个二次函数
f ( x ) = ( 0.5 x − 3 ) 2 f(x)=(0.5x-3)^2 f(x)=(0.5x−3)2
导数为 f ′ ( x ) = ( 0.5 x − 3 ) f'(x)=(0.5x-3) f′(x)=(0.5x−3),使用梯度下降
x t + 1 = x t − 0.1 f ′ ( x t ) (1) x_{t+1}=x_t-0.1f'(x_t)\tag{1} xt+1=xt−0.1f′(xt)(1)
其中 x t 、 x t + 1 x_t、x_{t+1} xt、xt+1表示优化前和优化后的x对应的值, 0.1 0.1 0.1是步长。初始化蓝色点 x t = − 6 x_t=-6 xt=−6,迭代100轮梯度下降,就可以得到下面的图(可以看到蓝色点逐渐向着函数最低点靠近)
为什么会这样?因为梯度总是指向函数值上升的方向。而Eq.(1),是减去梯度,相当于对梯度取反方向。于是x的值就沿着函数值下降的方向走了。如果换成梯度上升,则Eq.(1)改为
x t + 1 = x t + 0.1 f ′ ( x t ) (2) x_{t+1}=x_t+0.1f'(x_t)\tag{2} xt+1=xt+0.1f′(xt)(2)
对应图像为
再回忆一下一维高斯分布的概率密度的图像
当y值(密度值)取到最高点,其对应样本点在均值处
此时我们注意到,高斯分布的图像,与Eq.(2)何其相像,那我们把Eq.(2)里面的 f ( x ) f(x) f(x)当作是高斯分布的密度函数,而 x x x则对应高斯分布的样本点
x t + 1 = x t + 0.1 f ′ ( x t ) x_{t+1}=x_t+0.1f'(x_t) xt+1=xt+0.1f′(xt)
那么这个梯度上升的意思就变成了,对于一个样本 x t x_t xt,不断往概率密度函数 f ′ ( x t ) f'(x_t) f′(xt)密度值高的地方靠近。如果优化到最优点,那么图像就会变成这样
也就是说,样本点 x t x_t xt,最终会走到概率值最高对应的点,那么此时的样本点 x t x_t xt,就可以认为是从高斯分布中采样出来的一个概率最高的样本。我们写成概率分布的一般形式
x t + 1 = x t + α ∇ x P ( x t ) x_{t+1}=x_t+\alpha \nabla_xP(x_t) xt+1=xt+α∇xP(xt)
α \alpha α表示步长,比如之前的0.1, ∇ x \nabla_x ∇x是对x求梯度。
我们在 P ( x t ) P(x_t) P(xt)前面取一个log对数,不改变单调性,仍然会使 x t x_t xt收敛到最优值
x t + 1 = x t + α ∇ x log P ( x t ) x_{t+1}=x_t+\alpha \nabla_x\log P(x_t) xt+1=xt+α∇xlogP(xt)
更一般的,从一个概率分布中采样,我们往往会存在一些偏差项,于是我们加上一个随机噪声
x t + 1 = x t + α ∇ x log P ( x t ) + 2 α z t (3) x_{t+1}=x_t+\alpha \nabla_x\log P(x_t)+\sqrt{2\alpha}z_t\tag{3} xt+1=xt+α∇xlogP(xt)+2αzt(3)
2 α \sqrt{2\alpha} 2α是缩放系数,而 z t z_t zt是标准高斯分布,加上一个噪声后, x t x_t xt的收敛值会在概率最高点处不断徘徊
图像表示为
现在,我们更进一步,我们把 x t x_t xt当作是一个随机初始化的图像,然后 P ( x ) P(x) P(x)是我们训练图像的所对应的分布,通过不断执行Eq.(3),便可以让随机初始化的图像,不断往 P ( x ) P(x) P(x)概率最高点周围靠近,那么就间接说明,经过了大T步Eq.(3),得到的 x t x_t xt,可以认为是从 P ( x ) P(x) P(x)中采样出来的。
仔细看一下,这不就是一个生成图像的过程吗?
这种方式,又被称为郎之万动力采样。emmmmm,不懂,物理学的东西。。。
我们看一个可视化的过程(图像来自参考①)
3、目标函数
既然Eq.(3)能够通过迭代的方式,生成图像,那自然只需要求解Eq.(3)就可以了。不幸的是,我们没办法求解
我们的训练图像,它们所服从的概率分布往往及其复杂,也就是说 P ( x ) P(x) P(x)是难以求解的,好在我们的目标并不是求出 P ( x ) P(x) P(x),而是对应的梯度(也称为分数)
L = 1 2 E P d a t a ( x ) [ ∣ ∣ s θ ( x ) − ∇ x log P d a t a ( x ) ∣ ∣ 2 2 ] (4) L_{}=\frac{1}{2}\mathbb{E}_{P_{data}(x)}\left[||s_\theta(x)-\nabla_x\log P_{data}(x)||_2^2\right]\tag{4} L=21EPdata(x)[∣∣sθ(x)−∇xlogPdata(x)∣∣22](4)
P d a t a P_{data} Pdata表示训练数据所服从的分布
也就是通过最小化上式,便可得到 s θ ( x ) ≈ ∇ x log P d a t a ( x ) s_\theta(x)\approx \nabla_x\log P_{data}(x) sθ(x)≈∇xlogPdata(x)。
4、问题
理论上,我们直接求解Eq.(4)就可以了,但是,我们样本所服从的分布往往是服从,概率分布中往往存在一些低密度区域,那么对应的样本就很少。
而样本少,意味着对应为止的梯度分数,得不到很好的训练,那么神经网络在那些样本点就很容易估不准。作者博客给出了一张很形象的图像(图像来自参考①)
可以看到,数据的密度分别都在左下角和右上角,那么这些区域就能够用神经网络得到很好的拟合,对应Accurate区域。相反,低密度区域,没有得到很好的拟合,对应Inaccurate区域。
当我们使用郎之万动力采样的时候,随机初始化一个 x 0 x_0 x0,它落在低密度区域的概率非常之高。而低密度的区域没有经过很好的训练,所以郎之万动力采样在短时间内很难得到较好的结果。
那么,该如何解决这个问题呢?一个很好的方法就是——加噪声
我们通过对图像加入随机扰动噪声,会填充原本的低密度区域,从而让整个区域看起来较为的均匀(图像来自参考①)
也就是这样,让原本的密度点扩张开来。
加噪的过程我们可以表示为 x ~ = x + σ z \tilde x=x+\sigma z x~=x+σz。 x x x表示原始图像, x ~ \tilde x x~表示加噪后的图像。
我们用 q ( x ~ ∣ x ) ∼ N ( x , σ 2 I ) q(\tilde x|x)\sim N(x,\sigma^2I) q(x~∣x)∼N(x,σ2I)去表示这个加噪过程
于是Eq.(3)就可以变成
L = 1 2 E P d a t a ( x ) , x ~ ∼ N ( x , σ 2 I ) [ ∣ ∣ s θ ( x + σ z ) − ∇ x ~ log q ( x ~ ∣ x ) ∣ ∣ 2 2 ] (5) L_{}=\frac{1}{2}\mathbb{E}_{P_{data}(x),\tilde x\sim N(x,\sigma^2I)}\left[||s_\theta(x+\sigma z)-\nabla_{\tilde x}\log q(\tilde x|x)||_2^2\right]\tag{5} L=21EPdata(x),x~∼N(x,σ2I)[∣∣sθ(x+σz)−∇x~logq(x~∣x)∣∣22](5)
emmmm,我感觉这样讲貌似挺合理的,但是它是需要证明的,也就是证明Eq.(4)、Eq.(5)的优化等价性。我就不证明了,证明过程在参考论文②,并不难,读者自己看一下就知道了
除此之外,真正导致需要加噪的,其实有其他原因,我只讲了其中一个。其他原因请看参考②,里面讲的非常之详细。我也懒得写了
现在,我们预测的是加噪后的梯度分数,通过加噪的过程,也避免了直接求解 P ( x ) P(x) P(x)的问题。那我们来看一下这个等式可以变成什么吧
如果我们加的噪声足够小,那么 P d a t a ( x ) ≈ q ( x ~ ∣ x ) P_{data}(x)\approx q(\tilde x|x) Pdata(x)≈q(x~∣x)
因为 q ( x ~ ∣ x ) q(\tilde x|x) q(x~∣x)是服从高斯分布的,是完全可以求出来的,所以梯度为
∇ x ~ log q ( x ~ ∣ x ) = ∇ x ~ log 1 2 π σ 2 d exp { − ∣ ∣ x ~ − x ∣ ∣ 2 2 σ 2 } = ∇ x ~ ( log 1 2 π σ 2 d − ∣ ∣ x ~ − x ∣ ∣ 2 2 σ 2 ) = − 2 ( x ~ − x ) 2 σ 2 = − x ~ − x σ 2 = − z σ \begin{aligned}\nabla_{\tilde x}\log q(\tilde x|x)=&\nabla_{\tilde x}\log \frac{1}{\sqrt{2\pi\sigma^2}^d}\exp \left\{-\frac{||\tilde x-x||^2}{2\sigma^2}\right\}\\=&\nabla_{\tilde x}\left(\log \frac{1}{\sqrt{2\pi\sigma^2}^d}-\frac{||\tilde x-x||^2}{2\sigma^2}\right)\\=&-\frac{2(\tilde x-x)}{2\sigma^2}\\=&-\frac{\tilde x -x}{\sigma^2}\\=&-\frac{z}{\sigma}\end{aligned} ∇x~logq(x~∣x)=====∇x~log2πσ2d1exp{−2σ2∣∣x~−x∣∣2}∇x~(log2πσ2d1−2σ2∣∣x~−x∣∣2)−2σ22(x~−x)−σ2x~−x−σz
所以损失函数就可以变成
L = 1 2 E P d a t a ( x ) , x ~ ∼ N ( x , σ 2 I ) [ ∣ ∣ s θ ( x + σ z ) + x ~ − x σ 2 ∣ ∣ 2 2 ] L=\frac{1}{2}\mathbb{E}_{P_{data}(x),\tilde x\sim N(x,\sigma^2I)}\left[||s_\theta(x+\sigma z)+\frac{\tilde x -x}{\sigma^2}||_2^2\right] L=21EPdata(x),x~∼N(x,σ2I)[∣∣sθ(x+σz)+σ2x~−x∣∣22]
按理说,我们只需要最优化这个目标函数即可。
可问题又来了
我们该如何加入噪声呢?加多少?加的小了,低密度区域没有得到很好的填充。加多了,直接改变原本的数据分布了,这显然也不行。
我们干脆一不做二不休,我们加多个量级噪声,不同量级都进行训练。
当训练完成之后,就得到了不同噪声强度的噪声条件得分网络。
假设不同强度等级的噪声有S个, { σ i } i = 1 S \{\sigma_i\}_{i=1}^S {σi}i=1S,我们看一张图(里面显示了三个噪声强度的情况,图像来自参考①)
那么进行采样的时候,就可以从高强度的噪声,进行郎之万动力采样,然后慢慢降低噪声的强度。总而言之,就是每个噪声强度,都进行一轮郎之万动力采样,比如下图(图像来自参考①)(Gif图像太大,上传不了…看视频里面吧)
假设有S个噪声强度,那么就可以变成
L = 1 S ∑ i = 1 S λ i 1 2 E P d a t a ( x ) , x ~ ∼ N ( x , σ i 2 I ) [ ∣ ∣ s θ ( x + σ i z , σ i ) + x ~ i − x σ i 2 ∣ ∣ 2 2 ] L=\frac{1}{S}\sum\limits_{i=1}^S\lambda_i\frac{1}{2}\mathbb{E}_{P_{data}(x),\tilde x\sim N(x,\sigma_i^2I)}\left[||s_\theta(x+\sigma_i z,\sigma_i)+\frac{\tilde x_i -x}{\sigma_i^2}||_2^2\right] L=S1i=1∑Sλi21EPdata(x),x~∼N(x,σi2I)[∣∣sθ(x+σiz,σi)+σi2x~i−x∣∣22]
x ~ i \tilde x_i x~i表示在噪声强度为 σ i \sigma_i σi的加噪图像。 λ i \lambda_i λi代表的是一个加权系数.一般情况下,我们取 λ i = σ i 2 \lambda_i=\sigma^2_i λi=σi2。
对于噪声强度数量,一般是数百到数千;噪声强度选择一般采用几何级数。
采样的时候正如前面所说,先在高强度噪声量级进行郎之万动力采样,而后慢慢降低,所以采样方法为
5、结束
好了,本篇文章到此为止,如有问题,还望指出,阿里嘎多!!!
6、参考
①Generative Modeling by Estimating Gradients of the Data Distribution | Yang Song (yang-song.net)
②基于分数的生成模型(Score-based generative models) — 张振虎的博客 张振虎 文档 (zhangzhenhu.com)