昇思25天学习打卡营第19天|生成式-DCGAN生成漫画头像

打卡

目录

打卡

GAN基础原理

DCGAN原理

案例说明

数据集操作

数据准备

数据处理和增强

部分训练数据的展示

构造网络

生成器

生成器代码

​编辑

判别器

判别器代码

模型训练

训练代码

结果展示(3 epoch)

模型推理


GAN基础原理

原理介绍参考 GAN图像生成 。

DCGAN原理

DCGAN(深度卷积对抗生成网络,Deep Convolutional Generative Adversarial Networks)是GAN的直接扩展。不同之处在于,DCGAN会分别在判别器和生成器中使用卷积和转置卷积层

判别器由分层的卷积层、BatchNorm层和LeakyReLU激活层组成。输入是3x64x64的图像,输出是该图像为真图像的概率。

生成器则是由转置卷积层、BatchNorm层和ReLU激活层组成。输入是标准正态分布中提取出的隐向量𝑧,输出是3x64x64的RGB图像。

Radford等人提出,论文:Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks

案例说明

目的:用动漫头像数据集来训练一个生成式对抗网络,使用该网络生成动漫头像图片。

数据集操作

数据准备

  • 使用的动漫头像数据集共有70,171张动漫头像图片,图片大小均为96*96。
  • 数据来源:https://download.mindspore.cn/dataset/Faces/faces.zip
from download import downloadurl = "https://download.mindspore.cn/dataset/Faces/faces.zip"path = download(url, "./faces", kind="zip", replace=True)

数据处理和增强

如尺度变换、裁剪、格式变换。

import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
import matplotlib.pyplot as pltbatch_size = 128          # 批量大小
image_size = 64           # 训练图像空间大小
nc = 3                    # 图像彩色通道数
nz = 100                  # 隐向量的长度
ngf = 64                  # 特征图在生成器中的大小
ndf = 64                  # 特征图在判别器中的大小
num_epochs = 3            # 训练周期数
lr = 0.0002               # 学习率
beta1 = 0.5               # Adam优化器的beta1超参数def create_dataset_imagenet(dataset_path):"""数据加载"""dataset = ds.ImageFolderDataset(dataset_path,num_parallel_workers=4,shuffle=True,decode=True)# 数据增强操作transforms = [vision.Resize(image_size),vision.CenterCrop(image_size),vision.HWC2CHW(),lambda x: ((x / 255).astype("float32"))]# 数据映射操作dataset = dataset.project('image')dataset = dataset.map(transforms, 'image')# 批量操作dataset = dataset.batch(batch_size)return datasetdef plot_data(data):# 可视化部分训练数据plt.figure(figsize=(10, 3), dpi=140)for i, image in enumerate(data[0][:30], 1):plt.subplot(3, 10, i)plt.axis("off")plt.imshow(image.transpose(1, 2, 0))plt.show()dataset = create_dataset_imagenet('./faces')## 可视化部分训练数据。
sample_data = next(dataset.create_tuple_iterator(output_numpy=True))
plot_data(sample_data)
部分训练数据的展示

构造网络

按照DCGAN论文中的描述,所有模型权重均应从mean为0,sigma为0.02的正态分布中随机初始化。

生成器

生成器 G 的功能:将隐向量z映射到数据空间。由于数据是图像,这一过程也会创建与真实图像大小相同的 RGB 图像。

通过输入部分中设置的nzngfnc来影响代码中的生成器结构。nz是隐向量z的长度,ngf与通过生成器传播的特征图的大小有关,nc是输出图像中的通道数。

生成器代码
import mindspore as ms
from mindspore import nn, ops
from mindspore.common.initializer import Normalweight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)class Generator(nn.Cell):"""DCGAN网络生成器"""def __init__(self):super(Generator, self).__init__()self.generator = nn.SequentialCell(nn.Conv2dTranspose(nz, ngf * 8, 4, 1, 'valid', weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 8, ngf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 4, ngf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 2, ngf, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf, nc, 4, 2, 'pad', 1, weight_init=weight_init),nn.Tanh())def construct(self, x):return self.generator(x)generator = Generator()

判别器

  • 判别器D:一个二分类网络模型,输出判定该图像为真实图的概率。通过一系列的Conv2dBatchNorm2dLeakyReLU层对其进行处理,最后通过Sigmoid激活函数得到最终概率。
  • DCGAN论文提到,使用卷积而不是通过池化来进行下采样可以让网络学习自己的池化特征。
判别器代码
class Discriminator(nn.Cell):"""DCGAN网络判别器"""def __init__(self):super(Discriminator, self).__init__()self.discriminator = nn.SequentialCell(nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, weight_init=weight_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf, ndf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 8, 1, 4, 1, 'valid', weight_init=weight_init),)self.adv_layer = nn.Sigmoid()def construct(self, x):out = self.discriminator(x)out = out.reshape(out.shape[0], -1)return self.adv_layer(out)discriminator = Discriminator()

模型训练

  • 使用MindSpore中定义的二进制交叉熵损失函数BCELoss。
  • 设置了两个单独的优化器,一个用于D,另一个用于G。这两个都是lr=0.0002 和 beta1 = 0.5 的Adam优化器。
  • 训练分为两个主要部分:训练判别器和训练生成器。
    • 训练判别器的目的是最大程度地提高判别图像真伪的概率。按照Good fellow的方法,是希望通过提高其随机梯度来更新判别器,所以我们要最大化𝑙𝑜𝑔𝐷(𝑥)+𝑙𝑜𝑔(1−𝐷(𝐺(𝑧)) 的值。
    • 如DCGAN论文所述,我们希望通过最小化𝑙𝑜𝑔(1−𝐷(𝐺(𝑧))) 来训练生成器,以产生更好的虚假图像。
    • 在这两个部分中,分别获取训练过程中的损失,并在每个周期结束时进行统计,将fixed_noise批量推送到生成器中,以直观地跟踪G的训练进度。

训练代码

# 定义损失函数
adversarial_loss = nn.BCELoss(reduction='mean')# 为生成器和判别器设置优化器
optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G.update_parameters_name('optim_g.')
optimizer_D.update_parameters_name('optim_d.')def generator_forward(real_imgs, valid):# 将噪声采样为发生器的输入z = ops.standard_normal((real_imgs.shape[0], nz, 1, 1))# 生成一批图像gen_imgs = generator(z)# 损失衡量发生器绕过判别器的能力g_loss = adversarial_loss(discriminator(gen_imgs), valid)return g_loss, gen_imgsdef discriminator_forward(real_imgs, gen_imgs, valid, fake):# 衡量鉴别器从生成的样本中对真实样本进行分类的能力real_loss = adversarial_loss(discriminator(real_imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs), fake)d_loss = (real_loss + fake_loss) / 2return d_lossgrad_generator_fn = ms.value_and_grad(generator_forward, None,optimizer_G.parameters,has_aux=True)
grad_discriminator_fn = ms.value_and_grad(discriminator_forward, None,optimizer_D.parameters)@ms.jit
def train_step(imgs):valid = ops.ones((imgs.shape[0], 1), mindspore.float32)fake = ops.zeros((imgs.shape[0], 1), mindspore.float32)(g_loss, gen_imgs), g_grads = grad_generator_fn(imgs, valid)optimizer_G(g_grads)d_loss, d_grads = grad_discriminator_fn(imgs, gen_imgs, valid, fake)optimizer_D(d_grads)return g_loss, d_loss, gen_imgsimport mindsporeG_losses = []
D_losses = []
image_list = []### 训练步骤
total = dataset.get_dataset_size()
for epoch in range(num_epochs):generator.set_train()discriminator.set_train()# 为每轮训练读入数据for i, (imgs, ) in enumerate(dataset.create_tuple_iterator()):g_loss, d_loss, gen_imgs = train_step(imgs)if i % 100 == 0 or i == total - 1:# 输出训练记录print('[%2d/%d][%3d/%d]   Loss_D:%7.4f  Loss_G:%7.4f' % (epoch + 1, num_epochs, i + 1, total, d_loss.asnumpy(), g_loss.asnumpy()))D_losses.append(d_loss.asnumpy())G_losses.append(g_loss.asnumpy())# 每个epoch结束后,使用生成器生成一组图片generator.set_train(False)fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))img = generator(fixed_noise)image_list.append(img.transpose(0, 2, 3, 1).asnumpy())# 保存网络模型参数为ckpt文件mindspore.save_checkpoint(generator, "./generator.ckpt")mindspore.save_checkpoint(discriminator, "./discriminator.ckpt")## 描绘D和G损失与训练迭代的关系图
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G", color='blue')
plt.plot(D_losses, label="D", color='orange')
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()## 可视化训练过程中通过隐向量fixed_noise生成的图像。
import matplotlib.pyplot as plt
import matplotlib.animation as animationdef showGif(image_list):show_list = []fig = plt.figure(figsize=(8, 3), dpi=120)for epoch in range(len(image_list)):images = []for i in range(3):row = np.concatenate((image_list[epoch][i * 8:(i + 1) * 8]), axis=1)images.append(row)img = np.clip(np.concatenate((images[:]), axis=0), 0, 1)plt.axis("off")show_list.append([plt.imshow(img)])ani = animation.ArtistAnimation(fig, show_list, interval=1000, repeat_delay=1000, blit=True)ani.save('./dcgan.gif', writer='pillow', fps=1)showGif(image_list)

训练过程图示:

结果展示(3 epoch)

1、描绘DG损失与训练迭代的关系图。

2、可视化训练过程中通过隐向量fixed_noise生成的图像。

随着训练次数的增多,图像质量也越来越好。如果增大训练周期数,当num_epochs达到50以上时,生成的动漫头像图片与数据集中的较为相似。

模型推理

通过加载生成器网络模型参数文件来生成图。

# 从文件中获取模型参数并加载到网络中
mindspore.load_checkpoint("./generator.ckpt", generator)fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))
img64 = generator(fixed_noise).transpose(0, 2, 3, 1).asnumpy()fig = plt.figure(figsize=(8, 3), dpi=120)
images = []
for i in range(3):images.append(np.concatenate((img64[i * 8:(i + 1) * 8]), axis=1))
img = np.clip(np.concatenate((images[:]), axis=0), 0, 1)
plt.axis("off")
plt.imshow(img)
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/48403.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C#实战 | 天行健、上下而求索

本文介绍C#开发入门案例。 01、项目一:创建控制台应用“天行健,君子以自强不息” 项目说明: 奋斗是中华民族的底色,见山开山,遇水架桥,正是因为自强不息的奋斗,才有了辉煌灿烂的中华民族。今…

xmind--如何快速将Excel表中多列数据,复制到XMind分成多级主题

每次要将表格中的数据分成多级时,只能复制粘贴吗 快来试试这个简易的方法吧 这个是原始的表格,分成了4级 步骤: 1、我们可以先按照这个层级设置下空列(后买你会用到这个空列) 二级不用加、三级前面加一列、四级前面加…

MAT使用

概念 Shallow heap & Retained Heap Shallow Heap就是对象本身占用内存的大小。 Retained Heap就是当前对象被GC后,从Heap上总共能释放掉的内存(表示如果一个对象被释放掉,那会因为该对象的释放而减少引用进而被释放的所有的对象(包括…

【React】JSX 实现列表渲染

文章目录 一、基础语法1. 使用 map() 方法2. key 属性的使用 二、常见错误和注意事项1. 忘记使用 key 属性2. key 属性的选择 三、列表渲染的高级用法1. 渲染嵌套列表2. 条件渲染列表项3. 动态生成组件 四、最佳实践 在 React 开发中,列表渲染是一个非常常见的需求。…

【多模态】CLIP-KD: An Empirical Study of CLIP Model Distillation

论文:CLIP-KD: An Empirical Study of CLIP Model Distillation 链接:https://arxiv.org/pdf/2307.12732 CVPR 2024 Introduction Motivation:使用大的Teacher CLIP模型有监督蒸馏小CLIP模型,出发点基于在资源受限的应用中&…

【网络】tcp_socket

tcp_socket 一、tcp_server与udp_server一样的部分二、listen接口(监听)三、accept接收套接字1、为什么还要多一个套接字(明明已经有了个socket套接字文件了,为什么要多一个accept套接字文件?)2、底层拿到新…

从R-CNN到Faster-R-CNN的简单介绍

1、R-CNN RCNN算法4个步骤 1、一张图像生成1K~2K个候选区域(使用Selective Search方法) 2、对每个候选区域,使用深度网络提取特征 3、特征送入每一类的SVM分类器,判别是否属于该类 4、使用回归器精细修正候选框位置 R-CNN 缺陷 : 1.训练…

Java使用AsposePDF和AsposeWords进行表单填充

声明:本文为作者Huathy原创文章,禁止转载、爬取!否则,本人将保留追究法律责任的权力! 文章目录 AsposePDF填充表单adobe pdf表单准备引入依赖编写测试类 AsposeWord表单填充表单模板准备与生成效果引入依赖编码 参考文…

【C语言】链式队列的实现

队列基本概念 首先我们要了解什么是队列,队列里面包含什么。 队列是线性表的一种是一种先进先出(First In Fi Out)的数据结构。在需要排队的场景下有很强的应用性。有数组队列也有链式队列,数组实现的队列时间复杂度太大&#x…

【数据结构 | 哈希表】一文了解哈希表(散列表)

😁博客主页😁:🚀https://blog.csdn.net/wkd_007🚀 🤑博客内容🤑:🍭嵌入式开发、Linux、C语言、C、数据结构、音视频🍭 🤣本文内容🤣&a…

虚拟局域网配置与分析-VLAN

前言:本博客仅作记录学习使用,部分图片出自网络,如有侵犯您的权益,请联系删除 一、相关知识 虚拟局域网(Virtual Local Area Network,VLAN)是一组逻辑上的设备和用户;不受物理位置的…

浅谈监听器之保存响应到文件

浅谈监听器之保存响应到文件 JMeter 提供了一个实用的监听器——“保存响应到文件”,该监听器能够自动将取样器的响应数据直接保存到指定的文件中,便于后续分析或存档。本文档旨在详细介绍如何配置和使用此监听器功能。 适用场景 ● 长时间运行的测试…

昇思25天学习打卡营第n天|本地安装mindspore之二|开始第一课的代码。以及对比xshell,MobaXterm

开始准备在本地的系统上跑例子了。从第一课开始吧。 1,下载代码 打开课程。 下载样例代码 https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.3/tutorials/zh_cn/beginner/mindspore_quick_start.py 2,在本地Linux上输入并运…

Python新手如何制作植物大战僵尸?这篇文章教会你!

引言 《植物大战僵尸》是一款非常受欢迎的塔防游戏,玩家需要种植各种植物来抵御僵尸的进攻。在这篇文章中,我们将使用Python编写一个简化版的植物大战僵尸游戏,以展示如何使用Python创建游戏。 游戏规则 玩家将种植不同类型的植物来防御僵尸…

好用的电脑录屏软件免费推荐,拥有这3款就能高效录屏!

电脑录屏软件已成为我们记录生活、分享知识的得力助手。但是,市面上琳琅满目的录屏软件令人眼花缭乱,如何才能选择到一款好用的电脑录屏软件免费神器呢?今天,就让我来为您揭晓这个秘密! 首先,我们得明确一…

胖东来也要加入“打水仗”?瓶装水品牌又该如何出招

今年瓶装水行业的“战场”似乎格外热闹,比武汉的天气好像还要火热......从年头农夫山泉打出“纯净水”的牌,再到如今掀起价格内卷战,一箱12瓶的纯净水在某宝平台上仅售9.9元,平均下来每瓶单价不超一元,农夫山泉都出击了…

自动化网络爬虫:如何它成为提升数据收集效率的终极武器?

摘要 本文深入探讨了自动化网络爬虫技术如何彻底改变数据收集领域的游戏规则,揭示其作为提升工作效率的终极工具的奥秘。通过分析其工作原理、优势及实际应用案例,我们向读者展示了如何利用这一强大工具加速业务决策过程,同时保持数据收集的…

5G mmWave PAAM 开发平台

Avnet-Fujikura-AMD 5G 毫米波相控阵天线模块开发平台 Avnet 和 Fujikura 为毫米波频段创建了一个领先的 5G FR2 相控阵天线开发平台。该平台使开发人员能够使用 AMD Xilinx 的 Zynq UltraScale™ RFSoC Gen3 和 Fujikura 的 FutureAcess™ 相控阵天线模块 (PAAM) 快速创建和制…

算法日记day 18(二叉树的所有路径|左叶子之和)

一、二叉树的所有路径 题目: 给你一个二叉树的根节点 root ,按 任意顺序 ,返回所有从根节点到叶子节点的路径。 叶子节点 是指没有子节点的节点。 示例 1: 输入:root [1,2,3,null,5] 输出:["1->…

抖音矩阵管理系统解决方案:一站式服务

在当今社交媒体蓬勃发展的时代,抖音作为一款短视频平台,凭借其独特的魅力和庞大的用户群体,已成为众多企业、个人乃至网红达人展示自我、推广品牌的重要舞台。然而,随着抖音账号数量的不断增加,如何高效、专业地管理这…