circlegan_CycleGAN原理以及代码全解析

许多名画造假者费尽毕生的心血,试图模仿出艺术名家的风格。如今,CycleGAN就可以初步实现这个神奇的功能。这个功能就是风格迁移,比如下图,照片可以被赋予莫奈,梵高等人的绘画风格

这属于是无配对数据(unpaired)产生的图片,也就是说你有一些名人名家的作品,也有一些你想转换风格的真实图片,这两种图片是没有任何交集的。在之前的文章(用AI增强人类想象力)中提到的Pix2Pix方法的关键是提供了在这两个域中有相同数据的训练样本。CycleGAN的创新点在于能够在源域和目标域之间,无须建立训练数据间一对一的映射,就可实现这种迁移

想要做到这点,有两个比较重要的点,第一个就是双判别器。如上图a所示,两个分布X,Y,生成器G,F分别是X到Y和Y到X的映射,两个判别器Dx,Dy可以对转换后的图片进行判别。第二个点就是cycle-consistency loss,用数据集中其他的图来检验生成器,这是防止G和F过拟合,比如想把一个小狗照片转化成梵高风格,如果没有cycle-consistency loss,生成器可能会生成一张梵高真实画作来骗过Dx,而无视输入的小狗。

需要注意的是,广为流传的下图,有个容易让人理解错误的地方,那就是下图中的input和output那几张图,两匹马应该除了花纹其他一致的,除此之外,结构还是挺清晰的

对抗损失

生成器和判别器的loss函数和GAN是一样的,判别器D尽力检测出生成器G产生的假图片,生成器尽力生成图片骗过判别器,具体数理推导可以看我专栏之前的文章李刚:GAN 对抗生成网络入门辅助理解​zhuanlan.zhihu.com

对抗loss由两部分组成:

以及

Cycle Consistency 损失

作者说:理论上,对抗训练可以学习映射输出G和F,它们分别作为目标域Y和X产生相同的分布。然而,具有足够大的容量,网络可以将相同的输入图像集合映射到目标域中的任何图像的随机排列。因此,单独的对抗性loss不能保证可以映射单个输入。需要另外来一个loss,保证G和F不仅能满足各自的判别器,还能应用于其他图片。也就是说,G和F可能合伙偷懒骗人,给G一个图,G偷偷把小狗变成梵高自画像,F再把梵高自画像变成输入。Cycle Consistency loss的到来制止了这种投机取巧的行为,他用梵高其他的画作测试FG,用另外真实照片测试GF,看看能否变回到原来的样子,这样保证了GF在整个X,Y分布区间的普适性。

整体

所以,整个loss就是下面的式子,就像训练两个auoto-encoder一样

作者在后文比对了单独拿出不同部分的效果,比如只用Cycle Consistency loss,只用对抗,GAN + 前向cycle-consistency loss (F(G(x)) ≈ x),, GAN + 后向 cycle-consistency loss (G(F(y)) ≈ y),以及cycleGAN的效果。

代码实现

首先是一些参数

ngf = 32 # Number of filters in first layer of generator

ndf = 64 # Number of filters in first layer of discriminator

batch_size = 1 # batch_size

pool_size = 50 # pool_size

img_width = 256 # Imput image will of width 256

img_height = 256 # Input image will be of height 256

img_depth = 3 # RGB format

构造生成器Generator(Encoder+Transformer+Decoder)

假设所有图片都是256*256的彩图,需要先用卷积神经网络提取特征,在这里,input_gen是输入图像,num_features是我们从卷积层中提取出的输出特征的数量(滤波器的数量)window_width,window_height代表滤波器尺寸。 stride_width,strideheight是滤波器如何在整个图上移动的参数。输出的O_C1是尺寸[256,256,32]的矩阵。也可以在后边自行添加Relu等函数。

o_c1 = general_conv2d(input_gen,

num_features=ngf,

window_width=7,

window_height=7,

stride_width=1,

stride_height=1)

#定义卷积层函数

def general_conv2d(inputconv, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1):

with tf.variable_scope(name):

conv = tf.contrib.layers.conv2d(inputconv, num_features, [window_width, window_height], [stride_width, stride_height],

padding, activation_fn=None, weights_initializer=tf.truncated_normal_initializer(stddev=stddev),

biases_initializer=tf.constant_initializer(0.0))

后面是相似的卷积步骤,最后一层输出o_enc_A是(64,64,256)的矩阵

o_c2 = general_conv2d(o_c1, num_features=64*2, window_width=3, window_height=3, stride_width=2, stride_height=2)

# o_c2.shape = (128, 128, 128)

o_enc_A = general_conv2d(o_c2, num_features=64*4, window_width=3, window_height=3, stride_width=2, stride_height=2)

# o_enc_A.shape = (64, 64, 256)

Transformer可以将这些层视为图像的不同附近特征的组合,然后基于这些特征来决定如何将图像的特征向量转换到另一个分布。作者使用了6层resnet块,其中输入的残差被添加到输出中。这样做是为了确保先前层的输入的属性也可用于以后的层,因此它们的输出不会偏离原始输入,否则原始图像的特性将不被保留在输出中。任务的主要目的之一是保留原始输入的特性,如对象的大小和形状,因此残差网络非常适合这些类型的变换。关于resnet,详见 ResNet原理及其在TF-Slim中的实现

o_r1 = build_resnet_block(o_enc_A, num_features=64*4)

o_r2 = build_resnet_block(o_r1, num_features=64*4)

o_r3 = build_resnet_block(o_r2, num_features=64*4)

o_r4 = build_resnet_block(o_r3, num_features=64*4)

o_r5 = build_resnet_block(o_r4, num_features=64*4)

o_enc_B = build_resnet_block(o_r5, num_features=64*4)

#定义resnet

def resnet_blocks(input_res, num_features):

out_res_1 = general_conv2d(input_res, num_features,

window_width=3,

window_heigth=3,

stride_width=1,

stride_heigth=1)

out_res_2 = general_conv2d(out_res_1, num_features,

window_width=3,

window_heigth=3,

stride_width=1,

stride_heigth=1)

return (out_res_2 + input_res)

下面是decoder,用反卷积把这些特征变回成图片

o_d1 = general_deconv2d(o_enc_B, num_features=ngf*2 window_width=3, window_height=3, stride_width=2, stride_height=2)

o_d2 = general_deconv2d(o_d1, num_features=ngf, window_width=3, window_height=3, stride_width=2, stride_height=2)

gen_B = general_conv2d(o_d2, num_features=3, window_width=7, window_height=7, stride_width=1, stride_height=1)

#定义反卷积层

def general_deconv2d(inputconv, outshape, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1, stddev=0.02, padding="VALID", name="deconv2d", do_norm=True, do_relu=True, relufactor=0):

with tf.variable_scope(name):

conv = tf.contrib.layers.conv2d_transpose(inputconv, o_d, [f_h, f_w], [s_h, s_w], padding, activation_fn=None, weights_initializer=tf.truncated_normal_initializer(stddev=stddev),biases_initializer=tf.constant_initializer(0.0))

if do_norm:

conv = instance_norm(conv)

# conv = tf.contrib.layers.batch_norm(conv, decay=0.9, updates_collections=None, epsilon=1e-5, scale=True, scope="batch_norm")

if do_relu:

if(relufactor == 0):

conv = tf.nn.relu(conv,"relu")

else:

conv = lrelu(conv, relufactor, "lrelu")

return conv

判别器的构成在这里救不赘述了,无非就是用CNN把生成的图片变成一些特征图,再用全连接变成最后的decision(真或假)

定义loss function

判别器loss:loss_1是对于真图的判定,越接近1越好,loss_2是对于假图的判定,越接近0越好,loss是两个loss相加

D_A_loss_1 = tf.reduce_mean(tf.squared_difference(dec_A,1))

D_B_loss_1 = tf.reduce_mean(tf.squared_difference(dec_B,1))

D_A_loss_2 = tf.reduce_mean(tf.square(dec_gen_A))

D_B_loss_2 = tf.reduce_mean(tf.square(dec_gen_B))

D_A_loss = (D_A_loss_1 + D_A_loss_2)/2

D_B_loss = (D_B_loss_1 + D_B_loss_2)/2

生成器loss:

g_loss_B_1 = tf.reduce_mean(tf.squared_difference(dec_gen_A,1))

g_loss_A_1 = tf.reduce_mean(tf.squared_difference(dec_gen_A,1))

Cycle Consistency loss: 保证原始图像和循环图像之间的差异应该尽可能小,注意10*cyc_loss是赋予Cycle Consistency loss更大的权值,作者并没有讨论这个参数是怎么确定下来的

cyc_loss = tf.reduce_mean(tf.abs(input_A-cyc_A)) + tf.reduce_mean(tf.abs(input_B-cyc_B))

g_loss_A = g_loss_A_1 + 10*cyc_loss

g_loss_B = g_loss_B_1 + 10*cyc_loss

模型训练

for epoch in range(0,100):

# Define the learning rate schedule. The learning rate is kept

# constant upto 100 epochs and then slowly decayed

if(epoch < 100) :

curr_lr = 0.0002

else:

curr_lr = 0.0002 - 0.0002*(epoch-100)/100

# Running the training loop for all batches

for ptr in range(0,num_images):

# Train generator G_A->B

_, gen_B_temp = sess.run([g_A_trainer, gen_B],

feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})

# We need gen_B_temp because to calculate the error in training D_B

_ = sess.run([d_B_trainer],

feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})

# Same for G_B->A and D_A as follow

_, gen_A_temp = sess.run([g_B_trainer, gen_A],

feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})

_ = sess.run([d_A_trainer],

feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/498743.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

魅族16无信号服务器,魅族16信号差的解决办法

手机信号问题一直都是人们关注的问题&#xff0c;在日常使用时有些地方手机可能出现突然信号变差&#xff0c;可能别人的手机信号一直很好只有你的出现了问题。魅族手机最早的几个版本都很容易出现这种问题&#xff0c;新款的魅族16怎么样呢&#xff1f;魅族16信号差怎么解决呢…

服务器系统核心和带gui区别,Windows Server 2012图形用户界面(GUI)和服务器核心(Server Core)之间的切换...

当安装 Windows Server 2012 时&#xff0c;咱们能够在“服务器核心安装”和“彻底安装”之间任选其一。“带 GUI 选项的服务器”选项Windows Server 2012 等效于 Windows Server 2008 R2 中的彻底安装选项。“服务器核心安装”选项可减小所需的磁盘空间、潜在的***面&#xff…

用python控制键盘_【python黑科技-pyautogui】教你用python控制滑鼠与键盘

今天来聊聊python中非常酷炫的功能&#xff0c;控制滑鼠与键盘&#xff0c;理论上&#xff0c;如果你程序功力非常强的话&#xff0c;甚至可以用这个模组写出一个「游戏插件」&#xff0c;让程序操控你的滑鼠与键盘帮你玩游戏本文测试环境: anaconda, spyder, python3.7安装pya…

ulimit限制 新系统_Linux查看及修改系统的资源限制命令ulimit

在Linux&#xff0c;查看系统对资源使用的显示可以使用命令ulimit&#xff0c;其中参数-a会列出所有的资源使用限制。[demoserver ~]$ ulimit -acore file size (blocks, -c) 0data seg size (kbytes, -d) unlimitedscheduling priority (-e) 0file size (blocks, -f) unlimite…

华为策略路由加等价路由_华为——防火墙——策略路由配置及思路

华为——策略路由(校园网配置)作用&#xff1a;通过分析数据报的源地址和目标地址&#xff0c;按照策略规则选择不同的网关&#xff0c;进行数据转发。提供冗余&#xff0c;负载&#xff0c;但是还是单线路的速度。只是提供了不同的方向&#xff0c;并没有进行合并线路。拓扑图…

简述数学建模的过程_中文字幕乱码文字2020_MDTM-198加勒比中文字幕合集 - 第5页...

Well, you won’t get one from me.Nah, I never let a lady treat.I try to ease away and create a larger space cushion, but he steps toward me again. I don’t feel threatened by him, however. He’s a big guy, but not menacing. He isn’t trying to bully me wit…

头条自己提问的问题在哪看_在头条的这三十天

文、图&#xff1a;书海履痕今天入头条三十天&#xff0c;按民间俗语&#xff0c;满月了。 三十个日子&#xff0c;真得是感慨万千。特别是昨日的文章&#xff0c;经头条君和各位友友们的厚爱&#xff0c;让我经历了过山车的感觉&#xff0c;各种滋味存于心底&#xff0c;在此谢…

c可以 char* 赋值但是c++不可以_雷佳音的妻子完全可以女团C位出道,这么有气质的女人,谁能不爱...

导读&#xff1a;雷佳音的妻子完全可以女团C位出道&#xff0c;这么有气质的女人&#xff0c;谁能不爱各位点开这篇文章的朋友们&#xff0c;想必都是很高的颜值吧&#xff0c;我们真的是很有缘哦&#xff0c;小编每天都会给大家带来不一样的汽车资讯&#xff0c;如果对小编的文…

【加解密学习笔记:第一天】操作系统基础知识

加解密相关系统基础知识 Unicode编码格式 Unicode编码中使用2字节对字符进行编码&#xff0c;对ASCLL码的支持通过愿为不变&#xff0c;高位补零实现一个字有2字节&#xff0c;Intel在存入储存器时低位入低地址&#xff0c;高位入高地址&#xff08;Little-endian&#xff09…

oracle sequence 不同 会话 不连续_序列 Sequence

Sequence是一个数据库对象&#xff0c;多个用户可以从中生成唯一的整数&#xff0c;可以使用序列自动生成主键值。生成序列号时&#xff0c;序列号将递增&#xff0c;独立于事务提交或回滚;如果两个用户同时递增同一序列&#xff0c;因为序列号是由另一个用户生成的&#xff0c…

【加解密学习笔记:第二天】动态调试工具OllyDbg使用基础介绍

首先说一下OllyDbg的界面&#xff0c;如下图所示 下面依次介绍&#xff1a; 反汇编面板&#xff1a;有四列&#xff0c;从左到右依次为&#xff1a;地址&#xff08;Address&#xff09;&#xff0c;机器码&#xff08;Hex dump&#xff09;&#xff0c;反汇编代码&#xff08…

dmp只导数据不导结构_今日头条快消食品推广CVR为何高达4.40%?原来DMP定向这么好...

摘要&#xff1a;据艾媒报告显示&#xff0c;当前快消品消费在中国居民消费的比重已经占到34.6%&#xff0c;无疑是一支重要力量。虽然消费者的消费能力在不断提升&#xff0c;但快消行业的推广仍多受制于传统模式&#xff0c;应该怎么寻找出路呢&#xff1f;一、企业介绍客户L…

【加解密学习笔记:第三天】OllyDbg断点介绍

INT 3 断点 常用断点&#xff0c;使用“F2”快捷键设置的就是 INT 3 断点。这类断点采用修改机器码的方式&#xff0c;将设断处的代码更改为 “CC”&#xff0c;当程序运行至设断处时&#xff0c;会抛出一个异常&#xff0c;OllyDbg会捕捉到这个异常&#xff0c;使得程序暂停&a…

c# 十六进制转为字节_C# 16进制与字符串、字节数组之间的转换

1.请问c#中如何将十进制数的字符串转化成十六进制数的字符串//十进制转二进制Console.WriteLine("十进制166的二进制表示: "Convert.ToString(166, 2));//十进制转八进制Console.WriteLine("十进制166的八进制表示: "Convert.ToString(166, 8));//十进制转…

SHA-1算法详解和C++实现

SHA-1算法详解和C实现 背景介绍 SHA-1算法也称安全散列算法1&#xff0c;可以将一个最大264−12^{64}-1264−1的数据生成一个160位的数据摘要。尽管SHA-1算法已经被认为不再安全&#xff0c;但仍有部分应用使用SHA-1算法验证文件。 算法原理 类型定义 在介绍算法原理之前&…

python的socket模块_python模块:socket模块

1.Socket类型socket(family,type[,protocal]) 使用给定的地址族,套接字类型,协议编号(默认是0)来创建套接字socket类型描述socket.AF_UNIX只能够用于单一的Unix系统进程间通信socket.AF_INET服务器之间网络通信socket.AF_INET6IPv6socket.SOCK_STREAM流式socket , for TCPs…

完整性校验用到常见的算法_校验数据的完整性,校验数据完整性,使用MD5/SHA算法校...

校验数据的完整性&#xff0c;校验数据完整性,使用MD5/SHA算法校使用MD5/SHA算法校验数据的完整性package cn.itcast.gz;import java.io.File;import java.io.FileInputStream;import java.security.DigestInputStream;import java.security.MessageDigest;/** * 主要用于验证数…

python超市买苹果_官网购买的iPhone12pro还没发货?试着用Python快速入手

引言​iPhone12pro有望成为2020年末真香机&#xff0c;动辄3000元的溢价让不少消费者选择了等待官网调货。除了官方与电商线上平台&#xff0c;苹果还采用了线下预约制提货。但每天少的可怜的出货量&#xff0c;和不到一秒钟就抢空的预约名额让“老年人”手速的各位望而却步。传…

server多列转行 sql_sql server 行转列及列转行的使用

在我们使用的数据库表中经常需要用到行列互相转换的情况&#xff0c;使用sql 的关键词 UNPIVOT(列转行)和PIVOT(行转列)可轻松实现行列转换。一、列转行&#xff1a;员工月份排班表存储是采用1号~31号作为列的方式进行存储的现通过 UNPIVOT 将每天的班次用行进行展示&#xff0…

git 拉取远程其他分支代码_git切换远程分支并拉取远程分支代码

Git一般有很多分支&#xff0c;我们clone到本地的时候一般都是master分支&#xff0c;那么如何切换到其他分支呢&#xff1f;主要命令如下&#xff1a;1. 查看远程分支$ git branch -a我在mxnet根目录下运行以上命令&#xff1a;~/mxnet$ git branch -a* master可以看到&#x…