昇思25天学习打卡营第23天 | CycleGAN图像风格迁移互换

昇思25天学习打卡营第23天 | CycleGAN图像风格迁移互换

文章目录

  • 昇思25天学习打卡营第23天 | CycleGAN图像风格迁移互换
    • CycleGAN模型
      • 模型结构
      • 循环一致损失函数
    • 数据集
      • 数据下载
      • 创建数据集
    • 网络构建
      • 生成器
      • 判别器
      • 损失函数和优化器
      • 前向计算
      • 梯度计算与反向传播
    • 总结
    • 打卡

CycleGAN模型

循环对抗生成网络(Cycle Generative Adversarial Network)的一个重要应用是域迁移(Domain Adaptation),即图像风格迁移。

CycleGAN实现了一种在没有配对示例的情况下学习将图像从源域X转换到目标域Y的方法,是一种无监督的图像迁移网络。

模型结构

CycleGAN网络是由两个镜像对称的GAN网络组成:
CycleGAN
图中, G G G为将X生成Y风格的生成器, F F F为将Y生成X风格的生成器, D X D_X DX D Y D_Y DY为相应的判别器。

该模型最终输出两个模型的权重,分别将两种图像风格进行彼此迁移,生成新的图像。

循环一致损失函数

循环一致损失(Cycle Consistency Loss)的计算方法如下:

Cycle Consistency Loss
x ∈ X x\in X xX经过生成器 G G G得到 Y Y Y风格的 Y ^ \hat Y Y^,然后将 Y ^ \hat Y Y^送进生成器 F F F产生X风格的新图片 x ^ \hat x x^,最后使用 x ^ \hat x x^ x x x一起计算出循环一致损失。

数据集

实验使用ImageNet中的苹果橘子部分图片,图像缩放为 256 × 256 256\times 256 256×256大小。

数据下载

from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/CycleGAN_apple2orange.zip"download(url, ".", kind="zip", replace=True)

创建数据集

from mindspore.dataset import MindDataset# 读取MindRecord格式数据
name_mr = "./CycleGAN_apple2orange/apple2orange_train.mindrecord"
data = MindDataset(dataset_files=name_mr)
print("Datasize: ", data.get_dataset_size())batch_size = 1
dataset = data.batch(batch_size)
datasize = dataset.get_dataset_size()

网络构建

生成器

CycleGAN Generator

import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Normalweight_init = Normal(sigma=0.02)class ConvNormReLU(nn.Cell):def __init__(self, input_channel, out_planes, kernel_size=4, stride=2, alpha=0.2, norm_mode='instance',pad_mode='CONSTANT', use_relu=True, padding=None, transpose=False):super(ConvNormReLU, self).__init__()norm = nn.BatchNorm2d(out_planes)if norm_mode == 'instance':norm = nn.BatchNorm2d(out_planes, affine=False)has_bias = (norm_mode == 'instance')if padding is None:padding = (kernel_size - 1) // 2if pad_mode == 'CONSTANT':if transpose:conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='same',has_bias=has_bias, weight_init=weight_init)else:conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',has_bias=has_bias, padding=padding, weight_init=weight_init)layers = [conv, norm]else:paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))pad = nn.Pad(paddings=paddings, mode=pad_mode)if transpose:conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='pad',has_bias=has_bias, weight_init=weight_init)else:conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',has_bias=has_bias, weight_init=weight_init)layers = [pad, conv, norm]if use_relu:relu = nn.ReLU()if alpha > 0:relu = nn.LeakyReLU(alpha)layers.append(relu)self.features = nn.SequentialCell(layers)def construct(self, x):output = self.features(x)return outputclass ResidualBlock(nn.Cell):def __init__(self, dim, norm_mode='instance', dropout=False, pad_mode="CONSTANT"):super(ResidualBlock, self).__init__()self.conv1 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode)self.conv2 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode, use_relu=False)self.dropout = dropoutif dropout:self.dropout = nn.Dropout(p=0.5)def construct(self, x):out = self.conv1(x)if self.dropout:out = self.dropout(out)out = self.conv2(out)return x + outclass ResNetGenerator(nn.Cell):def __init__(self, input_channel=3, output_channel=64, n_layers=9, alpha=0.2, norm_mode='instance', dropout=False,pad_mode="CONSTANT"):super(ResNetGenerator, self).__init__()self.conv_in = ConvNormReLU(input_channel, output_channel, 7, 1, alpha, norm_mode, pad_mode=pad_mode)self.down_1 = ConvNormReLU(output_channel, output_channel * 2, 3, 2, alpha, norm_mode)self.down_2 = ConvNormReLU(output_channel * 2, output_channel * 4, 3, 2, alpha, norm_mode)layers = [ResidualBlock(output_channel * 4, norm_mode, dropout=dropout, pad_mode=pad_mode)] * n_layersself.residuals = nn.SequentialCell(layers)self.up_2 = ConvNormReLU(output_channel * 4, output_channel * 2, 3, 2, alpha, norm_mode, transpose=True)self.up_1 = ConvNormReLU(output_channel * 2, output_channel, 3, 2, alpha, norm_mode, transpose=True)if pad_mode == "CONSTANT":self.conv_out = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad',padding=3, weight_init=weight_init)else:pad = nn.Pad(paddings=((0, 0), (0, 0), (3, 3), (3, 3)), mode=pad_mode)conv = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad', weight_init=weight_init)self.conv_out = nn.SequentialCell([pad, conv])def construct(self, x):x = self.conv_in(x)x = self.down_1(x)x = self.down_2(x)x = self.residuals(x)x = self.up_2(x)x = self.up_1(x)output = self.conv_out(x)return ops.tanh(output)# 实例化生成器
net_rg_a = ResNetGenerator()
net_rg_a.update_parameters_name('net_rg_a.')net_rg_b = ResNetGenerator()
net_rg_b.update_parameters_name('net_rg_b.')

判别器

# 定义判别器
class Discriminator(nn.Cell):def __init__(self, input_channel=3, output_channel=64, n_layers=3, alpha=0.2, norm_mode='instance'):super(Discriminator, self).__init__()kernel_size = 4layers = [nn.Conv2d(input_channel, output_channel, kernel_size, 2, pad_mode='pad', padding=1, weight_init=weight_init),nn.LeakyReLU(alpha)]nf_mult = output_channelfor i in range(1, n_layers):nf_mult_prev = nf_multnf_mult = min(2 ** i, 8) * output_channellayers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))nf_mult_prev = nf_multnf_mult = min(2 ** n_layers, 8) * output_channellayers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1, weight_init=weight_init))self.features = nn.SequentialCell(layers)def construct(self, x):output = self.features(x)return output# 判别器初始化
net_d_a = Discriminator()
net_d_a.update_parameters_name('net_d_a.')net_d_b = Discriminator()
net_d_b.update_parameters_name('net_d_b.')

损失函数和优化器

对于生成器 G G G和对应的判别器 D Y D_Y DY,目标损失函数定义为:
L G A N ( D , D Y , X , Y ) = E y − p d a t a ( y ) [ log ⁡ D Y ( y ) ] + E x − p d a t a ( x ) [ log ⁡ ( 1 − D Y ( G ( x ) ) ) ] L_{GAN}(D,D_Y,X,Y)=E_{y- p_{data}(y)}[\log{D_Y(y)}] + E_{x-p_{data}(x)}[\log(1-D_Y(G(x))) ] LGAN(D,DY,X,Y)=Eypdata(y)[logDY(y)]+Expdata(x)[log(1DY(G(x)))]
生成器的目标是最小化这个损失函数,即:
min ⁡ G max ⁡ D Y L G A N ( G , D Y , X , Y ) \min_G\max_{D_{Y}}L_{GAN}(G,D_Y,X,Y) GminDYmaxLGAN(G,DY,X,Y)

对于 X X X的每个图象 x x x,图像转换周期应该能将 x x x带回原始图像,称之为正向循环一致性,即 x → G ( x ) → F ( G ( x ) ) ≈ x x\to G(x)\to F(G(x))\approx x xG(x)F(G(x))x
循环一致损失函数定义为:
L c y c ( G , F ) = E x − p d a t a ( x ) [ ∣ ∣ F ( G ( x ) ) − x ∣ ∣ 1 ] + E y − p d a t a ( y ) [ ∣ ∣ G ( F ( y ) ) − y ∣ ∣ 1 ] L_{cyc}(G,F)=E_{x-p_{data}(x)}[||F(G(x))-x||_1]+E_{y-p_{data}(y)}[||G(F(y))-y||_1] Lcyc(G,F)=Expdata(x)[∣∣F(G(x))x1]+Eypdata(y)[∣∣G(F(y))y1]
循环一致损失能够保证重建图像 F ( G ( x ) ) F(G(x)) F(G(x))与输入图像 x x x紧密匹配。

# 构建生成器,判别器优化器
optimizer_rg_a = nn.Adam(net_rg_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_rg_b = nn.Adam(net_rg_b.trainable_params(), learning_rate=0.0002, beta1=0.5)optimizer_d_a = nn.Adam(net_d_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_d_b = nn.Adam(net_d_b.trainable_params(), learning_rate=0.0002, beta1=0.5)# GAN网络损失函数,这里最后一层不使用sigmoid函数
loss_fn = nn.MSELoss(reduction='mean')
l1_loss = nn.L1Loss("mean")def gan_loss(predict, target):target = ops.ones_like(predict) * targetloss = loss_fn(predict, target)return loss

前向计算

import mindspore as ms# 前向计算def generator(img_a, img_b):fake_a = net_rg_b(img_b)fake_b = net_rg_a(img_a)rec_a = net_rg_b(fake_b)rec_b = net_rg_a(fake_a)identity_a = net_rg_b(img_a)identity_b = net_rg_a(img_b)return fake_a, fake_b, rec_a, rec_b, identity_a, identity_blambda_a = 10.0
lambda_b = 10.0
lambda_idt = 0.5def generator_forward(img_a, img_b):true = Tensor(True, dtype.bool_)fake_a, fake_b, rec_a, rec_b, identity_a, identity_b = generator(img_a, img_b)loss_g_a = gan_loss(net_d_b(fake_b), true)loss_g_b = gan_loss(net_d_a(fake_a), true)loss_c_a = l1_loss(rec_a, img_a) * lambda_aloss_c_b = l1_loss(rec_b, img_b) * lambda_bloss_idt_a = l1_loss(identity_a, img_a) * lambda_a * lambda_idtloss_idt_b = l1_loss(identity_b, img_b) * lambda_b * lambda_idtloss_g = loss_g_a + loss_g_b + loss_c_a + loss_c_b + loss_idt_a + loss_idt_breturn fake_a, fake_b, loss_g, loss_g_a, loss_g_b, loss_c_a, loss_c_b, loss_idt_a, loss_idt_bdef generator_forward_grad(img_a, img_b):_, _, loss_g, _, _, _, _, _, _ = generator_forward(img_a, img_b)return loss_gdef discriminator_forward(img_a, img_b, fake_a, fake_b):false = Tensor(False, dtype.bool_)true = Tensor(True, dtype.bool_)d_fake_a = net_d_a(fake_a)d_img_a = net_d_a(img_a)d_fake_b = net_d_b(fake_b)d_img_b = net_d_b(img_b)loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)loss_d = (loss_d_a + loss_d_b) * 0.5return loss_ddef discriminator_forward_a(img_a, fake_a):false = Tensor(False, dtype.bool_)true = Tensor(True, dtype.bool_)d_fake_a = net_d_a(fake_a)d_img_a = net_d_a(img_a)loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)return loss_d_adef discriminator_forward_b(img_b, fake_b):false = Tensor(False, dtype.bool_)true = Tensor(True, dtype.bool_)d_fake_b = net_d_b(fake_b)d_img_b = net_d_b(img_b)loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)return loss_d_b# 保留了一个图像缓冲区,用来存储之前创建的50个图像
pool_size = 50
def image_pool(images):num_imgs = 0image1 = []if isinstance(images, Tensor):images = images.asnumpy()return_images = []for image in images:if num_imgs < pool_size:num_imgs = num_imgs + 1image1.append(image)return_images.append(image)else:if random.uniform(0, 1) > 0.5:random_id = random.randint(0, pool_size - 1)tmp = image1[random_id].copy()image1[random_id] = imagereturn_images.append(tmp)else:return_images.append(image)output = Tensor(return_images, ms.float32)if output.ndim != 4:raise ValueError("img should be 4d, but get shape {}".format(output.shape))return output

梯度计算与反向传播

from mindspore import value_and_grad# 实例化求梯度的方法
grad_g_a = value_and_grad(generator_forward_grad, None, net_rg_a.trainable_params())
grad_g_b = value_and_grad(generator_forward_grad, None, net_rg_b.trainable_params())grad_d_a = value_and_grad(discriminator_forward_a, None, net_d_a.trainable_params())
grad_d_b = value_and_grad(discriminator_forward_b, None, net_d_b.trainable_params())# 计算生成器的梯度,反向传播更新参数
def train_step_g(img_a, img_b):net_d_a.set_grad(False)net_d_b.set_grad(False)fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib = generator_forward(img_a, img_b)_, grads_g_a = grad_g_a(img_a, img_b)_, grads_g_b = grad_g_b(img_a, img_b)optimizer_rg_a(grads_g_a)optimizer_rg_b(grads_g_b)return fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib# 计算判别器的梯度,反向传播更新参数
def train_step_d(img_a, img_b, fake_a, fake_b):net_d_a.set_grad(True)net_d_b.set_grad(True)loss_d_a, grads_d_a = grad_d_a(img_a, fake_a)loss_d_b, grads_d_b = grad_d_b(img_b, fake_b)loss_d = (loss_d_a + loss_d_b) * 0.5optimizer_d_a(grads_d_a)optimizer_d_b(grads_d_b)return loss_d

总结

这一节介绍了CycleGAN网络,用来将图像风格从源域 X X X转移到目标域 Y Y Y上。CycleGAN由两个镜像对称的GAN网络组成,两个GAN网络各自有一个生成器(用来相互转移风格)和判别器。此外还引入了循环一致损失函数用来捕获生成图像和输入图像的关系。

打卡

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/49871.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【办公软件】Office 2019以上版本PPT 做平滑切换

Office2019以上版本可以在切页面时做平滑切换&#xff0c;做到一些简单的动画效果。如下在快捷菜单栏中的切换里选择平滑。 比如&#xff0c;在两页PPT中&#xff0c;使用同一个形状对象&#xff0c;修改了大小和颜色。 选择切换为平滑后&#xff0c;可以完成如下的动画显示。 …

java-poi实现excel自定义注解生成数据并导出

因为项目很多地方需要使用导出数据excel的功能&#xff0c;所以开发了一个简易的统一生成导出方法。 依赖 <dependency> <groupId>org.apache.poi</groupId> <artifactId>poi-ooxml</artifactId> <version>4.0.1</version…

【TortoiseGit】合并单个commit(提交)到指定分支上

0、前言 当我们用Git的时候经常用到多个分支&#xff0c;会经常有如下情况&#xff1a;一个dev分支下面会有多个test分支&#xff0c;而每个test分支由不同的开发者。而我们会有这样的需求&#xff1a; 当某个test分支完成了相应功能验证&#xff0c;就要把成功验证的功能代码…

智能卡芯片载带条带AOI外观检测设备及系统

智能卡及其芯片载带简介 我国智能卡产业的发展始于1993年的“金卡工程”&#xff0c;它是一项把货币电子化&#xff0c;运用芯片技术来搭载电子货币应用&#xff0c;运用互联网技术建立从发行到受理的电子货币系统&#xff0c;以提高社会运作效率&#xff0c;方便人们工作生活为…

Mac m1安装 MongoDB 7.0.12

一、下载MongoDB MongoDB 社区版官网下载 二、安装配置MongoDB 1.解压下载的压缩包文件&#xff0c;文件夹重命名为mongodb; 2.将重命名为mongodb的文件夹&#xff0c;放在/usr/local 目录下 3.在/usr/local/mongodb 目录下&#xff0c;新建data 和 log这两个文件夹&#…

09-optee内核-线程处理

快速链接: . 👉👉👉 个人博客笔记导读目录(全部) 👈👈👈 付费专栏-付费课程 【购买须知】:【精选】TEE从入门到精通-[目录] 👈👈👈线程处理 OP-TEE内核使用几个线程来支持在 并行(未完全启用!有用于不同目的的处理程序。在thread.c中,您将找到一个名为…

MySQL练手 --- 1251. 平均售价

题目链接&#xff1a;1251. 平均售价 思路&#xff1a; 由题意可知&#xff0c;Prices表和UnitsSold表&#xff0c;表的连接关系为一对一&#xff0c;连接字段&#xff08;匹配字段&#xff09;为product_id 要求&#xff1a;查找每种产品的平均售价。而Prices表含有价格还有…

【简历】吉林某一本大学:JAVA秋招简历指导,简历通过率比较低

注&#xff1a;为保证用户信息安全&#xff0c;姓名和学校等信息已经进行同层次变更&#xff0c;内容部分细节也进行了部分隐藏 简历说明 这是一份吉林某一本大学25届计算机专业同学的Java简历。因为学校是一本&#xff0c;所以求职目标以中厂为主。因为学校背景在中厂是正常…

【Linux】gcc简介+编译过程

gcc是Linux系统下一款专门针对于C语言的代码编译软件。g则是Linux下针对于CPP语言的代码编译软件&#xff0c;实际上g底层也大量用了gcc代码。 目录 1.gcc基本认识与安装2.gcc编译过程2.1编译 和 链接2.2编译步骤形成的原因2.3编译器的自举2.4链接 1.gcc基本认识与安装 gcc是一…

Adobe国际认证详解-从零开始学做视频剪辑

从零开始学做视频剪辑&#xff0c;是许多初学者面临的挑战。在这个数字媒体时代&#xff0c;视频剪辑已经成为一种重要的技能&#xff0c;无论是个人爱好还是职业发展&#xff0c;掌握视频剪辑技能都是非常有价值的。 视频剪辑&#xff0c;简称“剪辑”&#xff0c;是视频制作过…

高三了,无计算机基础能学计算机吗?

高三阶段学习计算机编程是完全可行的&#xff0c;即使你没有任何计算机基础。我收集制作一份C语言学习包&#xff0c;对于新手而言简直不要太棒&#xff0c;里面包括了新手各个时期的学习方向&#xff0c;包括了编程教学&#xff0c;数据处理&#xff0c;通信处理&#xff0c;技…

《算法笔记》总结No.11——数字处理(上)欧拉筛选

机试中存在部分涉及到较复杂数字的问题&#xff0c;这是编码的基本功&#xff0c;各位一定要得心应手。 目录 一.最大公约数和最小公倍数 1.最大公约数 2.最小公倍数 二.素数 1.判断指定数 2.输出所有素数 3.精进不休——埃拉托斯特尼筛法 4.达到更优&#xff01;——…

VMWare 16 安装

1、下载地址 VMware-workstation-full-16.2.4-20089737 2、激活码 VM16&#xff1a;ZF3R0-FHED2-M80TY-8QYGC-NPKYF 3、安装步骤 修改一下【安装位置】&#xff0c;将【增强型键盘驱动程序(需要重新引导以使用此功能()此功能要求主机驱动器上具有 10MB 空间。】【将 wMware…

JUC-synchorized与锁原理、锁的升级与膨胀

syn-ed 是一个可重入、不公平的重量级锁&#xff1b;synchronized使用对象锁保证了临界区代码的原子性&#xff0c;无论使用synchorized锁的是代码块还是方法&#xff0c;其本质都是锁住一个对象。 同步代码块&#xff0c;锁住的是括号里的对象同步方法 普通方法&#xff0c;…

Adobe“加速”创意人士开启设计新篇章

近日&#xff0c;Adobe公司宣布了其行业领先的专业设计应用程序——Adobe Illustrator和Adobe Photoshop的突破性创新。这一重大更新不仅为创意专业人士带来了前所未有的设计可能性和工作效率提升&#xff0c;还让不论是插画师、设计师还是摄影师&#xff0c;都能从中受益并创作…

GO内存分配详解

文章目录 GO内存分配详解一. 物理内存(Physical Memory)和虚拟内存(Virtual Memory)二. 内存分配器三. TCMalloc线程内存(thread memory)页堆(page heap)四. Go内存分配器mspanmcachemcentralmheap五. 对象分配流程六. Go虚拟内存ArenaGO内存分配详解 这篇文章中我将抽丝剥茧,…

RK3568 Linux 平台开发系列讲解(内核入门篇):如何高效地阅读 Linux 内核设备驱动

在嵌入式 Linux 开发中,设备驱动是实现操作系统与硬件之间交互的关键。对于 RK3568 这样的平台,理解和阅读 Linux 内核中的设备驱动程序至关重要。 1. 理解内核架构 在阅读设备驱动之前,首先要了解 Linux 内核的基本架构。内核主要由以下几个部分组成: 内核核心:处理系…

【word转pdf】【最新版本jar】Java使用aspose-words实现word文档转pdf

【aspose-words-22.12-jdk17.jar】word文档转pdf 前置工作1、下载依赖2、安装依赖到本地仓库 项目1、配置pom.xml2、配置许可码文件&#xff08;不配置会有水印&#xff09;3、工具类4、效果 踩坑1、pdf乱码2、word中带有图片转换 前置工作 1、下载依赖 通过百度网盘分享的文…

Golang实现免费天气预报获取(OpenWeatherMap)

最近接到公司的一个小需求&#xff0c;需要天气数据&#xff0c;所以就做了一个小接口&#xff0c;供前端调用 这些数据包括六个元素&#xff0c;如降水、风、大气压力、云量和温度。有了这些&#xff0c;你可以分析趋势&#xff0c;知道明天的数据来预测天气。 1.1 工具简介 …

tinyxml2的入门教程

tinyxml2的入门教程 前言一、tinyxml2 创建xml 文件二、tinyxml2 添加数据三、tinyxml2 更改数据四、tinyxml2 删除数据五、tinyxml2 打印总结 前言 xml 是一种标记型文档&#xff0c;有两种基本解析方式&#xff1a;DOM(Document Object Model&#xff0c;文档对象模型)和SAX…