- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
目录
- 模型结构
模型结构
之前几期打卡中,已经介绍过GAN CGAN SGAN,而ACGAN属于上述几种GAN的缝合怪,其模型的结构图如下:
通过对模型图的分析我们可以发现,它是在SGAN的基础上,引入了CGAN的思想,这也说明我上一期中对SGAN的一些理解是错误的,SGAN实际上只是让生成的图像更具有不同分类的差异,效果更好,但是没有对应的控制生成的能力。
想要给SGAN增加控制生成的能力,我们就需要像CGAN一样引入控制量C
在生成器中:它使用一个label标签传入,经过一个嵌入层+全连接层转换为向量并合并到特征向量中。
在判别器中也是如此,将标签传入,通过嵌入后合并到特征向量中,然后传入到判别网络。
需要注意的是生成器中的输入z的维度和判别器中的潜在维度是不一样的。
代码大致为
# 在生成器的__init__方法中添加一个属性
self.condition_embedding = nn.Sequential(nn.Embedding(n_classes, embedding_dim),nn.Linear(embedding_dim, self.init_size*self.init*size)
)# forward方法
out = xxxx
label_output = self.condition_embedding(label)
features = torch.concat((out, label_output), dim =1)
return self.model(features)# 在判别器的__init__方法中添加一个属性
self.condition_embedding = nn.Sequential(nn.Embedding(n_classes, embedding_dim),nn.Linear(embedding_dim, 128*ds_size**2)
)# forward方法
out = xxx
label_output = self.condition_embedding(label)
features = torch.concat((out, label_output), dim=1)
validity = self.adv_layer(features)
label = self.aux_layer(features)
return validity, label
由于下次打卡才是真正的代码,这节就先把思路写一下,验证下周再发。