训练的代码见于train.py,首先定义好网络,两个生成器A2B, B2A和两个判别器A, B,以及对应的优化器(优化器的设置保证了只更新生成器或判别器,不会互相影响)
###### Definition of variables #######Networks
netG_A2B =Generator(opt.input_nc, opt.output_nc)
netG_B2A=Generator(opt.output_nc, opt.input_nc)
netD_A=Discriminator(opt.input_nc)
netD_B= Discriminator(opt.output_nc)
#Optimizers & LR schedulers
optimizer_G =torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
lr=opt.lr, betas=(0.5, 0.999))
optimizer_D_A= torch.optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999))
optimizer_D_B= torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999))
然后是数据
#Dataset loader
transforms_ = [ transforms.Resize(int(opt.size*1.12), Image.BICUBIC),
transforms.RandomCrop(opt.size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
dataloader= DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_, unaligned=True),
batch_size=opt.batchSize, shuffle=True, num_workers=opt.n_cpu)
接着就可以求取损失,反传梯度,更新网络,更新网络的时候首先更新生成器,然后分别更新两个判别器
生成器:损失函数=身份损失+对抗损失+循环一致损失
###### Generators A2B and B2A ######
optimizer_G.zero_grad()#Identity loss
#G_A2B(B) should equal B if real B is fed
same_B =netG_A2B(real_B)
loss_identity_B= criterion_identity(same_B, real_B)*5.0
#G_B2A(A) should equal A if real A is fed
same_A =netG_B2A(real_A)
loss_identity_A= criterion_identity(same_A, real_A)*5.0
#GAN loss
fake_B =netG_A2B(real_A)
pred_fake=netD_B(fake_B)
loss_GAN_A2B=criterion_GAN(pred_fake, target_real)
fake_A=netG_B2A(real_B)
pred_fake=netD_A(fake_A)
loss_GAN_B2A=criterion_GAN(pred_fake, target_real)#Cycle loss
recovered_A =netG_B2A(fake_B)
loss_cycle_ABA= criterion_cycle(recovered_A, real_A)*10.0recovered_B=netG_A2B(fake_A)
loss_cycle_BAB= criterion_cycle(recovered_B, real_B)*10.0
#Total loss
loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA +loss_cycle_BAB
loss_G.backward()
optimizer_G.step()
判别器A 损失函数= 真实样本分类损失 + 虚假样本分类损失
###### Discriminator A ######
optimizer_D_A.zero_grad()#Real loss
pred_real =netD_A(real_A)
loss_D_real=criterion_GAN(pred_real, target_real)#Fake loss
fake_A =fake_A_buffer.push_and_pop(fake_A)
pred_fake=netD_A(fake_A.detach())
loss_D_fake=criterion_GAN(pred_fake, target_fake)#Total loss
loss_D_A = (loss_D_real + loss_D_fake)*0.5loss_D_A.backward()
optimizer_D_A.step()###################################
判别器B损失函数= 真实样本分类损失 + 虚假样本分类损失
###### Discriminator B ######
optimizer_D_B.zero_grad()#Real loss
pred_real =netD_B(real_B)
loss_D_real=criterion_GAN(pred_real, target_real)#Fake loss
fake_B =fake_B_buffer.push_and_pop(fake_B)
pred_fake=netD_B(fake_B.detach())
loss_D_fake=criterion_GAN(pred_fake, target_fake)#Total loss
loss_D_B = (loss_D_real + loss_D_fake)*0.5loss_D_B.backward()
optimizer_D_B.step()###################################
可以注意到,判别器损失中,虚假样本fake_A,fake_B都采用detach()操作,脱离计算图,这样判别器的损失进行反向传播不会对整个网络计算梯度,避免了不必要的计算