昇思25天学习打卡营第21天|DCGAN生成漫画头像

DCGAN原理

DCGAN(深度卷积对抗生成网络,Deep Convolutional Generative Adversarial Networks)是GAN的直接扩展。不同之处在于,DCGAN会分别在判别器和生成器中使用卷积和转置卷积层。

它最早由Radford等人在论文Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks中进行描述。判别器由分层的卷积层、BatchNorm层和LeakyReLU激活层组成。输入是3x64x64的图像,输出是该图像为真图像的概率。生成器则是由转置卷积层、BatchNorm层和ReLU激活层组成。输入是标准正态分布中提取出的隐向量 z z z,输出是3x64x64的RGB图像。

数据准备与处理

from download import downloadurl = "https://download.mindspore.cn/dataset/Faces/faces.zip"path = download(url, "./faces", kind="zip", replace=True)

下载后的数据集目录结构如下:

./faces/faces
├── 0.jpg
├── 1.jpg
├── 2.jpg
├── 3.jpg
├── 4.jpg...
├── 70169.jpg
└── 70170.jpg

数据处理

首先为执行过程定义一些输入:

batch_size = 128          # 批量大小
image_size = 64           # 训练图像空间大小
nc = 3                    # 图像彩色通道数
nz = 100                  # 隐向量的长度
ngf = 64                  # 特征图在生成器中的大小
ndf = 64                  # 特征图在判别器中的大小
num_epochs = 3           # 训练周期数
lr = 0.0002               # 学习率
beta1 = 0.5               # Adam优化器的beta1超参数import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision as visiondef create_dataset_imagenet(dataset_path):"""数据加载"""dataset = ds.ImageFolderDataset(dataset_path,num_parallel_workers=4,shuffle=True,decode=True)# 数据增强操作transforms = [vision.Resize(image_size),vision.CenterCrop(image_size),vision.HWC2CHW(),lambda x: ((x / 255).astype("float32"))]# 数据映射操作dataset = dataset.project('image')dataset = dataset.map(transforms, 'image')# 批量操作dataset = dataset.batch(batch_size)return datasetdataset = create_dataset_imagenet('./faces')import matplotlib.pyplot as pltdef plot_data(data):# 可视化部分训练数据plt.figure(figsize=(10, 3), dpi=140)for i, image in enumerate(data[0][:30], 1):plt.subplot(3, 10, i)plt.axis("off")plt.imshow(image.transpose(1, 2, 0))plt.show()sample_data = next(dataset.create_tuple_iterator(output_numpy=True))
plot_data(sample_data)

构造网络

当处理完数据后,就可以来进行网络的搭建了。按照DCGAN论文中的描述,所有模型权重均应从mean为0,sigma为0.02的正态分布中随机初始化。

生成器

生成器G的功能是将隐向量z映射到数据空间。由于数据是图像,这一过程也会创建与真实图像大小相同的 RGB 图像。在实践场景中,该功能是通过一系列Conv2dTranspose转置卷积层来完成的,每个层都与BatchNorm2d层和ReLu激活层配对,输出数据会经过tanh函数,使其返回[-1,1]的数据范围内。

DCGAN论文生成图像如下所示:

dcgangenerator

图片来源:Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks.

我们通过输入部分中设置的nzngfnc来影响代码中的生成器结构。nz是隐向量z的长度,ngf与通过生成器传播的特征图的大小有关,nc是输出图像中的通道数。

以下是生成器的代码实现:

import mindspore as ms
from mindspore import nn, ops
from mindspore.common.initializer import Normalweight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)class Generator(nn.Cell):"""DCGAN网络生成器"""def __init__(self):super(Generator, self).__init__()self.generator = nn.SequentialCell(nn.Conv2dTranspose(nz, ngf * 8, 4, 1, 'valid', weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 8, ngf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 4, ngf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 2, ngf, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf, nc, 4, 2, 'pad', 1, weight_init=weight_init),nn.Tanh())def construct(self, x):return self.generator(x)generator = Generator()

判别器

如前所述,判别器D是一个二分类网络模型,输出判定该图像为真实图的概率。通过一系列的Conv2dBatchNorm2dLeakyReLU层对其进行处理,最后通过Sigmoid激活函数得到最终概率。

DCGAN论文提到,使用卷积而不是通过池化来进行下采样是一个好方法,因为它可以让网络学习自己的池化特征。

判别器的代码实现如下:

class Discriminator(nn.Cell):"""DCGAN网络判别器"""def __init__(self):super(Discriminator, self).__init__()self.discriminator = nn.SequentialCell(nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, weight_init=weight_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf, ndf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 8, 1, 4, 1, 'valid', weight_init=weight_init),)self.adv_layer = nn.Sigmoid()def construct(self, x):out = self.discriminator(x)out = out.reshape(out.shape[0], -1)return self.adv_layer(out)discriminator = Discriminator()

模型训练

损失函数

当定义了DG后,接下来将使用MindSpore中定义的二进制交叉熵损失函数BCELoss。

# 定义损失函数
adversarial_loss = nn.BCELoss(reduction='mean')

优化器

这里设置了两个单独的优化器,一个用于D,另一个用于G。这两个都是lr = 0.0002beta1 = 0.5的Adam优化器。

# 为生成器和判别器设置优化器
optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G.update_parameters_name('optim_g.')
optimizer_D.update_parameters_name('optim_d.')

训练模型

训练分为两个主要部分:训练判别器和训练生成器。

  • 训练判别器

    训练判别器的目的是最大程度地提高判别图像真伪的概率。按照Goodfellow的方法,是希望通过提高其随机梯度来更新判别器,所以我们要最大化 l o g D ( x ) + l o g ( 1 − D ( G ( z ) ) log D(x) + log(1 - D(G(z)) logD(x)+log(1D(G(z))的值。

  • 训练生成器

    如DCGAN论文所述,我们希望通过最小化 l o g ( 1 − D ( G ( z ) ) ) log(1 - D(G(z))) log(1D(G(z)))来训练生成器,以产生更好的虚假图像。

在这两个部分中,分别获取训练过程中的损失,并在每个周期结束时进行统计,将fixed_noise批量推送到生成器中,以直观地跟踪G的训练进度。

下面实现模型训练正向逻辑:

def generator_forward(real_imgs, valid):# 将噪声采样为发生器的输入z = ops.standard_normal((real_imgs.shape[0], nz, 1, 1))# 生成一批图像gen_imgs = generator(z)# 损失衡量发生器绕过判别器的能力g_loss = adversarial_loss(discriminator(gen_imgs), valid)return g_loss, gen_imgsdef discriminator_forward(real_imgs, gen_imgs, valid, fake):# 衡量鉴别器从生成的样本中对真实样本进行分类的能力real_loss = adversarial_loss(discriminator(real_imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs), fake)d_loss = (real_loss + fake_loss) / 2return d_lossgrad_generator_fn = ms.value_and_grad(generator_forward, None,optimizer_G.parameters,has_aux=True)
grad_discriminator_fn = ms.value_and_grad(discriminator_forward, None,optimizer_D.parameters)@ms.jit
def train_step(imgs):valid = ops.ones((imgs.shape[0], 1), mindspore.float32)fake = ops.zeros((imgs.shape[0], 1), mindspore.float32)(g_loss, gen_imgs), g_grads = grad_generator_fn(imgs, valid)optimizer_G(g_grads)d_loss, d_grads = grad_discriminator_fn(imgs, gen_imgs, valid, fake)optimizer_D(d_grads)return g_loss, d_loss, gen_imgs

循环训练网络,每经过50次迭代,就收集生成器和判别器的损失,以便于后面绘制训练过程中损失函数的图像。

import mindsporeG_losses = []
D_losses = []
image_list = []total = dataset.get_dataset_size()
for epoch in range(num_epochs):generator.set_train()discriminator.set_train()# 为每轮训练读入数据for i, (imgs, ) in enumerate(dataset.create_tuple_iterator()):g_loss, d_loss, gen_imgs = train_step(imgs)if i % 100 == 0 or i == total - 1:# 输出训练记录print('[%2d/%d][%3d/%d]   Loss_D:%7.4f  Loss_G:%7.4f' % (epoch + 1, num_epochs, i + 1, total, d_loss.asnumpy(), g_loss.asnumpy()))D_losses.append(d_loss.asnumpy())G_losses.append(g_loss.asnumpy())# 每个epoch结束后,使用生成器生成一组图片generator.set_train(False)fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))img = generator(fixed_noise)image_list.append(img.transpose(0, 2, 3, 1).asnumpy())# 保存网络模型参数为ckpt文件mindspore.save_checkpoint(generator, "./generator.ckpt")mindspore.save_checkpoint(discriminator, "./discriminator.ckpt")

结果展示

运行下面代码,描绘DG损失与训练迭代的关系图:

plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G", color='blue')
plt.plot(D_losses, label="D", color='orange')
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

可视化训练过程中通过隐向量fixed_noise生成的图像。

import matplotlib.pyplot as plt
import matplotlib.animation as animationdef showGif(image_list):show_list = []fig = plt.figure(figsize=(8, 3), dpi=120)for epoch in range(len(image_list)):images = []for i in range(3):row = np.concatenate((image_list[epoch][i * 8:(i + 1) * 8]), axis=1)images.append(row)img = np.clip(np.concatenate((images[:]), axis=0), 0, 1)plt.axis("off")show_list.append([plt.imshow(img)])ani = animation.ArtistAnimation(fig, show_list, interval=1000, repeat_delay=1000, blit=True)ani.save('./dcgan.gif', writer='pillow', fps=1)showGif(image_list)

dcgan

从上面的图像可以看出,随着训练次数的增多,图像质量也越来越好。如果增大训练周期数,当num_epochs达到50以上时,生成的动漫头像图片与数据集中的较为相似,下面我们通过加载生成器网络模型参数文件来生成图像,代码如下:

# 从文件中获取模型参数并加载到网络中
mindspore.load_checkpoint("./generator.ckpt", generator)fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))
img64 = generator(fixed_noise).transpose(0, 2, 3, 1).asnumpy()fig = plt.figure(figsize=(8, 3), dpi=120)
images = []
for i in range(3):images.append(np.concatenate((img64[i * 8:(i + 1) * 8]), axis=1))
img = np.clip(np.concatenate((images[:]), axis=0), 0, 1)
plt.axis("off")
plt.imshow(img)
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/46321.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据结构历年考研真题对应知识点(哈夫曼树和哈夫曼编码)

目录 5.5.1哈夫曼树和哈夫曼编码 1.哈夫曼树的定义 2.哈夫曼树的构造 【分析哈夫曼树的路径上权值序列的合法性(2010)】 【哈夫曼树的性质(2010、2019)】 3.哈夫曼编码 【根据哈夫曼编码对编码序列进行译码(201…

【AMBA】AXI总线中的AXLEN、AXSIZE、AXBURST和4K边界

文章目录 AXI传输层级概念AXLEN[7:0]定义突发传输长度AXSIZE[2:0]定义突发传输transfer的位宽AXBURST[1:0]定义突发传输类型4K边界 AXI传输层级概念 在手册的术语表中,与 AXI 传输相关的有三个概念,分别是 transfer(beat)、burst、transaction。用一句话…

树莓派关机

文件 shutdown.sh #!/usr/bin/bash sudo shutdown -r nowpython 文件开头添加 #!/usr/bin/python3

C字符串和内存函数介绍(三)——其他的字符串函数

在#include<string.h>的这个头文件里面&#xff0c;除了前面给大家介绍的两大类——长度固定的字符串函数和长度不固定的字符串函数。还有一些函数以其独特的用途占据一席之地。 今天要给大家介绍的是下面这三个字符串函数&#xff1a;strstr&#xff0c;strtok&#xf…

前端web性能统计

前端web性能统计 1. 背景2. 业界方案2.1 腾讯2.2 蚂蚁金服2.3 字节跳动2.4 美团 3. 相关观念3.1 RAIL模型3.2 性能指标3.3 真实用户监控3.4 performance 4. 性能监控工具介绍5. 推荐采用方案 1. 背景 在如今的数字时代&#xff0c;网站和应用程序的性能对用户体验至关重要。用…

STM32MP135裸机编程:唯一ID(UID)、设备标识号、设备版本

0 资料准备 1.STM32MP13xx参考手册1 唯一ID&#xff08;UID&#xff09;、设备标识号、设备版本 1.1 寄存器说明 &#xff08;1&#xff09;唯一ID 唯一ID可以用于生成USB序列号或者为其它应用所使用&#xff08;例如程序加密&#xff09;。 &#xff08;2&#xff09;设备…

Git代码管理工具 — 3 Git基本操作指令详解

目录 1 获取本地仓库 2 基础操作指令 2.1 基础操作指令框架 2.2 git status查看修改的状态 2.3 git add添加工作区到暂存区 2.4 提交暂存区到本地仓库 2.5 git log查看提交日志 2.6 git reflog查看已经删除的记录 2.7 git reset版本回退 2.8 添加文件至忽略列表 1 获…

0.单片机工作原理

文章目录 最小系统 单片机芯片 时钟电路 复位电路 电源 最小系统 单片机芯片 本次51单片机的芯片为&#xff1a;STC89C52 Flash(闪存)程序存储器&#xff1a;存储程序的空间 SRAM&#xff1a;数据存储器&#xff0c;可用于存放程序执行的中间结果和过程数据 DPTR&#xff1a;…

2024-07-14 Unity插件 Odin Inspector2 —— Essential Attributes

文章目录 1 说明2 重要特性2.1 AssetsOnly / SceneObjectsOnly2.2 CustomValueDrawer2.3 OnValueChanged2.4 DetailedInfoBox2.5 EnableGUI2.6 GUIColor2.7 HideLabel2.8 PropertyOrder2.9 PropertySpace2.10 ReadOnly2.11 Required2.12 RequiredIn&#xff08;*&#xff09;2.…

机器人相关工科专业课程体系

机器人相关工科专业课程体系 前言传统工科专业机械工程自动化/控制工程计算机科学与技术 新兴工科专业智能制造人工智能机器人工程 总结Reference: 前言 机器人工程专业是一个多领域交叉的前沿学科&#xff0c;涉及自然科学、工程技术、社会科学、人文科学等相关学科的理论、方…

Linux编程(三)—makefile快速编译

起因 linux环境下&#xff0c;编译c程序很麻烦&#xff0c;后面g -o demo demo.cpp ……往往跟了许多许多东西&#xff0c;这些每次编译的时候都要书写&#xff0c;所以就产生了makefile快速编译方式&#xff0c;具体操作如下。 怎么用makefile? 第一步&#xff1a;下载 m…

WPF学习(2) -- 样式基础

一、代码 <Window x:Class"学习.MainWindow"xmlns"http://schemas.microsoft.com/winfx/2006/xaml/presentation"xmlns:x"http://schemas.microsoft.com/winfx/2006/xaml"xmlns:d"http://schemas.microsoft.com/expression/blend/2008&…

MySQL 聚簇索引和非聚簇索引有什么区别?

聚簇索引&#xff08;主键索引&#xff09;、非聚簇索引&#xff08;二级索引&#xff09;。 这两者之间的最主要的区别是 B 树的叶子节点存放的内容不同&#xff1a; 聚簇索引的 B 树叶子节点存放的是主键值完整的记录&#xff1b;非聚簇索引的 B 树叶子节点存放的是索引值主…

【面试八股总结】C++内存管理:内存分区、内存泄漏、new和delete、malloc和free

参考资料&#xff1a;代码随想录、阿秀 一、内存分区 &#xff08;1&#xff09;栈区 在执行函数时&#xff0c;函数内部局部变量的存储单元都可以在栈上创建&#xff0c;函数执行结束时这些存储单元自动被释放。栈内存分配运算内置于处理器的指令集中&#xff0c;效率很高&am…

Postman下载及使用说明

Postman使用说明 Postman是什么&#xff1f; ​ Postman是一款接口对接工具【接口测试工具】 接口&#xff08;前端接口&#xff09;是什么&#xff1f; ​ 前端发送的请求普遍被称为接口 ​ 通常有网页的uri参数格式json/key-value请求方式post/get响应请求的格式json 接…

关闭Ubuntu烦人的apport

先来看让人绷不住的&#xff08;恼&#xff09; 我查半天apport是啥玩意发现就一错误报告弹窗&#xff0c;十秒钟给我弹一次一天给我内存弹爆了 就算我程序就算真的不停崩溃&#xff0c;也没你这傻比apport杀伤性强啊&#xff1f;&#xff1f;&#xff1f; 原则上是不建议关闭…

牛客周赛 Round 51 解题报告 | 珂学家

前言 题解 典题场&#xff0c; EF都有很多种解法 A. 小红的同余 性质: 相邻两数互质 x ( m 1 ) / 2 x (m1)/2 x(m1)/2 m int(input())print ((m 1) // 2)B. 小红的三倍数 性质: 各个位数之和是3的倍数&#xff0c;可被3整除 和数的组合顺序无关 n int(input()) arr…

MySQL高级面试点

Explain语句结果中各个字段分别代表什么 id&#xff1a;查询语句没出现一个select关键字&#xff0c;MySQL就会给他分配一个唯一id select_type&#xff1a; select关键字对应哪个查询的类型 simple&#xff1a;简单的查询 不包含任何子查询 primary&#xff1a;查询中如果…

网络安全设备——EDR

网络安全中的EDR&#xff08;Endpoint Detection and Response&#xff0c;端点检测与响应&#xff09;是一种主动式的端点安全解决方案&#xff0c;它专注于监控、检测和响应计算机和终端设备上的安全威胁。以下是EDR的详细解释&#xff1a; 一、定义与功能 EDR是一种网络安…

repo sync同步出错解决

当出现下面提示时 e list of known hosts. Fetching: 100% (1167/1167), done in 44.619s info: A new version of repo is available warning: repo is not tracking a remote branch, so it will not receive updates Repo command failed: RepoUnhandledExceptionError …