1 论文简介
论文题目: U-gat-it: Unsupervised generative attentional networks with adaptive layer-instance normalization for image-to-image translation
论文代码:https://github.com/taki0112/UGATIT
论文数据集:https://github.com/znxlwm/UGATIT-pytorch
本文以倒序的方式来介绍这篇论文,首先看效果,然后分析其原理。
2 效果
Figure 2: Visualization of the attention maps and their effects shown in the ablation experiments: (a) Source images, (b) Attention map of the generator, (c-d) Local and global attention maps of the discriminator, respectively. (e) Our results with CAM, (f) Results without CAM.
Figure 3: Comparison of the results using each normalization function: (a) Source images, (b) Our results, © Results only using IN in decoder with CAM, (d) Results only using LN in decoder with CAM, (e) Results only using AdaIN in decoder with CAM, (f) Results only using GN in decoder with CAM.
3 基本框架
本文提出了一种新的无监督图像到图像转换方法,以端到端的方式结合新的注意力模块和新的可学习归一化函数。
- 注意力模块根据辅助分类器获得的注意力图,引导模型专注于区分源域和目标域的更重要的区域(帮助模型知道在哪里进行密集转换)。 与之前无法处理域之间几何变化的基于注意力的方法不同,本文的模型可以转换需要整体变化的图像和需要大形状变化的图像。
- AdaLIN 函数帮助注意力模型灵活控制形状和纹理的变化量,而无需修改模型架构或超参数。
- 实验结果表明,与具有固定网络架构和超参数的现有最先进模型相比,所提出的方法具有优越性。
模型分为生成器和判别器两部分,结构几乎一致。生成器比判别器多了AdaLIN算法实现的Decoder模块。
图1描述了网络结构,以生成器为例,输入图像通过Encoder编码阶段(下采样+残差模块)得到特征图,然后添加一个辅助分类引入Attention机制通过特征图的最大池化,经过全连接层输出一个节点的预测,然后将这个全连接层的参数和特征图相乘从而得到Attention的特征图。最后经过Decoder模块得到输出图像。
Figure 1: The model architecture of U-GAT-IT. The detailed notations are described in Section Model
本文的目标是训练一个函数 Gs→tG_{s \rightarrow t}Gs→t,该函数使用从每个域中抽取未配对的样本将图像从源域XsX_sXs 映射到目标域 XtX_tXt:
- 该框架由两个生成器 Gs→tG_{s \rightarrow t}Gs→t和 Gt→sG_{t \rightarrow s}Gt→s 以及两个鉴别器 DsD_sDs 和DtD_tDt 组成;
- 将注意力模块集成到生成器和鉴别器中;
- 判别器中的注意力模块引导生成器关注对生成逼真图像至关重要的区域;
- 生成器中的注意力模块关注与其他域不同的区域(判别器注意力模块已经引导生成器聚焦了一个域,那么生成器的注意力模块则聚焦其它的域)。
3.1 生成器
在这里,我们只解释Gs→tG_{s \rightarrow t}Gs→t和 DtD_tDt(见图 1),反之亦然。
符号说明:
x∈{Xs,Xt}x \in\left\{X_{s}, X_{t}\right\}x∈{Xs,Xt}:来自源域和目标域的样本;
Gs→tG_{s \rightarrow t}Gs→t:包括一个编码器EsE_sEs,一个解码器GtG_tGt,和一个辅助分类器ηs\eta_sηs;
ηs(x)\eta_s(x)ηs(x):表示xxx来自XsX_sXs的概率;
Esk(x)E_{s}^{k}(x)Esk(x):编码器的第 kkk 个激活映射(map);
Eskij(x)E_{s}^{k_{i j}}(x)Eskij(x):在(i,j)(i, j)(i,j)上的值;
wskw_s^kwsk:通过使用全局平均池化和全局最大池化训练辅助分类器以学习源域的第kkk 个特征图的权重,例如:ηs(x)=σ(ΣkwskΣijEskij(x))\eta_{s}(x)=\sigma\left(\Sigma_{k} w_{s}^{k} \Sigma_{i j} E_{s}^{k_{i j}}(x)\right)ηs(x)=σ(ΣkwskΣijEskij(x));
利用 wskw_s^kwsk,可以计算一组特定领域的注意力特征图:
as(x)=ws∗Es(x)={wsk∗Esk(x)∣1≤k≤n}a_{s}(x)=w_{s} * E_{s}(x)=\left\{w_{s}^{k} * E_{s}^{k}(x) \mid 1 \leq k \leq n\right\}as(x)=ws∗Es(x)={wsk∗Esk(x)∣1≤k≤n}。
nnn:编码特征图的数量。
AdaLIN(a,γ,β)=γ⋅(ρ⋅aI^+(1−ρ)⋅aL^)+β,aI^=a−μIσI2+ϵ,aL^=a−μLσL2+ϵρ←clip[0,1](ρ−τΔρ)(1)\begin{array}{c} \operatorname{AdaLIN}(a, \gamma, \beta)=\gamma \cdot\left(\rho \cdot \hat{a_{I}}+(1-\rho) \cdot \hat{a_{L}}\right)+\beta, \\ \hat{a_{I}}=\frac{a-\mu_{I}}{\sqrt{\sigma_{I}^{2}+\epsilon}}, \hat{a_{L}}=\frac{a-\mu_{L}}{\sqrt{\sigma_{L}^{2}+\epsilon}} \\ \rho \leftarrow \operatorname{clip}_{[0,1]}(\rho-\tau \Delta \rho) \end{array}\tag1 AdaLIN(a,γ,β)=γ⋅(ρ⋅aI^+(1−ρ)⋅aL^)+β,aI^=σI2+ϵa−μI,aL^=σL2+ϵa−μLρ←clip[0,1](ρ−τΔρ)(1)
公式(1)的符号说明:
- γ\gammaγ和β\betaβ由注意力图的全连接层动态计算;
- μI\mu_IμI , μL\mu_LμL 和σI\sigma_IσI, σL\sigma_LσL 分别是通道方式、层方式均值和标准差;
- τ\tauτ为学习速率;
- ΔρΔ \rhoΔρ 表示优化器确定的参数更新向量(如梯度);
- ρ\rhoρ的值被限制在[0,1][0,1][0,1]的范围内,只需在参数更新步骤中设置界限即可;生成器调整该值,以便在实例规范化很重要的任务中ρ\rhoρ的值接近1,而在层归一化(LN)很重要的任务中ρ\rhoρ的值接近0。在解码器的残差块中,ρ\rhoρ的值初始化为1,在解码器的上采样块中,ρ\rhoρ的值初始化为0。
公式(1)中最核心的部分是:
ρ⋅IN+(1−ρ)⋅LNaI^=a−μIσI2+ϵaL^=a−μLσL2+ϵ(2)\begin{array}{c} \rho \cdot IN+(1-\rho) \cdot LN \\ \hat{a_{I}}=\frac{a-\mu_{I}}{\sqrt{\sigma_{I}^{2}+\epsilon}} \\ \hat{a_{L}}=\frac{a-\mu_{L}}{\sqrt{\sigma_{L}^{2}+\epsilon}} \\ \end{array}\tag2 ρ⋅IN+(1−ρ)⋅LNaI^=σI2+ϵa−μIaL^=σL2+ϵa−μL(2)
- 层归一化(Layer Norm,LN):通道(channel)方向做归一化,算CHW(通道、高、宽)的均值,主要对RNN作用明显;更多的考虑输入特征通道之间的相关性,LN比IN风格转换更彻底,但是语义信息保存不足;
- 实例归一化(Instance Norm,IN):一个通道(channel)内做归一化,算H*W的均值,用在风格化迁移;因为在图像风格化中,生成结果主要依赖于某个图像实例,所以对整个batch归一化不适合图像风格化中,因而对HW做归一化。可以加速模型收敛,并且保持每个图像实例之间的独立;更多考虑单个特征通道的内容,IN比LN更好的保存原图像的语义信息,但是风格转换不彻底。
3.2 判别器
3.3 损失函数
模型包括四个损失函数:
- 对抗损失:Llsgans→t=(Ex∼Xt[(Dt(x))2]+Ex∼Xs[(1−Dt(Gs→t(x)))2])L_{l s g a n}^{s \rightarrow t}=\left(\mathbb{E}_{x \sim X_{t}}\left[\left(D_{t}(x)\right)^{2}\right]+\mathbb{E}_{x \sim X_{s}}\left[\left(1-D_{t}\left(G_{s \rightarrow t}(x)\right)\right)^{2}\right]\right)Llsgans→t=(Ex∼Xt[(Dt(x))2]+Ex∼Xs[(1−Dt(Gs→t(x)))2]),保证风格迁移图像的分布与目标图像分布相匹配;
- 循环损失:Lcycle s→t=Ex∼Xs[∥x−Gt→s(Gs→t(x)))∥1]\left.L_{\text {cycle }}^{s \rightarrow t}=\mathrm{E}_{x \sim X_{s}}\left[\| x-G_{t \rightarrow s}\left(G_{s \rightarrow t}(x)\right)\right) \|_{1}\right]Lcycle s→t=Ex∼Xs[∥x−Gt→s(Gs→t(x)))∥1],保证一个图像x∈Xsx \in X_sx∈Xs,在从XsX_sXs到XtX_tXt,XtX_tXt到XsX_sXs一系列转化后,该图像能成功的转化回原始域;
- 一致性损失:Lidentity s→t=Ex∼Xt[∥x−Gs→t(x)∥1]L_{\text {identity }}^{s \rightarrow t}=\mathrm{E}_{x \sim X t}\left[\left\|x-G_{s \rightarrow t}(x)\right\|_{1}\right]Lidentity s→t=Ex∼Xt[∥x−Gs→t(x)∥1],保证输入图像与输出图像的颜色分布相似,给定一个图像x∈Xtx \in X_tx∈Xt,在使用Gs→tG_{s→t}Gs→t翻译之后,图像不应该改变;
- 分类激活映射损失:Lcams→t=−(Ex∼Xs[log(ηs(x))]+Ex∼Xt[log(1−ηs(x))]LcamDt=Ex∼Xt[(ηDt(x))2]+Ex∼Xs[(1−ηDt(Gs→t(x))2]\begin{array}{l} L_{c a m}^{s \rightarrow t}=-\left(\mathrm{E}_{x \sim X_{s}}\left[\log \left(\eta_{s}(x)\right)\right]+\mathrm{E}_{x \sim X_{t}}\left[\log \left(1-\eta_{s}(x)\right)\right]\right. \\ L_{c a m}^{D t}=\mathrm{E}_{x \sim X_{t}}\left[\left(\eta_{D t}(x)\right)^{2}\right]+\mathrm{E}_{x \sim X_{s}}\left[\left(1-\eta_{D t}\left(G_{s \rightarrow t}(x)\right)^{2}\right]\right. \end{array}Lcams→t=−(Ex∼Xs[log(ηs(x))]+Ex∼Xt[log(1−ηs(x))]LcamDt=Ex∼Xt[(ηDt(x))2]+Ex∼Xs[(1−ηDt(Gs→t(x))2],辅助分类器ηsη_sηs和ηDtη_{D_t}ηDt带来的损失。
最后,联合训练编码器、解码器、判别器和辅助分类器以优化最终目标函数:
minGs→t,Gt→s,ηs,ηtmaxDs,Dt,ηDs,ηDtλ1Llsgan +λ2Lcycle +λ3Lidentity +λ4Lcam \min _{G_{s \rightarrow t}, G_{t \rightarrow s}, \eta_{s}, \eta_{t}} \max _{D_{s}, D_{t}, \eta_{D_{s}}, \eta_{D_{t}}} \lambda_{1} L_{\text {lsgan }}+\lambda_{2} L_{\text {cycle }}+\lambda_{3} L_{\text {identity }}+\lambda_{4} L_{\text {cam }} Gs→t,Gt→s,ηs,ηtminDs,Dt,ηDs,ηDtmaxλ1Llsgan +λ2Lcycle +λ3Lidentity +λ4Lcam
其中λ1=1,λ2=10,λ3=10,λ4=1000\lambda_{1}=1, \lambda_{2}=10, \lambda_{3}=10, \lambda_{4}=1000λ1=1,λ2=10,λ3=10,λ4=1000, Llsgan =Llsgan s→t+Llsgan t→s,Lcycle =Lcycle s→t+Lcycle t→s,Lidentity =Ldentity s→t+Lidentity t→s,Lcam =Lcam s→t+Lcam t→sL_{\text {lsgan }}=L_{\text {lsgan }}^{s \rightarrow t}+L_{\text {lsgan }}^{t \rightarrow s}, L_{\text {cycle }}=L_{\text {cycle }}^{s \rightarrow t}+L_{\text {cycle }}^{t \rightarrow s}, L_{\text {identity }}=L_{\text {dentity }}^{s \rightarrow t}+L_{\text {identity }}^{t \rightarrow s}, L_{\text {cam }}=L_{\text {cam }}^{s \rightarrow t}+L_{\text {cam }}^{t \rightarrow s}Llsgan =Llsgan s→t+Llsgan t→s,Lcycle =Lcycle s→t+Lcycle t→s,Lidentity =Ldentity s→t+Lidentity t→s,Lcam =Lcam s→t+Lcam t→s。