昇思25天学习打卡营第8天|DCGAN生成漫画头像

文章目录

      • 昇思MindSpore应用实践
        • 基于MindSpore的DCGAN生成漫画头像
          • 1、DCGAN 概述
            • 零和博弈 vs 极大极小博弈
            • GAN的生成对抗损失
            • DCGAN原理
          • 2、数据预处理
          • 3、DCGAN模型构建
            • 生成器部分
            • 判别器部分
          • 4、模型训练
      • Reference

昇思MindSpore应用实践

本系列文章主要用于记录昇思25天学习打卡营的学习心得。

基于MindSpore的DCGAN生成漫画头像
1、DCGAN 概述

这部分原理介绍参考昇思官方文档GAN图像生成和昇思25天学习打卡营第5天_GAN图像生成

生成对抗网络简介:

零和博弈 vs 极大极小博弈

生成对抗网络Generative adversarial networks (GANs)主要包括生成器网络(Generator)和判别器网络(Discriminator)
这两个网络在GAN的训练过程中相互竞争,形成了一种博弈论中的极大极小博弈(MinMax game)

零和博弈(Zero-sum game)是博弈论中的一个重要概念,指的是参与者的利益完全相反,即一方的利益的增加意味着另一方的利益的减少,总利益为零。在零和博弈中,参与者之间的利益是完全对立的,因此一个参与者的利益的增加必然导致其他参与者的利益减少。在非合作博弈中,纳什均衡是一种重要的解,纳什均衡代表每个玩家选择的策略都是其在对方策略给定的情况下的最优策略。在零和博弈中,寻找纳什均衡通常涉及找到使每个玩家的预期收益最大化的策略组合。

极大极小博弈(MinMax game)是一种博弈论中的解决方法,用于确定参与者的最佳决策策略,此外为人所熟知用于决策的方法还有强化学习。在极大极小博弈中,每个参与者都试图最大化自己的最小收益。也就是说,每个参与者都采取行动,以确保在对手选择其最优策略时自己的收益最大化。

假设GAN网络训练达到了纳什平衡状态,那么判别器无法准确地判断出输入样本是真样本还是假样本,此时判别器失效,生成器达到了巅峰状态,我们就无需使用判别器并终止训练了,得到的生成器就是我们用来生成数据的预训练模型。

在这里插入图片描述
从理论上讲,此博弈游戏的平衡点是 p G ( x ; θ ) = p d a t a ( x ) p_{G}(x;\theta) = p_{data}(x) pG(x;θ)=pdata(x),此时判别器会随机猜测输入是真图像还是假图像。下面我们简要说明生成器和判别器的博弈过程:

  1. 在训练刚开始的时候,生成器和判别器的质量都比较差,生成器会随机生成一个数据分布;
  2. 判别器通过求取梯度和损失函数对网络进行优化,将接近真实数据分布的数据判定为1 D ( x ) = 1 D(x)=1 D(x)=1),将接近生成器生成数据分布数据判定为0(( G ( z ) = 0 G(z)=0 G(z)=0)),即希望 min ⁡ G max ⁡ D V ( G , D ) \underset{G}{\min} \underset{D}{\max}V(G, D) GminDmaxV(G,D)
  3. 生成器通过优化,生成出更加贴近真实数据分布的数据;
  4. 生成器所生成的数据和真实数据达到相同的分布,此时判别器的输出为1/2,如上图中的(d)所示。
GAN的生成对抗损失

min ⁡ G max ⁡ D V ( G , D ) = E x ∼ p data ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \underset{G}{\min} \underset{D}{\max}V(G, D) = \mathbb{E}_{x \sim p{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] GminDmaxV(G,D)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

GAN网络本身就是在训练一个能达到平衡状态的损失函数,生成对抗损失是GANs中最基本的损失函数。

当生成对抗损失达到纳什均衡时,判别器对真假数据的判别概率都是0.5,即 D ( x ) = 1 − G ( z ) = 0.5 D(x)=1-G(z)=0.5 D(x)=1G(z)=0.5

l o g ( D ( x ) ) = l o g ( 1 − G ( z ) ) ≈ 0.693 log(D(x))=log(1-G(z))\approx0.693 log(D(x))=log(1G(z))0.693

由于数据x和G(z)不仅是一张图片,再分别取两者的均值 E \mathbb{E} E,相加,就得到了生成对抗损失。

近十年来著名的GAN网络结构:
在这里插入图片描述

DCGAN原理

如上图所示,DCGAN(深度卷积对抗生成网络,Deep Convolutional Generative Adversarial Networks)是GAN的直接扩展。
不同之处在于,DCGAN会分别在判别器和生成器中使用卷积和转置卷积层

它最早由Radford等人在论文Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks中进行描述。判别器由分层的卷积层、BatchNorm层和LeakyReLU激活层组成。输入是3x64x64的图像,输出是该图像为真图像的概率。生成器则是由转置卷积层、BatchNorm层和ReLU激活层组成。输入是标准正态分布中提取出的隐向量 z z z,输出是3x64x64的RGB图像。

本教程将使用动漫头像数据集来训练一个生成式对抗网络,接着使用该网络生成动漫头像图片。

2、数据预处理
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision as visiondef create_dataset_imagenet(dataset_path):"""数据加载"""dataset = ds.ImageFolderDataset(dataset_path,num_parallel_workers=4,shuffle=True,decode=True)# 数据增强操作transforms = [vision.Resize(image_size),vision.CenterCrop(image_size),vision.HWC2CHW(),lambda x: ((x / 255).astype("float32"))]# 数据映射操作dataset = dataset.project('image')dataset = dataset.map(transforms, 'image')# 批量操作dataset = dataset.batch(batch_size)return datasetdataset = create_dataset_imagenet('./faces')# 通过create_dict_iterator函数将数据转换成字典迭代器,然后使用matplotlib模块可视化部分训练数据。import matplotlib.pyplot as pltdef plot_data(data):# 可视化部分训练数据plt.figure(figsize=(10, 3), dpi=140)for i, image in enumerate(data[0][:30], 1):plt.subplot(3, 10, i)plt.axis("off")plt.imshow(image.transpose(1, 2, 0))plt.show()sample_data = next(dataset.create_tuple_iterator(output_numpy=True))
plot_data(sample_data)

在这里插入图片描述

3、DCGAN模型构建
生成器部分

生成器G的功能是将隐向量z映射到数据空间。由于数据是图像,这一过程也会创建与真实图像大小相同的 RGB 图像。在实践场景中,该功能是通过一系列Conv2dTranspose转置卷积层来完成的,每个层都与BatchNorm2d层和ReLu激活层配对,输出数据会经过tanh函数,使其返回[-1,1]的数据范围内。

DCGAN生成器生成图像的大致流程如下:

1、将一个1x100的高斯潜在噪声向量投影变换为一个4x4x1024的特征图;
2、在经过CONV1卷积输出为8x8x512的特征图;
3、逐步增大分辨率,缩小通道数,经过CONV2卷积输出为16x16x256的特征图;
4、经过CONV3卷积输出为32x32x128的特征图;
5、最后经过CONV4卷积输出为64x64x3的生成图像,与真实图像一起送入判别器进行鉴定;
6、在训练过程中尽可能地生成逼近真实图像分布的效果从而欺骗判别器,令其失效,这样生成对抗就达到了平衡状态,生成器的训练过程完毕,拿去用作模型推理。
在这里插入图片描述

import mindspore as ms
from mindspore import nn, ops
from mindspore.common.initializer import Normalweight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)class Generator(nn.Cell):"""DCGAN网络生成器"""def __init__(self):super(Generator, self).__init__()self.generator = nn.SequentialCell(nn.Conv2dTranspose(nz, ngf * 8, 4, 1, 'valid', weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 8, ngf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 4, ngf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 2, ngf, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf, nc, 4, 2, 'pad', 1, weight_init=weight_init),nn.Tanh())def construct(self, x):return self.generator(x)generator = Generator()
判别器部分

在这里插入图片描述

class Discriminator(nn.Cell):"""DCGAN网络判别器"""def __init__(self):super(Discriminator, self).__init__()self.discriminator = nn.SequentialCell(nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, weight_init=weight_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf, ndf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 8, 1, 4, 1, 'valid', weight_init=weight_init),)self.adv_layer = nn.Sigmoid()def construct(self, x):out = self.discriminator(x)out = out.reshape(out.shape[0], -1)return self.adv_layer(out)discriminator = Discriminator()
4、模型训练
# 定义损失函数
adversarial_loss = nn.BCELoss(reduction='mean')# 为生成器和判别器设置优化器
optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G.update_parameters_name('optim_g.')
optimizer_D.update_parameters_name('optim_d.')# 定义训练时要用到的功能函数
def generator_forward(real_imgs, valid):# 将噪声采样为发生器的输入z = ops.standard_normal((real_imgs.shape[0], nz, 1, 1))# 生成一批图像gen_imgs = generator(z)# 损失衡量发生器绕过判别器的能力g_loss = adversarial_loss(discriminator(gen_imgs), valid)return g_loss, gen_imgsdef discriminator_forward(real_imgs, gen_imgs, valid, fake):# 衡量鉴别器从生成的样本中对真实样本进行分类的能力real_loss = adversarial_loss(discriminator(real_imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs), fake)d_loss = (real_loss + fake_loss) / 2return d_lossgrad_generator_fn = ms.value_and_grad(generator_forward, None,optimizer_G.parameters,has_aux=True)
grad_discriminator_fn = ms.value_and_grad(discriminator_forward, None,optimizer_D.parameters)@ms.jit
def train_step(imgs):valid = ops.ones((imgs.shape[0], 1), mindspore.float32)fake = ops.zeros((imgs.shape[0], 1), mindspore.float32)(g_loss, gen_imgs), g_grads = grad_generator_fn(imgs, valid)optimizer_G(g_grads)d_loss, d_grads = grad_discriminator_fn(imgs, gen_imgs, valid, fake)optimizer_D(d_grads)return g_loss, d_loss, gen_imgsimport mindsporeG_losses = []
D_losses = []
image_list = []total = dataset.get_dataset_size()
for epoch in range(num_epochs):generator.set_train()discriminator.set_train()# 为每轮训练读入数据for i, (imgs, ) in enumerate(dataset.create_tuple_iterator()):g_loss, d_loss, gen_imgs = train_step(imgs)if i % 100 == 0 or i == total - 1:# 输出训练记录print('[%2d/%d][%3d/%d]   Loss_D:%7.4f  Loss_G:%7.4f' % (epoch + 1, num_epochs, i + 1, total, d_loss.asnumpy(), g_loss.asnumpy()))D_losses.append(d_loss.asnumpy())G_losses.append(g_loss.asnumpy())# 每个epoch结束后,使用生成器生成一组图片generator.set_train(False)fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))img = generator(fixed_noise)image_list.append(img.transpose(0, 2, 3, 1).asnumpy())# 保存网络模型参数为ckpt文件mindspore.save_checkpoint(generator, "./generator.ckpt")mindspore.save_checkpoint(discriminator, "./discriminator.ckpt")

cpu训练5个epoch的训练效果:
在这里插入图片描述
可以明显看出Loss_D和Loss_G的分数并没有达到0.5:0.5的纳什平衡状态,生成图像自然是很可怕的抽象二次元漫画头像,这里忘了截图了就不放效果了。

申请了Ascend910 NPU的算力,训练50轮效果:

910太快了啊,吃顿饭回来就跑完了,不过结果还是蚌埠住了…
在这里插入图片描述
还是很糊,练崩了,今天先到这里了,先打次卡,有时间再调整一下网络结构试试,DCGAN可能对Anime数据集来说还是太简单了,不太好控制的样子。
在这里插入图片描述
两个网络训练的log:

在这里插入图片描述

Reference

昇思大模型平台
什么是GAN生成对抗网络,使用DCGAN生成动漫头像

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/38894.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器学习基础概念

1.机器学习定义 2.机器学习工作流程 (1)数据集 ①一行数据:一个样本 ②一列数据:一个特征 ③目标值(标签值):有些数据集有目标值,有些数据集没有。因此数据类型由特征值目标值构成或…

Java实现图书管理系统

一、框架 1. 创建类 用户:管理员AdminUser 普通用户NormalUser 继承抽象类User 书:书Book 书架BookList 操作对象:书Book 2. 知识点 主要涉及的知识点:数据类型 变量 if for 数组 方法 类和对象 封装继承多态 抽象类和接口 …

Linux运维之需掌握的基本Linux命令

前言:本博客仅作记录学习使用,部分图片出自网络,如有侵犯您的权益,请联系删除 目录 一、SHELL 二、执行命令 三、常用系统工作命令 四、系统状态检测命令 五、查找定位文件命令 六、文本文件编辑命令 七、文件目录管理命令…

【JavaWeb】登录校验-会话技术(一)Cookie与Session

登录校验 实现登陆后才能访问后端系统页面,不登陆则跳转登陆页面进行登陆。 首先我们在宏观上先有一个认知: HTTP协议是无状态协议。即每一次请求都是独立的,下一次请求并不会携带上一次请求的数据。 因此当我们通过浏览器访问登录后&#…

Simulink 模型生成 C 代码(一):使用 Embedded Coder 快速向导生成代码

以matlab自带的示例模型RollAxisAutopilot为例进行讲解。RollAxisAutopilot为飞机自动驾驶控制系统模型。 使用快速向导工具生成代码 通过键入以下命令打开模型 RollAxisAutopilot: openExample(RollAxisAutopilot); 如果 C 代码选项卡尚未打开,请在 …

【C++】宏定义

严格来说,这个题目起名为C是不合适的,因为宏定义是C语言的遗留特性。CleanCode并不推荐C中使用宏定义。我当时还在公司做过宏定义为什么应该被取代的报告。但是适当使用宏定义对代码是有好处的。坏处也有一些。 无参宏定义 最常见的一种宏定义&#xf…

新声创新20年:无线技术给助听器插上“娱乐”的翅膀

听力损失并非现代人的专利,古代人也会有听力损失。助听器距今发展已经有二百多年了,从当初单纯的声音放大器到如今的全数字时代助听器,助听器发生了翻天覆地的变化,现代助听器除了助听功能,还具有看电视,听…

C++ 和C#的差别

首先把眼睛瞪大,然后憋住一口气,读下去: 1、CPP 就是C plus plus的缩写,中国大陆的程序员圈子中通常被读做"C加加",而西方的程序员通常读做"C plus plus",它是一种使用非常广泛的计算…

Maya崩溃闪退常见原因及解决方案

Autodesk Maya 是一款功能强大的 3D 计算机图形程序,被电影、游戏和建筑等各个领域的设计师广泛使用。然而,Maya 就像任何其他软件一样可能会发生崩溃问题。在前文中,小编给大家介绍了3ds Max使用V-Ray渲染时的崩溃闪退解决方案: …

后端之路第三站(Mybatis)——JDBC跟Mybatis、lombok

一、什么是JDBC JDBC就是sun公司研发的一套通过java来操控数据库的工具,对应不同的数据库系统有不同的JDBC,而他们统称【驱动】,这就是上一篇我们提到创建Mybatis项目时要引入的依赖、以及连接数据库四要素里的第一要素。 JDBC有自己一套原始…

Elasticsearch:Painless scripting 语言(一)

Painless 是一种高性能、安全的脚本语言,专为 Elasticsearch 设计。你可以使用 Painless 在 Elasticsearch 支持脚本的任何地方安全地编写内联和存储脚本。 Painless 提供众多功能,这些功能围绕以下核心原则: 安全性:确保集群的…

近红外光谱脑功能成像(fNIRS):1.光学原理、变量选取与预处理

一、朗伯-比尔定律与修正的朗伯-比尔定律 朗伯-比尔定律 是一个描述光通过溶液时被吸收的规律。想象你有一杯有色液体,比如一杯红茶。当你用一束光照射这杯液体时,光的一部分会被液体吸收,导致透过液体的光变弱。朗伯-比尔定律告诉我们&#…

redis主从复制哨兵模式集群管理

主从复制: 主从复制是高可用Redis的基础,哨兵和集群都是在主从复制基础上实现高可用的。主从复制主要实现了数据的多机备份,以及对于读操作的负载均衡和简单的故障恢复。缺陷:故障恢复无法自动化;写操作无法负载均衡&…

HbuilderX:安卓打包证书.keystore生成与使用

前置条件 已安装jdk或配置好jre环境。 .keystore生成 打开cmd,切换到目标路径,输入以下命令, keytool -genkey -alias testalias -keyalg RSA -keysize 2048 -validity 36500 -keystore test.keystore 输入密钥库口令(要记住), 然后输入一系列信息, …

ui.perfetto.dev sql 查询某个事件范围内,某个事件的耗时并降序排列

ui.perfetto.dev sql 查询某个事件范围内,某个事件的耗时并降序排列 1.打开https://ui.perfetto.dev 导入Chrome Trace Json文件2.ParallelMLP.forward下的RowParallelLinear.forward3.点击Query(SQL),在输入框中输入以下内容,按CtrlEnter,显示查询结果4.点击Show timeline,点击…

2024年07年01日 Redis数据类型以及使用场景

String Hash List Set Sorted Set String,用的最多,对象序列化成json然后存储 1.对象缓存,单值缓存 2.分布式锁 Hash,不怎么用到 1.可缓存经常需要修改值的对象,可单独对对象某个属性进行修改 HMSET user {userI…

C++基础(三):C++入门(二)

上一篇博客我们正式进入C的学习,这一篇博客我们继续学习C入门的基础内容,一定要学好入门阶段的内容,这是后续学习C的基础,方便我们后续更加容易的理解C。 目录 一、内联函数 1.0 产生的原因 1.1 概念 1.2 特性 1.3 面试题 …

用随机森林算法进行的一次故障预测

本案例将带大家使用一份开源的S.M.A.R.T.数据集和机器学习中的随机森林算法,来训练一个硬盘故障预测模型,并测试效果。 实验目标 掌握使用机器学习方法训练模型的基本流程;掌握使用pandas做数据分析的基本方法;掌握使用scikit-l…

珠江电缆,承载您梦想的每一度电

在现代社会,电力无处不在,它不仅是经济发展的动力,更是每个人生活中不可或缺的能量来源。而在这个电力驱动的世界里,有一家企业默默地承载着千家万户的梦想,它就是珠江电缆。 连接梦想的每一度电 珠江电缆成立于2001…

绝区零国际服下载 一键下载绝区零国际服教程

绝区零是一款米哈游倾情打造的全新都市幻想动作角色扮演游戏。在游戏中,我们将扮演一名绳匠,这是为出于各种原因需要进入危险空洞的人提供指引的专业人士。您将与独特的角色一起踏上冒险之旅,携手探索空洞,对战强大敌人&#xff0…