许多名画造假者费尽毕生的心血,试图模仿出艺术名家的风格。如今,CycleGAN就可以初步实现这个神奇的功能。这个功能就是风格迁移,比如下图,照片可以被赋予莫奈,梵高等人的绘画风格
这属于是无配对数据(unpaired)产生的图片,也就是说你有一些名人名家的作品,也有一些你想转换风格的真实图片,这两种图片是没有任何交集的。在之前的文章(用AI增强人类想象力)中提到的Pix2Pix方法的关键是提供了在这两个域中有相同数据的训练样本。CycleGAN的创新点在于能够在源域和目标域之间,无须建立训练数据间一对一的映射,就可实现这种迁移
想要做到这点,有两个比较重要的点,第一个就是双判别器。如上图a所示,两个分布X,Y,生成器G,F分别是X到Y和Y到X的映射,两个判别器Dx,Dy可以对转换后的图片进行判别。第二个点就是cycle-consistency loss,用数据集中其他的图来检验生成器,这是防止G和F过拟合,比如想把一个小狗照片转化成梵高风格,如果没有cycle-consistency loss,生成器可能会生成一张梵高真实画作来骗过Dx,而无视输入的小狗。
需要注意的是,广为流传的下图,有个容易让人理解错误的地方,那就是下图中的input和output那几张图,两匹马应该除了花纹其他一致的,除此之外,结构还是挺清晰的
对抗损失
生成器和判别器的loss函数和GAN是一样的,判别器D尽力检测出生成器G产生的假图片,生成器尽力生成图片骗过判别器,具体数理推导可以看我专栏之前的文章李刚:GAN 对抗生成网络入门辅助理解zhuanlan.zhihu.com
对抗loss由两部分组成:
以及
Cycle Consistency 损失
作者说:理论上,对抗训练可以学习映射输出G和F,它们分别作为目标域Y和X产生相同的分布。然而,具有足够大的容量,网络可以将相同的输入图像集合映射到目标域中的任何图像的随机排列。因此,单独的对抗性loss不能保证可以映射单个输入。需要另外来一个loss,保证G和F不仅能满足各自的判别器,还能应用于其他图片。也就是说,G和F可能合伙偷懒骗人,给G一个图,G偷偷把小狗变成梵高自画像,F再把梵高自画像变成输入。Cycle Consistency loss的到来制止了这种投机取巧的行为,他用梵高其他的画作测试FG,用另外真实照片测试GF,看看能否变回到原来的样子,这样保证了GF在整个X,Y分布区间的普适性。
整体
所以,整个loss就是下面的式子,就像训练两个auoto-encoder一样
作者在后文比对了单独拿出不同部分的效果,比如只用Cycle Consistency loss,只用对抗,GAN + 前向cycle-consistency loss (F(G(x)) ≈ x),, GAN + 后向 cycle-consistency loss (G(F(y)) ≈ y),以及cycleGAN的效果。
代码实现
首先是一些参数
ngf = 32 # Number of filters in first layer of generator
ndf = 64 # Number of filters in first layer of discriminator
batch_size = 1 # batch_size
pool_size = 50 # pool_size
img_width = 256 # Imput image will of width 256
img_height = 256 # Input image will be of height 256
img_depth = 3 # RGB format
构造生成器Generator(Encoder+Transformer+Decoder)
假设所有图片都是256*256的彩图,需要先用卷积神经网络提取特征,在这里,input_gen是输入图像,num_features是我们从卷积层中提取出的输出特征的数量(滤波器的数量)window_width,window_height代表滤波器尺寸。 stride_width,strideheight是滤波器如何在整个图上移动的参数。输出的O_C1是尺寸[256,256,32]的矩阵。也可以在后边自行添加Relu等函数。
o_c1 = general_conv2d(input_gen,
num_features=ngf,
window_width=7,
window_height=7,
stride_width=1,
stride_height=1)
#定义卷积层函数
def general_conv2d(inputconv, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1):
with tf.variable_scope(name):
conv = tf.contrib.layers.conv2d(inputconv, num_features, [window_width, window_height], [stride_width, stride_height],
padding, activation_fn=None, weights_initializer=tf.truncated_normal_initializer(stddev=stddev),
biases_initializer=tf.constant_initializer(0.0))
后面是相似的卷积步骤,最后一层输出o_enc_A是(64,64,256)的矩阵
o_c2 = general_conv2d(o_c1, num_features=64*2, window_width=3, window_height=3, stride_width=2, stride_height=2)
# o_c2.shape = (128, 128, 128)
o_enc_A = general_conv2d(o_c2, num_features=64*4, window_width=3, window_height=3, stride_width=2, stride_height=2)
# o_enc_A.shape = (64, 64, 256)
Transformer可以将这些层视为图像的不同附近特征的组合,然后基于这些特征来决定如何将图像的特征向量转换到另一个分布。作者使用了6层resnet块,其中输入的残差被添加到输出中。这样做是为了确保先前层的输入的属性也可用于以后的层,因此它们的输出不会偏离原始输入,否则原始图像的特性将不被保留在输出中。任务的主要目的之一是保留原始输入的特性,如对象的大小和形状,因此残差网络非常适合这些类型的变换。关于resnet,详见 ResNet原理及其在TF-Slim中的实现
o_r1 = build_resnet_block(o_enc_A, num_features=64*4)
o_r2 = build_resnet_block(o_r1, num_features=64*4)
o_r3 = build_resnet_block(o_r2, num_features=64*4)
o_r4 = build_resnet_block(o_r3, num_features=64*4)
o_r5 = build_resnet_block(o_r4, num_features=64*4)
o_enc_B = build_resnet_block(o_r5, num_features=64*4)
#定义resnet
def resnet_blocks(input_res, num_features):
out_res_1 = general_conv2d(input_res, num_features,
window_width=3,
window_heigth=3,
stride_width=1,
stride_heigth=1)
out_res_2 = general_conv2d(out_res_1, num_features,
window_width=3,
window_heigth=3,
stride_width=1,
stride_heigth=1)
return (out_res_2 + input_res)
下面是decoder,用反卷积把这些特征变回成图片
o_d1 = general_deconv2d(o_enc_B, num_features=ngf*2 window_width=3, window_height=3, stride_width=2, stride_height=2)
o_d2 = general_deconv2d(o_d1, num_features=ngf, window_width=3, window_height=3, stride_width=2, stride_height=2)
gen_B = general_conv2d(o_d2, num_features=3, window_width=7, window_height=7, stride_width=1, stride_height=1)
#定义反卷积层
def general_deconv2d(inputconv, outshape, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1, stddev=0.02, padding="VALID", name="deconv2d", do_norm=True, do_relu=True, relufactor=0):
with tf.variable_scope(name):
conv = tf.contrib.layers.conv2d_transpose(inputconv, o_d, [f_h, f_w], [s_h, s_w], padding, activation_fn=None, weights_initializer=tf.truncated_normal_initializer(stddev=stddev),biases_initializer=tf.constant_initializer(0.0))
if do_norm:
conv = instance_norm(conv)
# conv = tf.contrib.layers.batch_norm(conv, decay=0.9, updates_collections=None, epsilon=1e-5, scale=True, scope="batch_norm")
if do_relu:
if(relufactor == 0):
conv = tf.nn.relu(conv,"relu")
else:
conv = lrelu(conv, relufactor, "lrelu")
return conv
判别器的构成在这里救不赘述了,无非就是用CNN把生成的图片变成一些特征图,再用全连接变成最后的decision(真或假)
定义loss function
判别器loss:loss_1是对于真图的判定,越接近1越好,loss_2是对于假图的判定,越接近0越好,loss是两个loss相加
D_A_loss_1 = tf.reduce_mean(tf.squared_difference(dec_A,1))
D_B_loss_1 = tf.reduce_mean(tf.squared_difference(dec_B,1))
D_A_loss_2 = tf.reduce_mean(tf.square(dec_gen_A))
D_B_loss_2 = tf.reduce_mean(tf.square(dec_gen_B))
D_A_loss = (D_A_loss_1 + D_A_loss_2)/2
D_B_loss = (D_B_loss_1 + D_B_loss_2)/2
生成器loss:
g_loss_B_1 = tf.reduce_mean(tf.squared_difference(dec_gen_A,1))
g_loss_A_1 = tf.reduce_mean(tf.squared_difference(dec_gen_A,1))
Cycle Consistency loss: 保证原始图像和循环图像之间的差异应该尽可能小,注意10*cyc_loss是赋予Cycle Consistency loss更大的权值,作者并没有讨论这个参数是怎么确定下来的
cyc_loss = tf.reduce_mean(tf.abs(input_A-cyc_A)) + tf.reduce_mean(tf.abs(input_B-cyc_B))
g_loss_A = g_loss_A_1 + 10*cyc_loss
g_loss_B = g_loss_B_1 + 10*cyc_loss
模型训练
for epoch in range(0,100):
# Define the learning rate schedule. The learning rate is kept
# constant upto 100 epochs and then slowly decayed
if(epoch < 100) :
curr_lr = 0.0002
else:
curr_lr = 0.0002 - 0.0002*(epoch-100)/100
# Running the training loop for all batches
for ptr in range(0,num_images):
# Train generator G_A->B
_, gen_B_temp = sess.run([g_A_trainer, gen_B],
feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})
# We need gen_B_temp because to calculate the error in training D_B
_ = sess.run([d_B_trainer],
feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})
# Same for G_B->A and D_A as follow
_, gen_A_temp = sess.run([g_B_trainer, gen_A],
feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})
_ = sess.run([d_A_trainer],
feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})