昇思25天学习打卡营第23天 | Pix2Pix实现图像转换

内容介绍:

Pix2Pix是基于条件生成对抗网络(cGAN, Condition Generative Adversarial Networks )实现的一种深度学习图像转换模型,该模型是由Phillip Isola等作者在2017年CVPR上提出的,可以实现语义/标签到真实图片、灰度图到彩色图、航空图到地图、白天到黑夜、线稿图到实物图的转换。Pix2Pix是将cGAN应用于有监督的图像到图像翻译的经典之作,其包括两个模型:生成器判别器

传统上,尽管此类任务的目标都是相同的从像素预测像素,但每项都是用单独的专用机器来处理的。而Pix2Pix使用的网络作为一个通用框架,使用相同的架构和目标,只在不同的数据上进行训练,即可得到令人满意的结果,鉴于此许多人已经使用此网络发布了他们自己的艺术作品。

cGAN的生成器与传统GAN的生成器在原理上有一些区别,cGAN的生成器是将输入图片作为指导信息,由输入图像不断尝试生成用于迷惑判别器的“假”图像,由输入图像转换输出为相应“假”图像的本质是从像素到另一个像素的映射,而传统GAN的生成器是基于一个给定的随机噪声生成图像,输出图像通过其他约束条件控制生成,这是cGAN和GAN的在图像翻译任务中的差异。Pix2Pix中判别器的任务是判断从生成器输出的图像是真实的训练图像还是生成的“假”图像。在生成器与判别器的不断博弈过程中,模型会达到一个平衡点,生成器输出的图像与真实训练数据使得判别器刚好具有50%的概率判断正确。

具体内容:

1. 导包:

from download import download
from mindspore import dataset as ds
import matplotlib.pyplot as plt
import mindspore
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.nn as nn
import mindspore.nn as nn
from mindspore.common import initializer as init
import numpy as np
import os
import datetime
from mindspore import value_and_grad, Tensor
from mindspore import load_checkpoint, load_param_into_net

2. 下载数据集

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/dataset_pix2pix.tar"download(url, "./dataset", kind="tar", replace=True)

3. 数据显示

dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator(output_numpy=True))
# 可视化部分训练数据
plt.figure(figsize=(10, 3), dpi=140)
for i, image in enumerate(data_iter['input_images'][:10], 1):plt.subplot(3, 10, i)plt.axis("off")plt.imshow((image.transpose(1, 2, 0) + 1) / 2)
plt.show()

4. 网络构建

class UNetSkipConnectionBlock(nn.Cell):def __init__(self, outer_nc, inner_nc, in_planes=None, dropout=False,submodule=None, outermost=False, innermost=False, alpha=0.2, norm_mode='batch'):super(UNetSkipConnectionBlock, self).__init__()down_norm = nn.BatchNorm2d(inner_nc)up_norm = nn.BatchNorm2d(outer_nc)use_bias = Falseif norm_mode == 'instance':down_norm = nn.BatchNorm2d(inner_nc, affine=False)up_norm = nn.BatchNorm2d(outer_nc, affine=False)use_bias = Trueif in_planes is None:in_planes = outer_ncdown_conv = nn.Conv2d(in_planes, inner_nc, kernel_size=4,stride=2, padding=1, has_bias=use_bias, pad_mode='pad')down_relu = nn.LeakyReLU(alpha)up_relu = nn.ReLU()if outermost:up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,kernel_size=4, stride=2,padding=1, pad_mode='pad')down = [down_conv]up = [up_relu, up_conv, nn.Tanh()]model = down + [submodule] + upelif innermost:up_conv = nn.Conv2dTranspose(inner_nc, outer_nc,kernel_size=4, stride=2,padding=1, has_bias=use_bias, pad_mode='pad')down = [down_relu, down_conv]up = [up_relu, up_conv, up_norm]model = down + upelse:up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,kernel_size=4, stride=2,padding=1, has_bias=use_bias, pad_mode='pad')down = [down_relu, down_conv, down_norm]up = [up_relu, up_conv, up_norm]model = down + [submodule] + upif dropout:model.append(nn.Dropout(p=0.5))self.model = nn.SequentialCell(model)self.skip_connections = not outermostdef construct(self, x):out = self.model(x)if self.skip_connections:out = ops.concat((out, x), axis=1)return out

5. UNet生成器

class UNetGenerator(nn.Cell):def __init__(self, in_planes, out_planes, ngf=64, n_layers=8, norm_mode='bn', dropout=False):super(UNetGenerator, self).__init__()unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=None,norm_mode=norm_mode, innermost=True)for _ in range(n_layers - 5):unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=unet_block,norm_mode=norm_mode, dropout=dropout)unet_block = UNetSkipConnectionBlock(ngf * 4, ngf * 8, in_planes=None, submodule=unet_block,norm_mode=norm_mode)unet_block = UNetSkipConnectionBlock(ngf * 2, ngf * 4, in_planes=None, submodule=unet_block,norm_mode=norm_mode)unet_block = UNetSkipConnectionBlock(ngf, ngf * 2, in_planes=None, submodule=unet_block,norm_mode=norm_mode)self.model = UNetSkipConnectionBlock(out_planes, ngf, in_planes=in_planes, submodule=unet_block,outermost=True, norm_mode=norm_mode)def construct(self, x):return self.model(x)

6. PatchGAN判别器

import mindspore.nn as nnclass ConvNormRelu(nn.Cell):def __init__(self,in_planes,out_planes,kernel_size=4,stride=2,alpha=0.2,norm_mode='batch',pad_mode='CONSTANT',use_relu=True,padding=None):super(ConvNormRelu, self).__init__()norm = nn.BatchNorm2d(out_planes)if norm_mode == 'instance':norm = nn.BatchNorm2d(out_planes, affine=False)has_bias = (norm_mode == 'instance')if not padding:padding = (kernel_size - 1) // 2if pad_mode == 'CONSTANT':conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad',has_bias=has_bias, padding=padding)layers = [conv, norm]else:paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))pad = nn.Pad(paddings=paddings, mode=pad_mode)conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias)layers = [pad, conv, norm]if use_relu:relu = nn.ReLU()if alpha > 0:relu = nn.LeakyReLU(alpha)layers.append(relu)self.features = nn.SequentialCell(layers)def construct(self, x):output = self.features(x)return outputclass Discriminator(nn.Cell):def __init__(self, in_planes=3, ndf=64, n_layers=3, alpha=0.2, norm_mode='batch'):super(Discriminator, self).__init__()kernel_size = 4layers = [nn.Conv2d(in_planes, ndf, kernel_size, 2, pad_mode='pad', padding=1),nn.LeakyReLU(alpha)]nf_mult = ndffor i in range(1, n_layers):nf_mult_prev = nf_multnf_mult = min(2 ** i, 8) * ndflayers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))nf_mult_prev = nf_multnf_mult = min(2 ** n_layers, 8) * ndflayers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1))self.features = nn.SequentialCell(layers)def construct(self, x, y):x_y = ops.concat((x, y), axis=1)output = self.features(x_y)return output

7. Pix2Pix生成器和判别器初始化

g_in_planes = 3
g_out_planes = 3
g_ngf = 64
g_layers = 8
d_in_planes = 6
d_ndf = 64
d_layers = 3
alpha = 0.2
init_gain = 0.02
init_type = 'normal'net_generator = UNetGenerator(in_planes=g_in_planes, out_planes=g_out_planes,ngf=g_ngf, n_layers=g_layers)
for _, cell in net_generator.cells_and_names():if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):if init_type == 'normal':cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))elif init_type == 'xavier':cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))elif init_type == 'constant':cell.weight.set_data(init.initializer(0.001, cell.weight.shape))else:raise NotImplementedError('initialization method [%s] is not implemented' % init_type)elif isinstance(cell, nn.BatchNorm2d):cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))cell.beta.set_data(init.initializer('zeros', cell.beta.shape))net_discriminator = Discriminator(in_planes=d_in_planes, ndf=d_ndf,alpha=alpha, n_layers=d_layers)
for _, cell in net_discriminator.cells_and_names():if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):if init_type == 'normal':cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))elif init_type == 'xavier':cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))elif init_type == 'constant':cell.weight.set_data(init.initializer(0.001, cell.weight.shape))else:raise NotImplementedError('initialization method [%s] is not implemented' % init_type)elif isinstance(cell, nn.BatchNorm2d):cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))cell.beta.set_data(init.initializer('zeros', cell.beta.shape))class Pix2Pix(nn.Cell):"""Pix2Pix模型网络"""def __init__(self, discriminator, generator):super(Pix2Pix, self).__init__(auto_prefix=True)self.net_discriminator = discriminatorself.net_generator = generatordef construct(self, reala):fakeb = self.net_generator(reala)return fakeb

8. 训练

epoch_num = 3
ckpt_dir = "results/ckpt"
dataset_size = 400
val_pic_size = 256
lr = 0.0002
n_epochs = 100
n_epochs_decay = 100def get_lr():lrs = [lr] * dataset_size * n_epochslr_epoch = 0for epoch in range(n_epochs_decay):lr_epoch = lr * (n_epochs_decay - epoch) / n_epochs_decaylrs += [lr_epoch] * dataset_sizelrs += [lr_epoch] * dataset_size * (epoch_num - n_epochs_decay - n_epochs)return Tensor(np.array(lrs).astype(np.float32))dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True, num_parallel_workers=1)
steps_per_epoch = dataset.get_dataset_size()
loss_f = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()def forword_dis(reala, realb):lambda_dis = 0.5fakeb = net_generator(reala)pred0 = net_discriminator(reala, fakeb)pred1 = net_discriminator(reala, realb)loss_d = loss_f(pred1, ops.ones_like(pred1)) + loss_f(pred0, ops.zeros_like(pred0))loss_dis = loss_d * lambda_disreturn loss_disdef forword_gan(reala, realb):lambda_gan = 0.5lambda_l1 = 100fakeb = net_generator(reala)pred0 = net_discriminator(reala, fakeb)loss_1 = loss_f(pred0, ops.ones_like(pred0))loss_2 = l1_loss(fakeb, realb)loss_gan = loss_1 * lambda_gan + loss_2 * lambda_l1return loss_gand_opt = nn.Adam(net_discriminator.trainable_params(), learning_rate=get_lr(),beta1=0.5, beta2=0.999, loss_scale=1)
g_opt = nn.Adam(net_generator.trainable_params(), learning_rate=get_lr(),beta1=0.5, beta2=0.999, loss_scale=1)grad_d = value_and_grad(forword_dis, None, net_discriminator.trainable_params())
grad_g = value_and_grad(forword_gan, None, net_generator.trainable_params())def train_step(reala, realb):loss_dis, d_grads = grad_d(reala, realb)loss_gan, g_grads = grad_g(reala, realb)d_opt(d_grads)g_opt(g_grads)return loss_dis, loss_ganif not os.path.isdir(ckpt_dir):os.makedirs(ckpt_dir)g_losses = []
d_losses = []
data_loader = dataset.create_dict_iterator(output_numpy=True, num_epochs=epoch_num)for epoch in range(epoch_num):for i, data in enumerate(data_loader):start_time = datetime.datetime.now()input_image = Tensor(data["input_images"])target_image = Tensor(data["target_images"])dis_loss, gen_loss = train_step(input_image, target_image)end_time = datetime.datetime.now()delta = (end_time - start_time).microsecondsif i % 2 == 0:print("ms per step:{:.2f}  epoch:{}/{}  step:{}/{}  Dloss:{:.4f}  Gloss:{:.4f} ".format((delta / 1000), (epoch + 1), (epoch_num), i, steps_per_epoch, float(dis_loss), float(gen_loss)))d_losses.append(dis_loss.asnumpy())g_losses.append(gen_loss.asnumpy())if (epoch + 1) == epoch_num:mindspore.save_checkpoint(net_generator, ckpt_dir + "Generator.ckpt")

9. 推理

param_g = load_checkpoint(ckpt_dir + "Generator.ckpt")
load_param_into_net(net_generator, param_g)
dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator())
predict_show = net_generator(data_iter["input_images"])
plt.figure(figsize=(10, 3), dpi=140)
for i in range(10):plt.subplot(2, 10, i + 1)plt.imshow((data_iter["input_images"][i].asnumpy().transpose(1, 2, 0) + 1) / 2)plt.axis("off")plt.subplots_adjust(wspace=0.05, hspace=0.02)plt.subplot(2, 10, i + 11)plt.imshow((predict_show[i].asnumpy().transpose(1, 2, 0) + 1) / 2)plt.axis("off")plt.subplots_adjust(wspace=0.05, hspace=0.02)
plt.show()

Pix2Pix的学习过程让我深刻体会到了深度学习在图像处理领域的强大能力。通过训练一对相互竞争的神经网络——生成器与判别器,Pix2Pix能够学习到输入图像与输出图像之间复杂的映射关系。这种端到端的学习方式,无需人工设计复杂的特征提取与转换规则,极大地简化了图像转换的流程,同时也提高了转换结果的质量和多样性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/870117.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python酷库之旅-第三方库Pandas(016)

目录 一、用法精讲 39、pandas.DataFrame.to_stata函数 39-1、语法 39-2、参数 39-3、功能 39-4、返回值 39-5、说明 39-6、用法 39-6-1、数据准备 39-6-2、代码示例 39-6-3、结果输出 40、pandas.read_stata函数 40-1、语法 40-2、参数 40-3、功能 40-4、返回…

nssm的下载和使用

nssm(Non-Sucking Service Manager)是一个用于在Windows系统上管理服务的工具。它允许你将.exe文件和.bat文件转换为Windows服务,并提供了一些功能来管理这些服务。 下载和安装 首先,你需要从nssm官方网站(https://n…

【ARM】MDK安装ARM_compiler5无法打开安装程序

【更多软件使用问题请点击亿道电子官方网站】 1、 文档目标 在客户安装了最新版本的MDK5.37及后续更新版本,但原工程使用ARM_Compiler_5.06进行编译和调试,需安装ARM_Compiler_5.06的编译器版本,但在解压缩的过程中后续无法打开ARM_Compiler…

解释 C 语言中的递归函数

🍅关注博主🎗️ 带你畅游技术世界,不错过每一次成长机会! 📙C 语言百万年薪修炼课程 通俗易懂,深入浅出,匠心打磨,死磕细节,6年迭代,看过的人都说好。 文章目…

AcWing 3381:手机键盘

【题目来源】https://www.acwing.com/problem/content/3384/【题目描述】 请你计算按照手机键盘(9键输入法)输入字母的方式,键入给定字符串(由小写字母构成)所花费的时间。 具体键入规则和花费时间如下描述&#xff1a…

确保智慧校园安全,充分利用操作日志功能

智慧校园基础平台系统的操作日志功能是确保整个平台运行透明、安全及可追溯的核心组件。它自动且详尽地记录下系统内的每一次关键操作细节,涵盖操作的具体时间、执行操作的用户账号、涉及的数据对象(例如学生信息更新、课程调度变动等)、操作…

十、函数栈帧的创建和销毁

前期学习的时候我们可能会有很多困惑,如: (1)局部变量的值是随机值? (2)为什么局部变量的值是随机值? (3)函数是怎么传参的?传参的顺序是怎样的…

Burp Suite Professional 2024.6 for macOS x64 ARM64 - 领先的 Web 渗透测试软件

Burp Suite Professional 2024.6 for macOS x64 & ARM64 - 领先的 Web 渗透测试软件 世界排名第一的 Web 渗透测试工具包 请访问原文链接:https://sysin.org/blog/burp-suite-pro-mac/,查看最新版。原创作品,转载请保留出处。 作者主页…

万字学习——DCU编程实战补充

参考资料 2.1 详解DCU架构 DCU 开发与使用文档 (hpccube.com) DCU架构是什么样的 计算单元阵列,如图CU0、CU1等缓存系统(L1一级缓存,L2二级缓存)全局内存(global memory)CPU和DCU数据通路(DMA) 我的理解…

通过图像高频信息保留图像细节,能保留多少细节-Comfyui

🧨前情提要 如果还不了解comfyui中图像高频信息保留细节的内容,可以参考上一篇文章: 图像中高频信息、低频信息与ComfyUI中图像细节保留的简单研究-CSDN博客 这次主要是简单测试下保留图像细节,能保留到什么程度; …

江波龙 128G msata量产

一小主机不断重启,用DG格式化 无法完成,应该是有坏块了 找一个usb转msata转换板 查了一下是2246en aa主控 颗粒应该是三星的 缓存是现代的 找到量产工具sm22XMPToolP0219B 打开量产工具 用镊子先短接一下jp1 插入usb口,再拿走镊子 scan …

每天五分钟计算机视觉:目标检测算法之R-CNN

本文重点 在计算机视觉领域,目标检测一直是一个核心问题,旨在识别图像中的物体并定位其位置。随着深度学习技术的发展,基于卷积神经网络(CNN)的目标检测算法取得了显著的进步。其中,R-CNN(Regions with CNN features)是一种开创性的目标检测框架,为后续的研究提供了重…

微积分-导数6(隐式导数)

隐式导数 前面我们学了如何求这些方程的导数: y x 3 1 or y x sin ⁡ x y \sqrt{x^31} \quad \text{or} \quad y x\sin x yx31 ​oryxsinx 但是如果是下面的方程,又该如何求导呢? x 3 y 3 6 x y x^3 y^3 6xy x3y36xy 这个方程的图…

【Linux】进程的基本概念(已经进程地址空间的初步了解)

目录 一.什么是进程 进程和程序的区别 Linux查看进程 进程的信息 fork函数 二.进程状态 操作系统上进程状态的概念 运行 阻塞 挂起 Linux中的进程状态 R状态 S状态和D状态 T状态 t状态 X状态 Z状态 三.进程的优先级 修改进程优先级 四.环境变量 常见的环境变量 PATH HOME PW…

科普文:jvm笔记

一、JVM概述# 1. JVM内部结构# 跨语言的平台,只要遵循编译出来的字节码的规范,都可以由JVM运行 虚拟机 系统虚拟机 VMvare 程序虚拟机 JVM JVM结构 HotSpot虚拟机 详细结构图 前端编译器是编译为字节码文件 执行引擎中的JIT Compiler编译器是把字节…

关于无法定位程序输入点 SetDefaultDllDirectories于动态链接库KERNEL32.dll 上 解决方法

文章目录 1. ERNEL32.dll 下载2. 解决方法 👍 个人网站:【 洛秋小站】 1. ERNEL32.dll 下载 Windows 7 在安装postman时报错缺少动态链接库,提示缺少.NET Framework,这是因为本地缺少相应的dll文件导致的,这时就需要下载ERNEL32.dll文件,在解…

前端/python脚本/转换-使用天地图下载的geojson(echarts4+如果直接使用会导致坐标和其他信息不全)

解决echarts4如果直接使用天地图下载的geojson会导致坐标和其他信息不全 解决方法是使用python脚本来补全其他信息:center,level,adcode等内容 前提是必须有一个之前使用的json文件(需要全一点的数据供echarts使用) …

Linux文件编程应用

目录 一、实现cp命令 二、修改程序的配置文件 三、写一个整数/结构体到文件 1.写一个整数到文件 2.写一个结构体到文件 四、写结构体数组到文件 我们学习了文件编程的常用指令以及了解文件编程的基本步骤后,试着来写一些程序实现某些功能。(没有学…

记录一次微信小程序申诉定位权限过程

1 小程序接到通知,检测到违规,需要及时处理,给一周的缓冲时间,如果到期未处理,会封禁能力(2023-11-17) 2 到期后,仍未处理,封禁能力(2023-11-24) …