昇思MindSpore学习笔记4-02生成式--DCGAN生成漫画头像

摘要:

        记录了昇思MindSpore AI框架使用70171张动漫头像图片训练一个DCGAN神经网络生成式对抗网络,并用来生成漫画头像的过程、步骤。包括环境准备、下载数据集、加载数据和预处理、构造网络、模型训练等。

一、概念

深度卷积对抗生成网络DCGAN

Deep Convolutional Generative Adversarial Networks

        扩展GAN

        判别器

                组成

                        卷积层

                        BatchNorm层

                        LeakyReLU激活层

                功能

                        输入是3*64*64图像

                        输出是真图像概率

        生成器

                组成

                        转置卷积层

                        BatchNorm层

                        ReLU激活层

                功能

                        输入是标准正态分布中提取出的隐向量z

                        输出是3*64*64 RGB图像。

  • 环境准备
%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14

三、数据准备与处理

1.下载数据集

下载到指定目录下并解压代码如下

from download import download
url = "https://download.mindspore.cn/dataset/Faces/faces.zip"
path = download(url, "./faces", kind="zip", replace=True)

输出:

Downloading data from https://download-mindspore.osinfra.cn/dataset/Faces/faces.zip (274.6 MB)file_sizes: 100%|████████████████████████████| 288M/288M [00:52<00:00, 5.49MB/s]
Extracting zip file...
Successfully downloaded / unzipped to ./faces

2.数据集介绍

使用的动漫头像数据集共有70,171张动漫头像图片,图片大小均为96*96。

数据集目录结构如下:

./faces/faces
├── 0.jpg
├── 1.jpg
├── 2.jpg
├── 3.jpg
├── 4.jpg...
├── 70169.jpg
└── 70170.jpg

3.数据处理

(1) 执行过程参数定义:

batch_size = 128          # 批量大小
image_size = 64           # 训练图像空间大小
nc = 3                    # 图像彩色通道数
nz = 100                  # 隐向量的长度
ngf = 64                  # 特征图在生成器中的大小
ndf = 64                  # 特征图在判别器中的大小
num_epochs = 3            # 训练周期数
lr = 0.0002               # 学习率
beta1 = 0.5               # Adam优化器的beta1超参数

(2) 数据处理和增强

create_dataset_imagenet函数

import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
​
def create_dataset_imagenet(dataset_path):"""数据加载"""dataset = ds.ImageFolderDataset(dataset_path,num_parallel_workers=4,shuffle=True,decode=True)
​# 数据增强操作transforms = [vision.Resize(image_size),vision.CenterCrop(image_size),vision.HWC2CHW(),lambda x: ((x / 255).astype("float32"))]
​# 数据映射操作dataset = dataset.project('image')dataset = dataset.map(transforms, 'image')
​# 批量操作dataset = dataset.batch(batch_size)return dataset
​
dataset = create_dataset_imagenet('./faces')

(3) 查看训练数据

matplotlib模块

数据转换成字典迭代器

        create_dict_iterator函数

import matplotlib.pyplot as plt
​
def plot_data(data):# 可视化部分训练数据plt.figure(figsize=(10, 3), dpi=140)for i, image in enumerate(data[0][:30], 1):plt.subplot(3, 10, i)plt.axis("off")plt.imshow(image.transpose(1, 2, 0))plt.show()
​
sample_data = next(dataset.create_tuple_iterator(output_numpy=True))
plot_data(sample_data)

四、构造网络

模型权重随机初始化

范围:mean为0,sigma为0.02的正态分布【数学不好】

1. 生成器

生成器G

        隐向量z映射数据空间

        数据源是图像

        生成与图像大小相同的 RGB 图像

        Conv2dTranspose转置卷积层

        每个层与BatchNorm2d层和ReLu激活层配对

        tanh函数

        输出[-1,1]范围内数据

DCGAN生成图像过程如下所示:

生成器结构参数:

        nz         隐向量z的长度

        ngf         有关生成器传播的特征图大小

        nc         输出图像通道数

生成器代码:

import mindspore as ms
from mindspore import nn, ops
from mindspore.common.initializer import Normal
​
weight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)
​
class Generator(nn.Cell):"""DCGAN网络生成器"""
​def __init__(self):super(Generator, self).__init__()self.generator = nn.SequentialCell(nn.Conv2dTranspose(nz, ngf * 8, 4, 1, 'valid', weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 8, ngf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 4, ngf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 2, ngf, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf, nc, 4, 2, 'pad', 1, weight_init=weight_init),nn.Tanh())
​def construct(self, x):return self.generator(x)
​
generator = Generator()

2. 判别器

判别器D

        二分类网络模型

                Conv2d

                BatchNorm2d

                LeakyReLU

                Sigmoid激活函数

                输出判定图像真实概率

判别器代码:

class Discriminator(nn.Cell):"""DCGAN网络判别器"""
​def __init__(self):super(Discriminator, self).__init__()self.discriminator = nn.SequentialCell(nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, weight_init=weight_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf, ndf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 8, 1, 4, 1, 'valid', weight_init=weight_init),)self.adv_layer = nn.Sigmoid()
​def construct(self, x):out = self.discriminator(x)out = out.reshape(out.shape[0], -1)return self.adv_layer(out)
​
discriminator = Discriminator()

五、模型训练

1. 损失函数

二进制交叉熵损失函数MindSpore.nn.BCELoss

# 定义损失函数
adversarial_loss = nn.BCELoss(reduction='mean')

2. 优化器

Adam优化器

        lr = 0.0002

        beta1 = 0.5

# 为生成器和判别器设置优化器
optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G.update_parameters_name('optim_g.')
optimizer_D.update_parameters_name('optim_d.')

3. 训练模型

训练判别器

        提高判别图像真伪的概率

        Goodfellow方法:提高随机梯度更新判别器

        最大化logD(x)+log(1-D(G(z)))

训练生成器

        最小化log(1−D(G(z)))

        产生更好的虚拟图像

两个部分分别

        获取训练损失

        每个周期结束统计

        批量推送fixed_noise到生成器

        跟踪G的训练进度

模型训练正向逻辑:

def generator_forward(real_imgs, valid):# 将噪声采样为发生器的输入z = ops.standard_normal((real_imgs.shape[0], nz, 1, 1))
​# 生成一批图像gen_imgs = generator(z)
​# 损失衡量发生器绕过判别器的能力g_loss = adversarial_loss(discriminator(gen_imgs), valid)
​return g_loss, gen_imgs
​
def discriminator_forward(real_imgs, gen_imgs, valid, fake):# 衡量鉴别器从生成的样本中对真实样本进行分类的能力real_loss = adversarial_loss(discriminator(real_imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs), fake)d_loss = (real_loss + fake_loss) / 2return d_loss
​
grad_generator_fn = ms.value_and_grad(generator_forward, None,optimizer_G.parameters,has_aux=True)
grad_discriminator_fn = ms.value_and_grad(discriminator_forward, None,optimizer_D.parameters)
​
@ms.jit
def train_step(imgs):valid = ops.ones((imgs.shape[0], 1), mindspore.float32)fake = ops.zeros((imgs.shape[0], 1), mindspore.float32)
​(g_loss, gen_imgs), g_grads = grad_generator_fn(imgs, valid)optimizer_G(g_grads)d_loss, d_grads = grad_discriminator_fn(imgs, gen_imgs, valid, fake)optimizer_D(d_grads)
​return g_loss, d_loss, gen_imgs

循环训练网络

        迭代50次收集生成器、判别器的损失一次

        绘制损失函数的图像

import mindspore
​
G_losses = []
D_losses = []
image_list = []
​
total = dataset.get_dataset_size()
for epoch in range(num_epochs):generator.set_train()discriminator.set_train()# 为每轮训练读入数据for i, (imgs, ) in enumerate(dataset.create_tuple_iterator()):g_loss, d_loss, gen_imgs = train_step(imgs)if i % 100 == 0 or i == total - 1:# 输出训练记录print('[%2d/%d][%3d/%d]   Loss_D:%7.4f  Loss_G:%7.4f' % (epoch + 1, num_epochs, i + 1, total, d_loss.asnumpy(), g_loss.asnumpy()))D_losses.append(d_loss.asnumpy())G_losses.append(g_loss.asnumpy())
​# 每个epoch结束后,使用生成器生成一组图片generator.set_train(False)fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))img = generator(fixed_noise)image_list.append(img.transpose(0, 2, 3, 1).asnumpy())
​# 保存网络模型参数为ckpt文件mindspore.save_checkpoint(generator, "./generator.ckpt")mindspore.save_checkpoint(discriminator, "./discriminator.ckpt")

输出:

[ 1/3][  1/549]   Loss_D: 0.2635  Loss_G: 4.8150
[ 1/3][101/549]   Loss_D: 0.4023  Loss_G: 4.9807
[ 1/3][201/549]   Loss_D: 0.2425  Loss_G: 1.6335
[ 1/3][301/549]   Loss_D: 0.5856  Loss_G: 0.6079
[ 1/3][401/549]   Loss_D: 0.1922  Loss_G: 4.3977
[ 1/3][501/549]   Loss_D: 0.1065  Loss_G: 2.3724
[ 1/3][549/549]   Loss_D: 0.1893  Loss_G: 1.6483
[ 2/3][  1/549]   Loss_D: 0.3370  Loss_G: 4.4347
[ 2/3][101/549]   Loss_D: 0.4681  Loss_G: 0.8623
[ 2/3][201/549]   Loss_D: 0.1856  Loss_G: 3.7501
[ 2/3][301/549]   Loss_D: 0.1932  Loss_G: 2.6333
[ 2/3][401/549]   Loss_D: 0.1310  Loss_G: 2.2524
[ 2/3][501/549]   Loss_D: 0.2531  Loss_G: 1.4690
[ 2/3][549/549]   Loss_D: 0.1192  Loss_G: 5.7166
[ 3/3][  1/549]   Loss_D: 0.0716  Loss_G: 2.9886
[ 3/3][101/549]   Loss_D: 0.1345  Loss_G: 2.6544
[ 3/3][201/549]   Loss_D: 0.1097  Loss_G: 2.8604
[ 3/3][301/549]   Loss_D: 0.2066  Loss_G: 6.1513
[ 3/3][401/549]   Loss_D: 0.0797  Loss_G: 3.2336
[ 3/3][501/549]   Loss_D: 0.2618  Loss_G: 4.0991
[ 3/3][549/549]   Loss_D: 0.5600  Loss_G:10.7509

4. 结果展示

描绘D和G损失与训练迭代的关系图:

plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G", color='blue')
plt.plot(D_losses, label="D", color='orange')
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

输出:

显示隐向量fixed_noise训练生成的图像

import matplotlib.pyplot as plt
import matplotlib.animation as animation
​
def showGif(image_list):show_list = []fig = plt.figure(figsize=(8, 3), dpi=120)for epoch in range(len(image_list)):images = []for i in range(3):row = np.concatenate((image_list[epoch][i * 8:(i + 1) * 8]), axis=1)images.append(row)img = np.clip(np.concatenate((images[:]), axis=0), 0, 1)plt.axis("off")show_list.append([plt.imshow(img)])
​ani = animation.ArtistAnimation(fig, show_list, interval=1000, repeat_delay=1000, blit=True)ani.save('./dcgan.gif', writer='pillow', fps=1)
​
showGif(image_list)

输出:

训练次数增多,图像质量越好

num_epochs达到50以上,生成动漫头像图片与数据集较为相似

加载生成器网络模型参数文件来生成图像代码:

# 从文件中获取模型参数并加载到网络中
mindspore.load_checkpoint("./generator.ckpt", generator)
​
fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))
img64 = generator(fixed_noise).transpose(0, 2, 3, 1).asnumpy()
​
fig = plt.figure(figsize=(8, 3), dpi=120)
images = []
for i in range(3):images.append(np.concatenate((img64[i * 8:(i + 1) * 8]), axis=1))
img = np.clip(np.concatenate((images[:]), axis=0), 0, 1)
plt.axis("off")
plt.imshow(img)
plt.show()

输出:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/41464.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

STM32+ESP8266(ESP32)+MQTT+阿里云物联网平台

1、阿里云物联网平台 - 阿里云物联网平台配置 产品-设备-编辑物模型-设备端开发-查看上报数据 在产品上添加物模型&#xff0c;然后设备是继承自产品的&#xff0c;因此也具有物模型 添加产品、添加设备、产品上添加物模型 - 使用IOT Studio 绘制界面显示温度、湿度、灯开…

Tomcat(+Servlet)笔记+代码

Tomcat安装和配置 安装在不含中文的路径&#xff0c;路径不能太长 Apache 官网&#x1f447; Apache Tomcat - Welcome! 配置部分 点击下图红框处&#xff0c;找到Tomcat安装位置 添加项目的文件 配好的话&#xff0c;红框这里有个猫 代码部分 新建jsp文件&#xff0c;里…

线程(基础概念)

文章目录 一、线程和进程&#xff1f;二、线程初识2.1 线程属性2.2 线程的调度策略2.3 线程的优先级2.3 线程实验 一、线程和进程&#xff1f; 我们经常描述进程&#xff08;process&#xff09;和线程&#xff08;thread&#xff09;&#xff1a; 进程是资源管理的最小单位&a…

昇思25天学习打卡营第07天 | 函数式自动微分

昇思25天学习打卡营第07天 | 函数式自动微分 文章目录 昇思25天学习打卡营第07天 | 函数式自动微分函数与计算图微分函数与梯度Stop GradientAuxiliary data 神经网络梯度计算总结打卡 神经网络的训练主要使用反向传播算法&#xff0c;首先计算模型预测值&#xff08;logits&am…

科普文:微服务之服务网格Service Mesh

一、ServiceMesh概念 背景 随着业务的发展&#xff0c;传统单体应用的问题越来越严重&#xff1a; 单体应用代码库庞大&#xff0c;不易于理解和修改持续部署困难&#xff0c;由于单体应用各组件间依赖性强&#xff0c;只要其中任何一个组件发生更改&#xff0c;将重新部署整…

MUNIK解读ISO26262--什么是DFA

我们在学习功能安全过程中&#xff0c;经常会听到很多安全分析方法&#xff0c;有我们熟知的FMEA(Failure Modes Effects Analysis)和FTA(Fault Tree Analysis)还有功能安全产品设计中几乎绕不开的FMEDA(Failure Modes Effects and Diagnostic Analysis)&#xff0c;相比于它们…

【OceanBase】OBProxy 无状态的理解

SueWakeup 个人主页&#xff1a;SueWakeup 系列专栏&#xff1a;为祖国的科技进步添砖Java 个性签名&#xff1a;保留赤子之心也许是种幸运吧 本文封面由 凯楠&#x1f4f8;友情提供 目录 前言 OBProxy 无状态的概述 OBProxy 无状态特性带来的优点 1. 高可用 2. 负载均衡…

2024最新版Redis常见面试题包含详细讲解

Redis适用于哪些场景&#xff1f; 缓存分布式锁降级限流消息队列延迟消息队 说一说缓存穿透 缓存穿透的概念 用户频繁的发起恶意请求查询缓存中和数据库中都不存在的数据&#xff0c;查询积累到一定量级导致数据库压力过大甚至宕机。 缓存穿透的原因 比如正常情况下用户发…

C++基础22 字符串与字符数组及其相关操作

这是《C算法宝典》C基础篇的第22节文章啦~ 如果你之前没有太多C基础&#xff0c;请点击&#x1f449;C基础&#xff0c;如果你C语法基础已经炉火纯青&#xff0c;则可以进阶算法&#x1f449;专栏&#xff1a;算法知识和数据结构&#x1f449;专栏&#xff1a;数据结构啦 ​ 目…

蓝牙传输技术的演进与发展

蓝牙模块技术&#xff0c;作为无线通信领域的重要一员&#xff0c;自其诞生之初便受到了广泛的关注和应用。随着技术的不断发展和演进&#xff0c;蓝牙模块技术已经从最初的单一功能、有限传输速度发展到现在的多功能、高速率、低功耗&#xff0c;为人们的生活和工作带来了极大…

信创-系统架构师认证

随着国家对信息技术自主创新的战略重视程度不断提升&#xff0c;信创产业迎来前所未有的发展机遇。未来几年内&#xff0c;信创产业将呈现市场规模扩大、技术创新加速、产业链完善和国产化替代加速的趋势。信创人才培养对于推动产业发展具有重要意义。应加强高校教育、建立人才…

NXP i.MX8系列平台开发讲解 - 3.18 Linux tty子系统介绍(一)

专栏文章目录传送门&#xff1a;返回专栏目录 Hi, 我是你们的老朋友&#xff0c;主要专注于嵌入式软件开发&#xff0c;有兴趣不要忘记点击关注【码思途远】 目录 1. TTY 起源 2. Linux 系统中的TTY 2.1 Linux TTY 设备形式 2.2 Linux TTY framework 2.3 驱动核心相关文件…

零基础入门怎么学习老挝语字母表?《老挝语翻译通》App真人发音教学,学习老挝语字母发音和词汇句子!

这段老挝文字翻译成中文是什么意思&#xff1f;有什么好用的老挝语翻译工具推荐吗&#xff1f; 快速翻译&#xff1a;中老语言无缝转换&#xff0c;实时翻译&#xff0c;让沟通更流畅。 学习工具&#xff1a;零基础入门到流利对话&#xff0c;老挝语真人发音&#xff0c;让你的…

MaxKB开源知识库问答系统发布v1.3.0版本,新增强大的工作流引擎

2024年4月12日&#xff0c;1Panel开源项目组正式发布官方开源子项目——MaxKB开源知识库问答系统&#xff08;github.com/1Panel-dev/MaxKB&#xff09;。MaxKB开源项目发布后迅速获得了社区用户的认可&#xff0c;成功登顶GitHub Trending趋势榜主榜。 截至2024年7月4日&…

docker仓库--centos7.9部署harbor详细过程与使用以及常见问题

文章目录 前言1.docker-compose是什么2.harbor是什么 centos7部署harbor详细过程与使用环境一、部署docker二、部署harbor1.下载docker-compose工具2.harbor安装3.拷贝样本文件&#xff0c;并修改文件4.安装harbor&#xff0c;安装完成自行启动5.查看 三、harbor的使用1.创建项…

Https网站如何申请免费的SSL证书及操作使用指南

前言 在当今互联网环境下&#xff0c;HTTPS已成为网站安全的标配&#xff0c;它通过SSL/TLS协议为网站数据传输提供加密&#xff0c;保障用户信息的安全。申请并部署免费SSL证书&#xff0c;不仅能够提升网站的专业形象&#xff0c;还能增强用户信任。本文将详细介绍如何在知名…

StreamSets: 数据采集工具详解

欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;欢迎订阅相关专栏&#xff1a; 欢迎关注微信公众号&#xff1a;野老杂谈 ⭐️ 全网最全IT互联网公司面试宝典&#xff1a;收集整理全网各大IT互联网公司技术、项目、HR面试真题. ⭐️ AIGC时代的创新与未来&a…

Golang语法规范和风格指南(一)——简单指南

1. 前引 一个语言的规范的学习是重要的&#xff0c;直接关系到你的代码是否易于维护和理解&#xff0c;同时学习好对应的语言规范可以在前期学习阶段有效规避该语言语法和未知编程风格的冲突。 这里是 Google 提供的规范&#xff0c;有助于大家在开始学习阶段对 Golang 进行一…

Tensorflow入门实战 T07-Vgg16网络进行咖啡豆识别

本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 | 接辅导、项目定制 1、 前言 这周学习的主要内容是&#xff0c;使用tensorflow编写代码&#xff0c;使用vgg-16网络模型&#xff0c;完成咖啡豆的识别。 2、完整代码 imp…

【密码学基础】对随机不经意传输(Random Oblivious Transfer)的理解

ROT在offline阶段生成大量的OT对&#xff0c;在online阶段通过one-pad方式高效加密&#xff0c;并且只需要简单的异或运算就能实现OT过程&#xff08;去随机化&#xff09;。 在ROT中&#xff0c;有一个关键点是&#xff1a;需要考虑offline阶段的选择比特和online阶段的选择比…