昇思25天学习打卡营第23天 | CycleGAN图像风格迁移互换

昇思25天学习打卡营第23天 | CycleGAN图像风格迁移互换

文章目录

  • 昇思25天学习打卡营第23天 | CycleGAN图像风格迁移互换
    • CycleGAN模型
      • 模型结构
      • 循环一致损失函数
    • 数据集
      • 数据下载
      • 创建数据集
    • 网络构建
      • 生成器
      • 判别器
      • 损失函数和优化器
      • 前向计算
      • 梯度计算与反向传播
    • 总结
    • 打卡

CycleGAN模型

循环对抗生成网络(Cycle Generative Adversarial Network)的一个重要应用是域迁移(Domain Adaptation),即图像风格迁移。

CycleGAN实现了一种在没有配对示例的情况下学习将图像从源域X转换到目标域Y的方法,是一种无监督的图像迁移网络。

模型结构

CycleGAN网络是由两个镜像对称的GAN网络组成:
CycleGAN
图中, G G G为将X生成Y风格的生成器, F F F为将Y生成X风格的生成器, D X D_X DX D Y D_Y DY为相应的判别器。

该模型最终输出两个模型的权重,分别将两种图像风格进行彼此迁移,生成新的图像。

循环一致损失函数

循环一致损失(Cycle Consistency Loss)的计算方法如下:

Cycle Consistency Loss
x ∈ X x\in X xX经过生成器 G G G得到 Y Y Y风格的 Y ^ \hat Y Y^,然后将 Y ^ \hat Y Y^送进生成器 F F F产生X风格的新图片 x ^ \hat x x^,最后使用 x ^ \hat x x^ x x x一起计算出循环一致损失。

数据集

实验使用ImageNet中的苹果橘子部分图片,图像缩放为 256 × 256 256\times 256 256×256大小。

数据下载

from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/CycleGAN_apple2orange.zip"download(url, ".", kind="zip", replace=True)

创建数据集

from mindspore.dataset import MindDataset# 读取MindRecord格式数据
name_mr = "./CycleGAN_apple2orange/apple2orange_train.mindrecord"
data = MindDataset(dataset_files=name_mr)
print("Datasize: ", data.get_dataset_size())batch_size = 1
dataset = data.batch(batch_size)
datasize = dataset.get_dataset_size()

网络构建

生成器

CycleGAN Generator

import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Normalweight_init = Normal(sigma=0.02)class ConvNormReLU(nn.Cell):def __init__(self, input_channel, out_planes, kernel_size=4, stride=2, alpha=0.2, norm_mode='instance',pad_mode='CONSTANT', use_relu=True, padding=None, transpose=False):super(ConvNormReLU, self).__init__()norm = nn.BatchNorm2d(out_planes)if norm_mode == 'instance':norm = nn.BatchNorm2d(out_planes, affine=False)has_bias = (norm_mode == 'instance')if padding is None:padding = (kernel_size - 1) // 2if pad_mode == 'CONSTANT':if transpose:conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='same',has_bias=has_bias, weight_init=weight_init)else:conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',has_bias=has_bias, padding=padding, weight_init=weight_init)layers = [conv, norm]else:paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))pad = nn.Pad(paddings=paddings, mode=pad_mode)if transpose:conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='pad',has_bias=has_bias, weight_init=weight_init)else:conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',has_bias=has_bias, weight_init=weight_init)layers = [pad, conv, norm]if use_relu:relu = nn.ReLU()if alpha > 0:relu = nn.LeakyReLU(alpha)layers.append(relu)self.features = nn.SequentialCell(layers)def construct(self, x):output = self.features(x)return outputclass ResidualBlock(nn.Cell):def __init__(self, dim, norm_mode='instance', dropout=False, pad_mode="CONSTANT"):super(ResidualBlock, self).__init__()self.conv1 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode)self.conv2 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode, use_relu=False)self.dropout = dropoutif dropout:self.dropout = nn.Dropout(p=0.5)def construct(self, x):out = self.conv1(x)if self.dropout:out = self.dropout(out)out = self.conv2(out)return x + outclass ResNetGenerator(nn.Cell):def __init__(self, input_channel=3, output_channel=64, n_layers=9, alpha=0.2, norm_mode='instance', dropout=False,pad_mode="CONSTANT"):super(ResNetGenerator, self).__init__()self.conv_in = ConvNormReLU(input_channel, output_channel, 7, 1, alpha, norm_mode, pad_mode=pad_mode)self.down_1 = ConvNormReLU(output_channel, output_channel * 2, 3, 2, alpha, norm_mode)self.down_2 = ConvNormReLU(output_channel * 2, output_channel * 4, 3, 2, alpha, norm_mode)layers = [ResidualBlock(output_channel * 4, norm_mode, dropout=dropout, pad_mode=pad_mode)] * n_layersself.residuals = nn.SequentialCell(layers)self.up_2 = ConvNormReLU(output_channel * 4, output_channel * 2, 3, 2, alpha, norm_mode, transpose=True)self.up_1 = ConvNormReLU(output_channel * 2, output_channel, 3, 2, alpha, norm_mode, transpose=True)if pad_mode == "CONSTANT":self.conv_out = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad',padding=3, weight_init=weight_init)else:pad = nn.Pad(paddings=((0, 0), (0, 0), (3, 3), (3, 3)), mode=pad_mode)conv = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad', weight_init=weight_init)self.conv_out = nn.SequentialCell([pad, conv])def construct(self, x):x = self.conv_in(x)x = self.down_1(x)x = self.down_2(x)x = self.residuals(x)x = self.up_2(x)x = self.up_1(x)output = self.conv_out(x)return ops.tanh(output)# 实例化生成器
net_rg_a = ResNetGenerator()
net_rg_a.update_parameters_name('net_rg_a.')net_rg_b = ResNetGenerator()
net_rg_b.update_parameters_name('net_rg_b.')

判别器

# 定义判别器
class Discriminator(nn.Cell):def __init__(self, input_channel=3, output_channel=64, n_layers=3, alpha=0.2, norm_mode='instance'):super(Discriminator, self).__init__()kernel_size = 4layers = [nn.Conv2d(input_channel, output_channel, kernel_size, 2, pad_mode='pad', padding=1, weight_init=weight_init),nn.LeakyReLU(alpha)]nf_mult = output_channelfor i in range(1, n_layers):nf_mult_prev = nf_multnf_mult = min(2 ** i, 8) * output_channellayers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))nf_mult_prev = nf_multnf_mult = min(2 ** n_layers, 8) * output_channellayers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1, weight_init=weight_init))self.features = nn.SequentialCell(layers)def construct(self, x):output = self.features(x)return output# 判别器初始化
net_d_a = Discriminator()
net_d_a.update_parameters_name('net_d_a.')net_d_b = Discriminator()
net_d_b.update_parameters_name('net_d_b.')

损失函数和优化器

对于生成器 G G G和对应的判别器 D Y D_Y DY,目标损失函数定义为:
L G A N ( D , D Y , X , Y ) = E y − p d a t a ( y ) [ log ⁡ D Y ( y ) ] + E x − p d a t a ( x ) [ log ⁡ ( 1 − D Y ( G ( x ) ) ) ] L_{GAN}(D,D_Y,X,Y)=E_{y- p_{data}(y)}[\log{D_Y(y)}] + E_{x-p_{data}(x)}[\log(1-D_Y(G(x))) ] LGAN(D,DY,X,Y)=Eypdata(y)[logDY(y)]+Expdata(x)[log(1DY(G(x)))]
生成器的目标是最小化这个损失函数,即:
min ⁡ G max ⁡ D Y L G A N ( G , D Y , X , Y ) \min_G\max_{D_{Y}}L_{GAN}(G,D_Y,X,Y) GminDYmaxLGAN(G,DY,X,Y)

对于 X X X的每个图象 x x x,图像转换周期应该能将 x x x带回原始图像,称之为正向循环一致性,即 x → G ( x ) → F ( G ( x ) ) ≈ x x\to G(x)\to F(G(x))\approx x xG(x)F(G(x))x
循环一致损失函数定义为:
L c y c ( G , F ) = E x − p d a t a ( x ) [ ∣ ∣ F ( G ( x ) ) − x ∣ ∣ 1 ] + E y − p d a t a ( y ) [ ∣ ∣ G ( F ( y ) ) − y ∣ ∣ 1 ] L_{cyc}(G,F)=E_{x-p_{data}(x)}[||F(G(x))-x||_1]+E_{y-p_{data}(y)}[||G(F(y))-y||_1] Lcyc(G,F)=Expdata(x)[∣∣F(G(x))x1]+Eypdata(y)[∣∣G(F(y))y1]
循环一致损失能够保证重建图像 F ( G ( x ) ) F(G(x)) F(G(x))与输入图像 x x x紧密匹配。

# 构建生成器,判别器优化器
optimizer_rg_a = nn.Adam(net_rg_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_rg_b = nn.Adam(net_rg_b.trainable_params(), learning_rate=0.0002, beta1=0.5)optimizer_d_a = nn.Adam(net_d_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_d_b = nn.Adam(net_d_b.trainable_params(), learning_rate=0.0002, beta1=0.5)# GAN网络损失函数,这里最后一层不使用sigmoid函数
loss_fn = nn.MSELoss(reduction='mean')
l1_loss = nn.L1Loss("mean")def gan_loss(predict, target):target = ops.ones_like(predict) * targetloss = loss_fn(predict, target)return loss

前向计算

import mindspore as ms# 前向计算def generator(img_a, img_b):fake_a = net_rg_b(img_b)fake_b = net_rg_a(img_a)rec_a = net_rg_b(fake_b)rec_b = net_rg_a(fake_a)identity_a = net_rg_b(img_a)identity_b = net_rg_a(img_b)return fake_a, fake_b, rec_a, rec_b, identity_a, identity_blambda_a = 10.0
lambda_b = 10.0
lambda_idt = 0.5def generator_forward(img_a, img_b):true = Tensor(True, dtype.bool_)fake_a, fake_b, rec_a, rec_b, identity_a, identity_b = generator(img_a, img_b)loss_g_a = gan_loss(net_d_b(fake_b), true)loss_g_b = gan_loss(net_d_a(fake_a), true)loss_c_a = l1_loss(rec_a, img_a) * lambda_aloss_c_b = l1_loss(rec_b, img_b) * lambda_bloss_idt_a = l1_loss(identity_a, img_a) * lambda_a * lambda_idtloss_idt_b = l1_loss(identity_b, img_b) * lambda_b * lambda_idtloss_g = loss_g_a + loss_g_b + loss_c_a + loss_c_b + loss_idt_a + loss_idt_breturn fake_a, fake_b, loss_g, loss_g_a, loss_g_b, loss_c_a, loss_c_b, loss_idt_a, loss_idt_bdef generator_forward_grad(img_a, img_b):_, _, loss_g, _, _, _, _, _, _ = generator_forward(img_a, img_b)return loss_gdef discriminator_forward(img_a, img_b, fake_a, fake_b):false = Tensor(False, dtype.bool_)true = Tensor(True, dtype.bool_)d_fake_a = net_d_a(fake_a)d_img_a = net_d_a(img_a)d_fake_b = net_d_b(fake_b)d_img_b = net_d_b(img_b)loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)loss_d = (loss_d_a + loss_d_b) * 0.5return loss_ddef discriminator_forward_a(img_a, fake_a):false = Tensor(False, dtype.bool_)true = Tensor(True, dtype.bool_)d_fake_a = net_d_a(fake_a)d_img_a = net_d_a(img_a)loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)return loss_d_adef discriminator_forward_b(img_b, fake_b):false = Tensor(False, dtype.bool_)true = Tensor(True, dtype.bool_)d_fake_b = net_d_b(fake_b)d_img_b = net_d_b(img_b)loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)return loss_d_b# 保留了一个图像缓冲区,用来存储之前创建的50个图像
pool_size = 50
def image_pool(images):num_imgs = 0image1 = []if isinstance(images, Tensor):images = images.asnumpy()return_images = []for image in images:if num_imgs < pool_size:num_imgs = num_imgs + 1image1.append(image)return_images.append(image)else:if random.uniform(0, 1) > 0.5:random_id = random.randint(0, pool_size - 1)tmp = image1[random_id].copy()image1[random_id] = imagereturn_images.append(tmp)else:return_images.append(image)output = Tensor(return_images, ms.float32)if output.ndim != 4:raise ValueError("img should be 4d, but get shape {}".format(output.shape))return output

梯度计算与反向传播

from mindspore import value_and_grad# 实例化求梯度的方法
grad_g_a = value_and_grad(generator_forward_grad, None, net_rg_a.trainable_params())
grad_g_b = value_and_grad(generator_forward_grad, None, net_rg_b.trainable_params())grad_d_a = value_and_grad(discriminator_forward_a, None, net_d_a.trainable_params())
grad_d_b = value_and_grad(discriminator_forward_b, None, net_d_b.trainable_params())# 计算生成器的梯度,反向传播更新参数
def train_step_g(img_a, img_b):net_d_a.set_grad(False)net_d_b.set_grad(False)fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib = generator_forward(img_a, img_b)_, grads_g_a = grad_g_a(img_a, img_b)_, grads_g_b = grad_g_b(img_a, img_b)optimizer_rg_a(grads_g_a)optimizer_rg_b(grads_g_b)return fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib# 计算判别器的梯度,反向传播更新参数
def train_step_d(img_a, img_b, fake_a, fake_b):net_d_a.set_grad(True)net_d_b.set_grad(True)loss_d_a, grads_d_a = grad_d_a(img_a, fake_a)loss_d_b, grads_d_b = grad_d_b(img_b, fake_b)loss_d = (loss_d_a + loss_d_b) * 0.5optimizer_d_a(grads_d_a)optimizer_d_b(grads_d_b)return loss_d

总结

这一节介绍了CycleGAN网络,用来将图像风格从源域 X X X转移到目标域 Y Y Y上。CycleGAN由两个镜像对称的GAN网络组成,两个GAN网络各自有一个生成器(用来相互转移风格)和判别器。此外还引入了循环一致损失函数用来捕获生成图像和输入图像的关系。

打卡

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/49871.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【办公软件】Office 2019以上版本PPT 做平滑切换

Office2019以上版本可以在切页面时做平滑切换&#xff0c;做到一些简单的动画效果。如下在快捷菜单栏中的切换里选择平滑。 比如&#xff0c;在两页PPT中&#xff0c;使用同一个形状对象&#xff0c;修改了大小和颜色。 选择切换为平滑后&#xff0c;可以完成如下的动画显示。 …

java-poi实现excel自定义注解生成数据并导出

因为项目很多地方需要使用导出数据excel的功能&#xff0c;所以开发了一个简易的统一生成导出方法。 依赖 <dependency> <groupId>org.apache.poi</groupId> <artifactId>poi-ooxml</artifactId> <version>4.0.1</version…

数学基础 -- 分段函数的导数

分段函数的导数 分段函数的导数是指分段函数在每一段上的导数情况&#xff0c;以及在分段点上的导数情况。具体可以通过以下步骤进行计算&#xff1a; 确定分段函数的定义&#xff1a;首先明确分段函数的每一段的表达式和对应的区间。 逐段求导&#xff1a;在每个区间内&…

【TortoiseGit】合并单个commit(提交)到指定分支上

0、前言 当我们用Git的时候经常用到多个分支&#xff0c;会经常有如下情况&#xff1a;一个dev分支下面会有多个test分支&#xff0c;而每个test分支由不同的开发者。而我们会有这样的需求&#xff1a; 当某个test分支完成了相应功能验证&#xff0c;就要把成功验证的功能代码…

智能卡芯片载带条带AOI外观检测设备及系统

智能卡及其芯片载带简介 我国智能卡产业的发展始于1993年的“金卡工程”&#xff0c;它是一项把货币电子化&#xff0c;运用芯片技术来搭载电子货币应用&#xff0c;运用互联网技术建立从发行到受理的电子货币系统&#xff0c;以提高社会运作效率&#xff0c;方便人们工作生活为…

多角度解析高防CDN防御DDOS及CC攻击

网络攻击的形式也日益多样化&#xff0c;其中DDoS&#xff08;分布式拒绝服务&#xff09;和CC&#xff08;Challenge Collapsar&#xff09;攻击尤为突出&#xff0c;给网站和企业带来了巨大的安全威胁。高防CDN&#xff08;Content Delivery Network&#xff09;作为一种专业…

Mac m1安装 MongoDB 7.0.12

一、下载MongoDB MongoDB 社区版官网下载 二、安装配置MongoDB 1.解压下载的压缩包文件&#xff0c;文件夹重命名为mongodb; 2.将重命名为mongodb的文件夹&#xff0c;放在/usr/local 目录下 3.在/usr/local/mongodb 目录下&#xff0c;新建data 和 log这两个文件夹&#…

09-optee内核-线程处理

快速链接: . 👉👉👉 个人博客笔记导读目录(全部) 👈👈👈 付费专栏-付费课程 【购买须知】:【精选】TEE从入门到精通-[目录] 👈👈👈线程处理 OP-TEE内核使用几个线程来支持在 并行(未完全启用!有用于不同目的的处理程序。在thread.c中,您将找到一个名为…

Flutter开发Dart 中的 mixin、extends 和 implements

目录 ​​​​​​​前言 1.extends 2.implements 3.mixin 前言 在 Dart 中&#xff0c;mixin、extends 和 implements 是面向对象编程中常用的关键字&#xff0c;它们分别用于不同的继承和实现方式。理解它们的用法和区别对于编写高质量、可维护的 Dart 代码至关重要。本文…

使用内网穿透工具 frp 发布内网 web 站点

使用内网穿透工具 frp 发布内网 web 站点 文章目录 使用内网穿透工具 frp 发布内网 web 站点一、基础环境二、适用场景三、过程和方法四、参考资料 版权声明&#xff1a;本文为CSDN博主「杨群」的原创文章&#xff0c;遵循 CC 4.0 BY-SA版权协议&#xff0c;于2024年7月20日首…

MySQL练手 --- 1251. 平均售价

题目链接&#xff1a;1251. 平均售价 思路&#xff1a; 由题意可知&#xff0c;Prices表和UnitsSold表&#xff0c;表的连接关系为一对一&#xff0c;连接字段&#xff08;匹配字段&#xff09;为product_id 要求&#xff1a;查找每种产品的平均售价。而Prices表含有价格还有…

【简历】吉林某一本大学:JAVA秋招简历指导,简历通过率比较低

注&#xff1a;为保证用户信息安全&#xff0c;姓名和学校等信息已经进行同层次变更&#xff0c;内容部分细节也进行了部分隐藏 简历说明 这是一份吉林某一本大学25届计算机专业同学的Java简历。因为学校是一本&#xff0c;所以求职目标以中厂为主。因为学校背景在中厂是正常…

MySQL查询优化:提升数据库性能的策略

在数据库管理和应用中&#xff0c;优化查询是提高MySQL数据库性能的关键环节。随着数据量的不断增长&#xff0c;如何高效地检索和处理数据成为了一个重要的挑战。本文将介绍一系列优化MySQL查询的策略&#xff0c;帮助开发者和管理员提升数据库的性能。 案例1: 使用索引优化查…

Oracle(21)什么是聚集索引和非聚集索引?

在数据库管理系统中&#xff0c;索引是一种用于加速数据检索的结构。主要有两种类型的索引&#xff1a;聚集索引&#xff08;Clustered Index&#xff09;和非聚集索引&#xff08;Non-Clustered Index&#xff09;。它们在数据存储和检索方面有不同的实现和用途。 聚集索引&a…

自学YOLO前置知识

YOLO前置知识 学习YOLO&#xff08;You Only Look Once&#xff09;之前&#xff0c;掌握一些前置知识会帮助你更好地理解和应用该技术。以下是一些推荐的前置知识领域&#xff1a; 计算机视觉基础&#xff1a; 图像处理&#xff1a;了解图像的基本处理技术&#xff0c;如滤波…

【Linux】gcc简介+编译过程

gcc是Linux系统下一款专门针对于C语言的代码编译软件。g则是Linux下针对于CPP语言的代码编译软件&#xff0c;实际上g底层也大量用了gcc代码。 目录 1.gcc基本认识与安装2.gcc编译过程2.1编译 和 链接2.2编译步骤形成的原因2.3编译器的自举2.4链接 1.gcc基本认识与安装 gcc是一…

Neuron协议网关的北向应用插件开发

目录 概述 指令处理层开发​ 应用层开发​ .open​ .close​ .init​ .uninit​ .start​ .stop​ .setting​ .request​ 插件设置文件​ 适配华为的思路 概述 最近研究了一段时间的Neuron协议网关&#xff0c;前面的博文也提到它虽然能够把数据发到华为的IoT平台上…

Adobe国际认证详解-从零开始学做视频剪辑

从零开始学做视频剪辑&#xff0c;是许多初学者面临的挑战。在这个数字媒体时代&#xff0c;视频剪辑已经成为一种重要的技能&#xff0c;无论是个人爱好还是职业发展&#xff0c;掌握视频剪辑技能都是非常有价值的。 视频剪辑&#xff0c;简称“剪辑”&#xff0c;是视频制作过…

高三了,无计算机基础能学计算机吗?

高三阶段学习计算机编程是完全可行的&#xff0c;即使你没有任何计算机基础。我收集制作一份C语言学习包&#xff0c;对于新手而言简直不要太棒&#xff0c;里面包括了新手各个时期的学习方向&#xff0c;包括了编程教学&#xff0c;数据处理&#xff0c;通信处理&#xff0c;技…

《算法笔记》总结No.11——数字处理(上)欧拉筛选

机试中存在部分涉及到较复杂数字的问题&#xff0c;这是编码的基本功&#xff0c;各位一定要得心应手。 目录 一.最大公约数和最小公倍数 1.最大公约数 2.最小公倍数 二.素数 1.判断指定数 2.输出所有素数 3.精进不休——埃拉托斯特尼筛法 4.达到更优&#xff01;——…