昇思25天学习打卡营第8天|DCGAN生成漫画头像

文章目录

      • 昇思MindSpore应用实践
        • 基于MindSpore的DCGAN生成漫画头像
          • 1、DCGAN 概述
            • 零和博弈 vs 极大极小博弈
            • GAN的生成对抗损失
            • DCGAN原理
          • 2、数据预处理
          • 3、DCGAN模型构建
            • 生成器部分
            • 判别器部分
          • 4、模型训练
      • Reference

昇思MindSpore应用实践

本系列文章主要用于记录昇思25天学习打卡营的学习心得。

基于MindSpore的DCGAN生成漫画头像
1、DCGAN 概述

这部分原理介绍参考昇思官方文档GAN图像生成和昇思25天学习打卡营第5天_GAN图像生成

生成对抗网络简介:

零和博弈 vs 极大极小博弈

生成对抗网络Generative adversarial networks (GANs)主要包括生成器网络(Generator)和判别器网络(Discriminator)
这两个网络在GAN的训练过程中相互竞争,形成了一种博弈论中的极大极小博弈(MinMax game)

零和博弈(Zero-sum game)是博弈论中的一个重要概念,指的是参与者的利益完全相反,即一方的利益的增加意味着另一方的利益的减少,总利益为零。在零和博弈中,参与者之间的利益是完全对立的,因此一个参与者的利益的增加必然导致其他参与者的利益减少。在非合作博弈中,纳什均衡是一种重要的解,纳什均衡代表每个玩家选择的策略都是其在对方策略给定的情况下的最优策略。在零和博弈中,寻找纳什均衡通常涉及找到使每个玩家的预期收益最大化的策略组合。

极大极小博弈(MinMax game)是一种博弈论中的解决方法,用于确定参与者的最佳决策策略,此外为人所熟知用于决策的方法还有强化学习。在极大极小博弈中,每个参与者都试图最大化自己的最小收益。也就是说,每个参与者都采取行动,以确保在对手选择其最优策略时自己的收益最大化。

假设GAN网络训练达到了纳什平衡状态,那么判别器无法准确地判断出输入样本是真样本还是假样本,此时判别器失效,生成器达到了巅峰状态,我们就无需使用判别器并终止训练了,得到的生成器就是我们用来生成数据的预训练模型。

在这里插入图片描述
从理论上讲,此博弈游戏的平衡点是 p G ( x ; θ ) = p d a t a ( x ) p_{G}(x;\theta) = p_{data}(x) pG(x;θ)=pdata(x),此时判别器会随机猜测输入是真图像还是假图像。下面我们简要说明生成器和判别器的博弈过程:

  1. 在训练刚开始的时候,生成器和判别器的质量都比较差,生成器会随机生成一个数据分布;
  2. 判别器通过求取梯度和损失函数对网络进行优化,将接近真实数据分布的数据判定为1 D ( x ) = 1 D(x)=1 D(x)=1),将接近生成器生成数据分布数据判定为0(( G ( z ) = 0 G(z)=0 G(z)=0)),即希望 min ⁡ G max ⁡ D V ( G , D ) \underset{G}{\min} \underset{D}{\max}V(G, D) GminDmaxV(G,D)
  3. 生成器通过优化,生成出更加贴近真实数据分布的数据;
  4. 生成器所生成的数据和真实数据达到相同的分布,此时判别器的输出为1/2,如上图中的(d)所示。
GAN的生成对抗损失

min ⁡ G max ⁡ D V ( G , D ) = E x ∼ p data ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \underset{G}{\min} \underset{D}{\max}V(G, D) = \mathbb{E}_{x \sim p{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] GminDmaxV(G,D)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

GAN网络本身就是在训练一个能达到平衡状态的损失函数,生成对抗损失是GANs中最基本的损失函数。

当生成对抗损失达到纳什均衡时,判别器对真假数据的判别概率都是0.5,即 D ( x ) = 1 − G ( z ) = 0.5 D(x)=1-G(z)=0.5 D(x)=1G(z)=0.5

l o g ( D ( x ) ) = l o g ( 1 − G ( z ) ) ≈ 0.693 log(D(x))=log(1-G(z))\approx0.693 log(D(x))=log(1G(z))0.693

由于数据x和G(z)不仅是一张图片,再分别取两者的均值 E \mathbb{E} E,相加,就得到了生成对抗损失。

近十年来著名的GAN网络结构:
在这里插入图片描述

DCGAN原理

如上图所示,DCGAN(深度卷积对抗生成网络,Deep Convolutional Generative Adversarial Networks)是GAN的直接扩展。
不同之处在于,DCGAN会分别在判别器和生成器中使用卷积和转置卷积层

它最早由Radford等人在论文Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks中进行描述。判别器由分层的卷积层、BatchNorm层和LeakyReLU激活层组成。输入是3x64x64的图像,输出是该图像为真图像的概率。生成器则是由转置卷积层、BatchNorm层和ReLU激活层组成。输入是标准正态分布中提取出的隐向量 z z z,输出是3x64x64的RGB图像。

本教程将使用动漫头像数据集来训练一个生成式对抗网络,接着使用该网络生成动漫头像图片。

2、数据预处理
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision as visiondef create_dataset_imagenet(dataset_path):"""数据加载"""dataset = ds.ImageFolderDataset(dataset_path,num_parallel_workers=4,shuffle=True,decode=True)# 数据增强操作transforms = [vision.Resize(image_size),vision.CenterCrop(image_size),vision.HWC2CHW(),lambda x: ((x / 255).astype("float32"))]# 数据映射操作dataset = dataset.project('image')dataset = dataset.map(transforms, 'image')# 批量操作dataset = dataset.batch(batch_size)return datasetdataset = create_dataset_imagenet('./faces')# 通过create_dict_iterator函数将数据转换成字典迭代器,然后使用matplotlib模块可视化部分训练数据。import matplotlib.pyplot as pltdef plot_data(data):# 可视化部分训练数据plt.figure(figsize=(10, 3), dpi=140)for i, image in enumerate(data[0][:30], 1):plt.subplot(3, 10, i)plt.axis("off")plt.imshow(image.transpose(1, 2, 0))plt.show()sample_data = next(dataset.create_tuple_iterator(output_numpy=True))
plot_data(sample_data)

在这里插入图片描述

3、DCGAN模型构建
生成器部分

生成器G的功能是将隐向量z映射到数据空间。由于数据是图像,这一过程也会创建与真实图像大小相同的 RGB 图像。在实践场景中,该功能是通过一系列Conv2dTranspose转置卷积层来完成的,每个层都与BatchNorm2d层和ReLu激活层配对,输出数据会经过tanh函数,使其返回[-1,1]的数据范围内。

DCGAN生成器生成图像的大致流程如下:

1、将一个1x100的高斯潜在噪声向量投影变换为一个4x4x1024的特征图;
2、在经过CONV1卷积输出为8x8x512的特征图;
3、逐步增大分辨率,缩小通道数,经过CONV2卷积输出为16x16x256的特征图;
4、经过CONV3卷积输出为32x32x128的特征图;
5、最后经过CONV4卷积输出为64x64x3的生成图像,与真实图像一起送入判别器进行鉴定;
6、在训练过程中尽可能地生成逼近真实图像分布的效果从而欺骗判别器,令其失效,这样生成对抗就达到了平衡状态,生成器的训练过程完毕,拿去用作模型推理。
在这里插入图片描述

import mindspore as ms
from mindspore import nn, ops
from mindspore.common.initializer import Normalweight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)class Generator(nn.Cell):"""DCGAN网络生成器"""def __init__(self):super(Generator, self).__init__()self.generator = nn.SequentialCell(nn.Conv2dTranspose(nz, ngf * 8, 4, 1, 'valid', weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 8, ngf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 4, ngf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 2, ngf, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf, nc, 4, 2, 'pad', 1, weight_init=weight_init),nn.Tanh())def construct(self, x):return self.generator(x)generator = Generator()
判别器部分

在这里插入图片描述

class Discriminator(nn.Cell):"""DCGAN网络判别器"""def __init__(self):super(Discriminator, self).__init__()self.discriminator = nn.SequentialCell(nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, weight_init=weight_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf, ndf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 8, 1, 4, 1, 'valid', weight_init=weight_init),)self.adv_layer = nn.Sigmoid()def construct(self, x):out = self.discriminator(x)out = out.reshape(out.shape[0], -1)return self.adv_layer(out)discriminator = Discriminator()
4、模型训练
# 定义损失函数
adversarial_loss = nn.BCELoss(reduction='mean')# 为生成器和判别器设置优化器
optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G.update_parameters_name('optim_g.')
optimizer_D.update_parameters_name('optim_d.')# 定义训练时要用到的功能函数
def generator_forward(real_imgs, valid):# 将噪声采样为发生器的输入z = ops.standard_normal((real_imgs.shape[0], nz, 1, 1))# 生成一批图像gen_imgs = generator(z)# 损失衡量发生器绕过判别器的能力g_loss = adversarial_loss(discriminator(gen_imgs), valid)return g_loss, gen_imgsdef discriminator_forward(real_imgs, gen_imgs, valid, fake):# 衡量鉴别器从生成的样本中对真实样本进行分类的能力real_loss = adversarial_loss(discriminator(real_imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs), fake)d_loss = (real_loss + fake_loss) / 2return d_lossgrad_generator_fn = ms.value_and_grad(generator_forward, None,optimizer_G.parameters,has_aux=True)
grad_discriminator_fn = ms.value_and_grad(discriminator_forward, None,optimizer_D.parameters)@ms.jit
def train_step(imgs):valid = ops.ones((imgs.shape[0], 1), mindspore.float32)fake = ops.zeros((imgs.shape[0], 1), mindspore.float32)(g_loss, gen_imgs), g_grads = grad_generator_fn(imgs, valid)optimizer_G(g_grads)d_loss, d_grads = grad_discriminator_fn(imgs, gen_imgs, valid, fake)optimizer_D(d_grads)return g_loss, d_loss, gen_imgsimport mindsporeG_losses = []
D_losses = []
image_list = []total = dataset.get_dataset_size()
for epoch in range(num_epochs):generator.set_train()discriminator.set_train()# 为每轮训练读入数据for i, (imgs, ) in enumerate(dataset.create_tuple_iterator()):g_loss, d_loss, gen_imgs = train_step(imgs)if i % 100 == 0 or i == total - 1:# 输出训练记录print('[%2d/%d][%3d/%d]   Loss_D:%7.4f  Loss_G:%7.4f' % (epoch + 1, num_epochs, i + 1, total, d_loss.asnumpy(), g_loss.asnumpy()))D_losses.append(d_loss.asnumpy())G_losses.append(g_loss.asnumpy())# 每个epoch结束后,使用生成器生成一组图片generator.set_train(False)fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))img = generator(fixed_noise)image_list.append(img.transpose(0, 2, 3, 1).asnumpy())# 保存网络模型参数为ckpt文件mindspore.save_checkpoint(generator, "./generator.ckpt")mindspore.save_checkpoint(discriminator, "./discriminator.ckpt")

cpu训练5个epoch的训练效果:
在这里插入图片描述
可以明显看出Loss_D和Loss_G的分数并没有达到0.5:0.5的纳什平衡状态,生成图像自然是很可怕的抽象二次元漫画头像,这里忘了截图了就不放效果了。

申请了Ascend910 NPU的算力,训练50轮效果:

910太快了啊,吃顿饭回来就跑完了,不过结果还是蚌埠住了…
在这里插入图片描述
还是很糊,练崩了,今天先到这里了,先打次卡,有时间再调整一下网络结构试试,DCGAN可能对Anime数据集来说还是太简单了,不太好控制的样子。
在这里插入图片描述
两个网络训练的log:

在这里插入图片描述

Reference

昇思大模型平台
什么是GAN生成对抗网络,使用DCGAN生成动漫头像

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/38894.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器学习基础概念

1.机器学习定义 2.机器学习工作流程 (1)数据集 ①一行数据:一个样本 ②一列数据:一个特征 ③目标值(标签值):有些数据集有目标值,有些数据集没有。因此数据类型由特征值目标值构成或…

Java实现图书管理系统

一、框架 1. 创建类 用户:管理员AdminUser 普通用户NormalUser 继承抽象类User 书:书Book 书架BookList 操作对象:书Book 2. 知识点 主要涉及的知识点:数据类型 变量 if for 数组 方法 类和对象 封装继承多态 抽象类和接口 …

Linux运维之需掌握的基本Linux命令

前言:本博客仅作记录学习使用,部分图片出自网络,如有侵犯您的权益,请联系删除 目录 一、SHELL 二、执行命令 三、常用系统工作命令 四、系统状态检测命令 五、查找定位文件命令 六、文本文件编辑命令 七、文件目录管理命令…

【JavaWeb】登录校验-会话技术(一)Cookie与Session

登录校验 实现登陆后才能访问后端系统页面,不登陆则跳转登陆页面进行登陆。 首先我们在宏观上先有一个认知: HTTP协议是无状态协议。即每一次请求都是独立的,下一次请求并不会携带上一次请求的数据。 因此当我们通过浏览器访问登录后&#…

go语言怎么获取文件的大小并且转化为kb为单位呢?

在Go语言中,你可以使用os包中的IsExist和Stat函数来获取文件的信息,包括文件的大小。文件的大小通常是以字节为单位的,但你可以很容易地将其转换为KB(千字节)。 下面是一个简单的Go程序示例,该程序打开指定…

Simulink 模型生成 C 代码(一):使用 Embedded Coder 快速向导生成代码

以matlab自带的示例模型RollAxisAutopilot为例进行讲解。RollAxisAutopilot为飞机自动驾驶控制系统模型。 使用快速向导工具生成代码 通过键入以下命令打开模型 RollAxisAutopilot: openExample(RollAxisAutopilot); 如果 C 代码选项卡尚未打开,请在 …

【C++】宏定义

严格来说,这个题目起名为C是不合适的,因为宏定义是C语言的遗留特性。CleanCode并不推荐C中使用宏定义。我当时还在公司做过宏定义为什么应该被取代的报告。但是适当使用宏定义对代码是有好处的。坏处也有一些。 无参宏定义 最常见的一种宏定义&#xf…

makefile总结

1,Makefile规则介绍 一个简单的 Makefile 描述规则组成: TARGET... : PREREQUISITES... COMMAND 注意: 每一个命令行必须以[Tab]字符开始, [Tab]字符告诉 make 此行是一个命令行。 make 按照命令完成相应的动作。这也是书写 Makefile 中容易产生,而且比较隐蔽的错…

油烟净化器:餐饮业健康环保的守护者

我最近分析了餐饮市场的油烟净化器等产品报告,解决了餐饮业厨房油腻的难题,更加方便了在餐饮业和商业场所有需求的小伙伴们。 在现代餐饮业,油烟净化器已经成为不可或缺的重要设备。它不仅是保障餐饮环境清洁的利器,更是守护健康…

新声创新20年:无线技术给助听器插上“娱乐”的翅膀

听力损失并非现代人的专利,古代人也会有听力损失。助听器距今发展已经有二百多年了,从当初单纯的声音放大器到如今的全数字时代助听器,助听器发生了翻天覆地的变化,现代助听器除了助听功能,还具有看电视,听…

【LeetCode】368. 最大整除子集

虽然这题挺难写的,但是仍然提醒了我:解题要注意方法。在明确分析当一条道路走不通的时候,就不要再犹豫了,就要果断的换方法,尝试用其它方法解决。否则一味的消耗时间,得不偿失。换方法的前提是明确的分析&a…

C++ 和C#的差别

首先把眼睛瞪大,然后憋住一口气,读下去: 1、CPP 就是C plus plus的缩写,中国大陆的程序员圈子中通常被读做"C加加",而西方的程序员通常读做"C plus plus",它是一种使用非常广泛的计算…

Maya崩溃闪退常见原因及解决方案

Autodesk Maya 是一款功能强大的 3D 计算机图形程序,被电影、游戏和建筑等各个领域的设计师广泛使用。然而,Maya 就像任何其他软件一样可能会发生崩溃问题。在前文中,小编给大家介绍了3ds Max使用V-Ray渲染时的崩溃闪退解决方案: …

Neo4j 图数据库 高级操作

Neo4j 图数据库 高级操作 文章目录 Neo4j 图数据库 高级操作1 批量添加节点、关系1.1 直接使用 UNWIND 批量创建关系1.2 使用 CSV 文件批量创建关系1.3 选择方法 2 索引2.1 创建单一属性索引2.2 创建组合属性索引2.3 创建全文索引2.4 列出所有索引2.5 删除索引2.6 注意事项 3 清…

后端之路第三站(Mybatis)——JDBC跟Mybatis、lombok

一、什么是JDBC JDBC就是sun公司研发的一套通过java来操控数据库的工具,对应不同的数据库系统有不同的JDBC,而他们统称【驱动】,这就是上一篇我们提到创建Mybatis项目时要引入的依赖、以及连接数据库四要素里的第一要素。 JDBC有自己一套原始…

SerialportToTCP② 全

效果补全(代码): namespace SerialportToTCP {public partial class Form1 : Form{IniHelper Ini;string[] botelvs new string[] { "1200", "4800", "9600", "13200" };public Form1(){Initializ…

Elasticsearch:Painless scripting 语言(一)

Painless 是一种高性能、安全的脚本语言,专为 Elasticsearch 设计。你可以使用 Painless 在 Elasticsearch 支持脚本的任何地方安全地编写内联和存储脚本。 Painless 提供众多功能,这些功能围绕以下核心原则: 安全性:确保集群的…

安卓gdb 建立链接

adbshell gdbserver :1234 testdcam --sensor 0 --workmode 0 --args preview-size1024x600,picture-size640x480, --time 10 adb forwardtcp:1234 tcp:1234 //设置adb的转发 ./prebuilts/gcc/linux-x86/arm/arm-linux-androideabi-4.7/bin/arm-linux-androideabi-gdb out/tar…

近红外光谱脑功能成像(fNIRS):1.光学原理、变量选取与预处理

一、朗伯-比尔定律与修正的朗伯-比尔定律 朗伯-比尔定律 是一个描述光通过溶液时被吸收的规律。想象你有一杯有色液体,比如一杯红茶。当你用一束光照射这杯液体时,光的一部分会被液体吸收,导致透过液体的光变弱。朗伯-比尔定律告诉我们&#…

mmdetection3D指定版本安装指南

1. 下载指定版本号 选择指定版本号下载mmdetection3d的源码,如这里选择的是0.17.2版本 git clone https://github.com/open-mmlab/mmdetection3d.git -b v0.17.22. 安装 cd mmdetection3d安装依赖库 pip install -r requirment.txt编译安装 pip install -v e .…