第五章 深度学习
十、生成对抗网络(GAN)
1. 图像生成技术概述
1.1 什么是图像生成技术
图像生成技术是指利用机器学习或深度学习等人工智能技术,通过训练模型来生成逼真的图像。这些技术可以根据给定的输入,生成与真实图像相似的、具有一定创造性的图像。近几年,深度学习图像生成技术取得了重大进展,引发了一股AIGC热潮。
1.2 生成模型的原理
在概率统计理论中,生成模型是指能够在给定某些隐含参数的条件下,随机生成观测数据的模型,它给观测值和标注数据序列指定一个联合概率分布。对于生成模型来说,可以分为两个类型:第一种类型的生成模型可以完全表示出数据确切的分布函数;第二种类型的生成模型只能做到新数据的生成,而数据分布函数则是模糊的,深度学习生成模型大多为后者。所以,从概率角度来讲,生成模型就是生成尽可能接近真实数据分布的假数据。
1.3 生成模型的应用
图像生成技术主要用于以下几个方面:
1)艺术创作:图像生成技术可以用于艺术创作,例如生成艺术作品、绘画、插图等。通过训练模型,可以生成具有艺术风格的图像,帮助艺术家创作出独特的作品。
2)视觉效果:图像生成技术在电影、游戏和虚拟现实等领域有广泛应用。通过生成逼真的图像,可以创建出想象力丰富的虚拟场景、特效和角色,提升视觉效果的真实感和沉浸感。
3)数据增强:在机器学习和深度学习中,数据的质量和数量对模型的性能至关重要。图像生成技术可以用于生成合成数据,扩充训练集,提高模型的泛化能力和鲁棒性。合成数据有以下优点:
- 合成训练数据比获取真实世界的样本更容易、更快、更便宜
- 某种情况下,合成数据增强可以提高AI系统的性能
- 可以在医学成像或医疗记录等敏感应用中保护隐私
- 最重要一点,随着深度学习模型参数越来越庞大,现几乎没有真实数据可用了
4)图像修复和增强:图像生成技术可以用于修复和增强图像的质量。例如,可以通过生成缺失的图像部分、去除噪声、增强细节等方式,改善图像的视觉效果。
5)虚拟现实和增强现实:图像生成技术可以用于创建虚拟现实和增强现实的场景和对象。通过生成逼真的虚拟图像,可以提供更加沉浸式和交互式的虚拟体验。
1.4 生成模型主要技术路线
目前,深度学习领域图像生成主要有以下几种技术路线:GAN(生成对抗网络)、VAE(变分自动编码器)、Flow-based、Diffusion(扩散)模型。
- GAN:包括一个生成器(Generator)和判别器(Discriminator),生成器负责生成数据,判别器负责对数据进行判断真假。生成器尽可能生成接近真实分布的数据,骗过判别器;判别器尽可能把帧数据和生成的数据识别出来。这样,生成器、判别器就形成一个对抗关系,从而提升两个模型的能力;
- VAE:通过编码器(Encoder)生成一个隐含编码(latent code),解码器(Decoder)根据隐含编码生成数据,生成的数据尽可能把原数据重建出来。VAE在普通的自编码器上加入了一些限制,要求产生的隐含编码能够遵循高斯分布,这个限制帮助自编码器真正读懂训练数据的潜在规律,让自编码器能够学习到输入数据的隐含变量模型。
- Flow-based模型:Flow-based模型采用一种比较独特的方法,它选择直接直面生成模型的概率计算,把分布转换为积分式 P G ( x ) = ∫ z p ( x ∣ z ) p ( z ) d z P_G(x)=\int_z p(x|z)p(z)dz PG(x)=∫zp(x∣z)p(z)dz计算出来。
- Diffusion模型:灵感来自物理学的非平衡热力学理论,输入原始图像,逐步随机噪声添加到数据中,然后学习逆向扩散过程以从噪声中构造所需的数据样本。
1.5 图像生成模型的评估指标
1.5.1 Inception Score
Inception Score(IS)用于评估生成图像的多样性和真实性。它基于Inception网络的输出,通过计算生成图像的条件分布和边际分布之间的KL散度来度量生成图像的多样性。较高的Inception分数表示生成图像具有更好的多样性和真实性。其表达式为:
I S ( G ) = e x p ( 1 N ∑ i = 1 N K L ( p ( y ∣ x i ) ∣ ∣ p ( y ) ) ) IS(G) = exp(\frac{1}{N} \sum_{i=1} ^ N KL(p(y|x_i) || p(y))) IS(G)=exp(N1i=1∑NKL(p(y∣xi)∣∣p(y)))
其中,KL表示KL散度,用来衡量两个分布之间的差异; p ( y ∣ x ) p(y|x) p(y∣x)表示对于图片 x x x,属于所有类别的概率分布. 如果是在ImageNet上预训练得到Inception,对于图像x,表示为一个1000维的向量。
1.5.2 FID
FID是Frechet Inception Distance score的简写,用于评估生成图像的质量和多样性。它基于Frechet 距离,通过比较生成图像的特征分布和真实图像的特征分布之间的距离来度量生成图像的质量。FID分数越低,表示生成图像与真实图像的分布越接近,生成图像的质量越好。表达式:
F I D = ∣ ∣ μ r − μ g ∣ ∣ 2 + T r ( Σ r + Σ g − 2 Σ r Σ g ) FID = ||\mu_r - \mu_g|| ^2 + T_r(\Sigma_r + \Sigma_g - 2 \sqrt{\Sigma_r \Sigma_g}) FID=∣∣μr−μg∣∣2+Tr(Σr+Σg−2ΣrΣg)
其中, μ r \mu_r μr表示真实图片的特征均值; μ g \mu_g μg表示生成图片的特征均值; Σ r \Sigma_r Σr为真实图像的协方差矩阵; Σ g \Sigma_g Σg为生成图片的协方差矩阵。
2. GAN模型概述
生成对抗网络(Generative Adversarial Networks)是一种无监督深度学习模型,用来通过计算机生成数据,由Ian J. Goodfellow等人于2014年提出。模型通过框架中(至少)两个模块:生成模型(Generative Model)和判别模型(Discriminative Model)的互相博弈学习产生相当好的输出。生成对抗网络被认为是当前最具前景、最具活跃度的模型之一,目前主要应用于样本数据生成、图像生成、图像修复、图像转换、文本生成等方向。
GAN这种全新的技术在生成方向上带给了人工智能领域全新的突破。在之后的几年中生GAN成为深度学习领域中的研究热点,近几年与GAN有关的论文数量也急速上升,目前数量仍然在持续增加中。
2018年,对抗式神经网络的思想被《麻省理工科技评论》评选为2018年“全球十大突破性技术”(10 Breakthrough Technologies)之一。 Yann LeCun(“深度学习三巨头”之一,纽约大学教授,前Facebook首席人工智能科学家)称赞生成对抗网络是“过去20年中深度学习领域最酷的思想”,而在国内被大家熟知的前百度首席科学家Andrew Ng也把生成对抗网络看作“深度学习领域中一项非常重大的进步”。
3. GAN基本原理
3.1 构成
GAN由两个重要的部分构成:生成器(Generator,简写作G)和判别器(Discriminator,简写作D)。
- 生成器:通过机器生成数据,目的是尽可能“骗过”判别器,生成的数据记做G(z);
- 判别器:判断数据是真实数据还是「生成器」生成的数据,目的是尽可能找出「生成器」造的“假数据”。它的输入参数是x,x代表数据,输出D(x)代表x为真实数据的概率,如果为1,就代表100%是真实的数据,而输出为0,就代表不可能是真实的数据。
这样,G和D构成了一个动态对抗(或博弈过程),随着训练(对抗)的进行,G生成的数据越来越接近真实数据,D鉴别数据的水平越来越高。在理想的状态下,G可以生成足以“以假乱真”的数据;而对于D来说,它难以判定生成器生成的数据究竟是不是真实的,因此D(G(z)) = 0.5。训练完成后,我们得到了一个生成模型G,它可以用来生成以假乱真的数据。
3.2 训练过程
- 第一阶段:固定「判别器D」,训练「生成器G」。使用一个性能不错的判别器,G不断生成“假数据”,然后给这个D去判断。开始时候,G还很弱,所以很容易被判别出来。但随着训练不断进行,G技能不断提升,最终骗过了D。这个时候,D基本属于“瞎猜”的状态,判断是否为假数据的概率为50%。
- 第二阶段:固定「生成器G」,训练「判别器D」。当通过了第一阶段,继续训练G就没有意义了。这时候我们固定G,然后开始训练D。通过不断训练,D提高了自己的鉴别能力,最终他可以准确判断出假数据。
- 重复第一阶段、第二阶段。通过不断的循环,「生成器G」和「判别器D」的能力都越来越强。最终我们得到了一个效果非常好的「生成器G」,就可以用它来生成数据。
3.3 GAN的优缺点
3.3.1 优点
- 能更好建模数据分布(图像更锐利、清晰);
- 理论上,GANs 能训练任何一种生成器网络。其他的框架需要生成器网络有一些特定的函数形式,比如输出层是高斯的;
- 无需利用马尔科夫链反复采样,无需在学习过程中进行推断,没有复杂的变分下界,避开近似计算棘手的概率的难题。
3.3.2 缺点
- 模型难以收敛,不稳定。生成器和判别器之间需要很好的同步,但是在实际训练中很容易D收敛,G发散。D/G 的训练需要精心的设计。
- 模式缺失(Mode Collapse)问题。GANs的学习过程可能出现模式缺失,生成器开始退化,总是生成同样的样本点,无法继续学习。
3.4 GAN的应用
3.4.1 生成数据集
人工智能的训练是需要大量的数据集,可以通过GAN自动生成低成本的数据集。