circlegan_【源码解读】cycleGAN(二) :训练

训练的代码见于train.py,首先定义好网络,两个生成器A2B, B2A和两个判别器A, B,以及对应的优化器(优化器的设置保证了只更新生成器或判别器,不会互相影响)

###### Definition of variables #######Networks

netG_A2B =Generator(opt.input_nc, opt.output_nc)

netG_B2A=Generator(opt.output_nc, opt.input_nc)

netD_A=Discriminator(opt.input_nc)

netD_B= Discriminator(opt.output_nc)

#Optimizers & LR schedulers

optimizer_G =torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),

lr=opt.lr, betas=(0.5, 0.999))

optimizer_D_A= torch.optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999))

optimizer_D_B= torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999))

然后是数据

#Dataset loader

transforms_ = [ transforms.Resize(int(opt.size*1.12), Image.BICUBIC),

transforms.RandomCrop(opt.size),

transforms.RandomHorizontalFlip(),

transforms.ToTensor(),

transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]

dataloader= DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_, unaligned=True),

batch_size=opt.batchSize, shuffle=True, num_workers=opt.n_cpu)

接着就可以求取损失,反传梯度,更新网络,更新网络的时候首先更新生成器,然后分别更新两个判别器

生成器:损失函数=身份损失+对抗损失+循环一致损失

###### Generators A2B and B2A ######

optimizer_G.zero_grad()#Identity loss

#G_A2B(B) should equal B if real B is fed

same_B =netG_A2B(real_B)

loss_identity_B= criterion_identity(same_B, real_B)*5.0

#G_B2A(A) should equal A if real A is fed

same_A =netG_B2A(real_A)

loss_identity_A= criterion_identity(same_A, real_A)*5.0

#GAN loss

fake_B =netG_A2B(real_A)

pred_fake=netD_B(fake_B)

loss_GAN_A2B=criterion_GAN(pred_fake, target_real)

fake_A=netG_B2A(real_B)

pred_fake=netD_A(fake_A)

loss_GAN_B2A=criterion_GAN(pred_fake, target_real)#Cycle loss

recovered_A =netG_B2A(fake_B)

loss_cycle_ABA= criterion_cycle(recovered_A, real_A)*10.0recovered_B=netG_A2B(fake_A)

loss_cycle_BAB= criterion_cycle(recovered_B, real_B)*10.0

#Total loss

loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA +loss_cycle_BAB

loss_G.backward()

optimizer_G.step()

判别器A 损失函数= 真实样本分类损失 + 虚假样本分类损失

###### Discriminator A ######

optimizer_D_A.zero_grad()#Real loss

pred_real =netD_A(real_A)

loss_D_real=criterion_GAN(pred_real, target_real)#Fake loss

fake_A =fake_A_buffer.push_and_pop(fake_A)

pred_fake=netD_A(fake_A.detach())

loss_D_fake=criterion_GAN(pred_fake, target_fake)#Total loss

loss_D_A = (loss_D_real + loss_D_fake)*0.5loss_D_A.backward()

optimizer_D_A.step()###################################

判别器B损失函数= 真实样本分类损失 + 虚假样本分类损失

###### Discriminator B ######

optimizer_D_B.zero_grad()#Real loss

pred_real =netD_B(real_B)

loss_D_real=criterion_GAN(pred_real, target_real)#Fake loss

fake_B =fake_B_buffer.push_and_pop(fake_B)

pred_fake=netD_B(fake_B.detach())

loss_D_fake=criterion_GAN(pred_fake, target_fake)#Total loss

loss_D_B = (loss_D_real + loss_D_fake)*0.5loss_D_B.backward()

optimizer_D_B.step()###################################

可以注意到,判别器损失中,虚假样本fake_A,fake_B都采用detach()操作,脱离计算图,这样判别器的损失进行反向传播不会对整个网络计算梯度,避免了不必要的计算

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/527432.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

电脑重启bootmgr_电脑出现bootmgr is missing怎么办

展开全部电脑开机,或者重启以后显示:Bootmgr is missing, 是代表硬盘的主引导记录(MBR)出错,从而导致无法引e68a8462616964757a686964616f31333337613931导系统,哪只能重建主引导记录,还有一个可能是丢失系统文件&…

python tkinter text改变文本字体颜色_如何更改Tkinter中文本的颜色?

在Tkinter图形用户界面中,我无法确定如何更改文本的颜色。我试着让Label1变成红色,Label2变成蓝色,Label3变成棕色,Label4变成黄色,但我似乎想不出来。提前谢谢:)import randomfrom Tkinter import * #dont…

qt 在label上以光标位置进行缩放_缩放|位移|渐变简单动画

本文简单介绍Qt的一些动画效果(缩放,位移,渐变)。缩放动画将窗口的geometry(位置,大小)属性作为动画参考实现缩放动画。代码QWidget *w new QWidget;w->setWindowTitle(QStringLiteral("缩放动画Qt君"));w->resize(320, 240)…

虚拟机中ubuntu可以使用显卡吗_在KVM下使用ubuntu19.10安装Anbox

导言:Anbox是一个Android模拟器,可以从linux系统运行Android应用程序或游戏。对于Anbox的安装已经有了各种教程,主要针对ubuntu18.04之前的版本。最近在做一个关于虚拟机中跑安卓的项目,因此在虚拟机中使用ubuntu18.04系统&#x…

mysql tree_MySQL树形遍历(二)

转载自:http://blog.csdn.net/dreamer0924/article/details/7580278英文原文:http://mikehillyer.com/articles/managing-hierarchical-data-in-mysql/预排序遍历树算法:modified preorder tree traversal algorithm这个算法有如下几个数据结构1 lft 代表左 left2 r…

python导入pillow模块_Python:argparse模块和pillow-image

刚入门学python,最近照着实验楼做了一个基础的练手项目:图像转字符画,里面用到了argparse和pillow-image。看了python关于这个函数的介绍和网上的一些教程,想把重点整理出来,正好最近发现前一天还挺明白的内容&#xf…

mysql 常用数据库连接池_常见的数据库连接池

欢迎进入Java社区论坛,与200万技术人员互动交流 >>进入 2.C3P0 在Hibernate和Spring中默认支持该数据库连接池 需要引入:c3p0-0.9.1.2.jar包,如果报错再引入mchange-commons-0.2.jar 1. 在类路径下编写一个c3p0-config.xml文件 c3p0-co…

win32_bios 的对象编辑器无法保存对象_怎样创建Femap对象

创建Femap对象主要有两种方式,一是直接在Femap内置的API程序窗体中创建,二是在API程序窗口以外的开发环境中创建。一、使用FEMAP集成的API程序窗口开始使用FEMAP API的最快方法是打开API编程窗口。它提供了一个完整的编辑、调试和运行的环境,…

nginx 在阿里云怎么安装mysql_阿里云Linux服务器安装 nginx+mysql+php

阿里云Linux服务器安装 nginxmysqlphp步骤1、登录服务器2、下载安装包3、将安装包上传到服务器的/home目录下注:使用rz sz命令进行本地和服务器间的上传、下载,安装命令yum install -y lrzsz4、解压安装包注:使用yum install unzip -y安装解压工具,安装完…

未定义变量: data_三、变量声明

三、变量声明var声明主要特点: - var是函数作用域,只针对函数声明 - 可以多次声明同一个变量不会报错 - 捕获变量怪异之处function fnVar(flag: boolean) {if(flag) {var x 10;}return x; } fnVar(true); // 10 fnVar(false); // undefinedvar isDone: …

阿帕奇链接mysql_apache guacamole 使用mysql 连接

1.创建一个临时文件夹,用来存放mysql-java连接器mkdir tempauth2.下载相关文件cd tempauthwget https://jaist.dl.sourceforge.net/project/guacamole/current/extensions/guacamole-auth-jdbc-0.9.14.tar.gzwget https://cdn.mysql.com//Downloads/Connector-J/mys…

表单的默认提交方式_对于PHP表单提交有哪集中方式讲解

PHP 做网页后端还是很优秀的&#xff0c;PHP 表单提交&#xff0c;不外乎两种方法&#xff0c;即 GET 和 POST 方法&#xff1b;PHP后台使用全局变量$_POST;$_GET;来获取提交数据。代码&#xff1a;<!DOCTYPE HTML> <html> <head><meta charset"utf-…

spring中怎么让事物提交_Spring怎么在一个事务中开启另一个事务

点击上方“Java知音”&#xff0c;选择“置顶公众号”技术文章第一时间送达&#xff01;作者&#xff1a;Mazinmy.oschina.net/u/3441184/blog/893628Spring项目&#xff0c;需要在一个事务中开启另一个事务。上面提到的情景可能不常见&#xff0c;但是还是会有的&#xff0c;一…

通过对象指针的方式强行指定到子类_C++中的虚指针与虚函数表

​ 最近在逛B站的时候发现有候捷老师的课程&#xff0c;如获至宝。因此&#xff0c;跟随他的讲解又复习了一遍关于C的内容&#xff0c;收获也非常的大&#xff0c;对于某些模糊的概念及遗忘的内容又有了更深的认识。以下内容是关于虚函数表、虚函数指针&#xff0c;而C中的动态…

datax oracle mysql_从 MySQL 到 Lindorm时序引擎 的数据迁移

背景本文主要介绍如何使用阿里巴巴的开源工具Datax 实现从 MySQL 到 时序引擎 的数据迁移。DataX相关使用介绍请参阅 DataX 的下面将首先介绍 DataX 工具本身&#xff0c;以及本次迁移工作涉及到的两个插件(MySQL Reader 和 TSDB Writer)。DataXDataX 是阿里巴巴集团内被广泛使…

如何手动输入给数组赋值_你是否真的了解VBA数组呢?让我带你认识一下真正的数组...

大家好&#xff0c;我们今日继续讲解VBA代码解决方案的第110讲内容&#xff1a;VBA数组讲解&#xff0c;什么是数组&#xff0c;如何定义数组&#xff0c;如何创建数组一、什么是数组 就是数组共享一个名字&#xff0c;有着多个元素按顺序排列的变量。在数组中&#xff0c;元素…

redhat9安装mysql_redhat 9.0 安装mysql

在官网上下载了MySQL-5.5.9-1.rhel5.i386.tar包 &#xff0c;将文件以二进制的形式ftp到虚拟机rehat上解压文件到MySQY-5文件夹下&#xff1a;然后将路径切换到解压目录下运行 rpm -ivh *.rpm --force报如下错&#xff1a;rootlocalhost MySQL-5]# rpm -ivh *.rpm --forceerror…

为什么整数在python中表示d_python中整数的缓存机制

在python中&#xff0c;如下代码结果一定不会让你吃惊&#xff1a;Python 3.3.2 (v3.3.2:d047928ae3f6, May 16 2013, 00:06:53) [MSC v.1600 64 bit (AMD64)] on win32Type "copyright", "credits" or "license()" for more information.>&g…

MySQL中序列的作用_MySql中序列的应用和总结

Mysql中的序列主要用于主键&#xff0c;主键是递增的字段&#xff0c;不可重复。Mysql与Oracle不同的是&#xff0c;它不支持原生态的sequence&#xff0c;需要用表和函数的组合来实现类似序列的功能。1.首先创建序列的主表/*Navicat Premium Data TransferSource Server : MyS…

python内置模块重要程度排名_python常用内置模块

#持续更新#在使用内置模块的时候需要导入&#xff0c;例如import abc&#xff0c;则导入abc模块&#xff0c;当然模块也可以自己写&#xff0c;相当于一个类&#xff0c;后面放到类里说&#xff0c;这个因为环境闲置&#xff0c;有些无法执行&#xff0c;只能理解了#os系统操作…