昇思25天学习打卡营第16天 | DCGAN生成漫画头像

这两天把minspore配置到我的电脑上了,然后运行就没什么问题了✨😊

今天学这个DCGAN生成漫画头像,我超级感兴趣的嘞🦄🥰

GAN基础原理

这部分原理介绍参考GAN图像生成。

DCGAN原理

DCGAN(深度卷积对抗生成网络,Deep Convolutional Generative Adversarial Networks)是GAN的直接扩展。不同之处在于,DCGAN会分别在判别器和生成器中使用卷积和转置卷积层。

它最早由Radford等人在论文Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks中进行描述。判别器由分层的卷积层、BatchNorm层和LeakyReLU激活层组成。输入是3x64x64的图像,输出是该图像为真图像的概率。生成器则是由转置卷积层、BatchNorm层和ReLU激活层组成。输入是标准正态分布中提取出的隐向量𝑧𝑧,输出是3x64x64的RGB图像。

本教程将使用动漫头像数据集来训练一个生成式对抗网络,接着使用该网络生成动漫头像图片。

数据准备与处理

from download import downloadurl = "https://download.mindspore.cn/dataset/Faces/faces.zip"path = download(url, "./faces", kind="zip", replace=True)

数据处理

batch_size = 128          # 批量大小
image_size = 64           # 训练图像空间大小
nc = 3                    # 图像彩色通道数
nz = 100                  # 隐向量的长度
ngf = 64                  # 特征图在生成器中的大小
ndf = 64                  # 特征图在判别器中的大小
num_epochs = 3           # 训练周期数
lr = 0.0002               # 学习率
beta1 = 0.5               # Adam优化器的beta1超参数import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision as visiondef create_dataset_imagenet(dataset_path):"""数据加载"""dataset = ds.ImageFolderDataset(dataset_path,num_parallel_workers=4,shuffle=True,decode=True)# 数据增强操作transforms = [vision.Resize(image_size),vision.CenterCrop(image_size),vision.HWC2CHW(),lambda x: ((x / 255).astype("float32"))]# 数据映射操作dataset = dataset.project('image')dataset = dataset.map(transforms, 'image')# 批量操作dataset = dataset.batch(batch_size)return datasetdataset = create_dataset_imagenet('./faces')import matplotlib.pyplot as pltdef plot_data(data):# 可视化部分训练数据plt.figure(figsize=(10, 3), dpi=140)for i, image in enumerate(data[0][:30], 1):plt.subplot(3, 10, i)plt.axis("off")plt.imshow(image.transpose(1, 2, 0))plt.show()sample_data = next(dataset.create_tuple_iterator(output_numpy=True))
plot_data(sample_data)

构造网络

当处理完数据后,就可以来进行网络的搭建了。按照DCGAN论文中的描述,所有模型权重均应从mean为0,sigma为0.02的正态分布中随机初始化。

生成器

生成器G的功能是将隐向量z映射到数据空间。由于数据是图像,这一过程也会创建与真实图像大小相同的 RGB 图像。在实践场景中,该功能是通过一系列Conv2dTranspose转置卷积层来完成的,每个层都与BatchNorm2d层和ReLu激活层配对,输出数据会经过tanh函数,使其返回[-1,1]的数据范围内。

DCGAN论文生成图像如下所示:

dcgangenerator

我们通过输入部分中设置的nzngfnc来影响代码中的生成器结构。nz是隐向量z的长度,ngf与通过生成器传播的特征图的大小有关,nc是输出图像中的通道数。

以下是生成器的代码实现:

import mindspore as ms
from mindspore import nn, ops
from mindspore.common.initializer import Normalweight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)class Generator(nn.Cell):"""DCGAN网络生成器"""def __init__(self):super(Generator, self).__init__()self.generator = nn.SequentialCell(nn.Conv2dTranspose(nz, ngf * 8, 4, 1, 'valid', weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 8, ngf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 4, ngf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 2, ngf, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf, nc, 4, 2, 'pad', 1, weight_init=weight_init),nn.Tanh())def construct(self, x):return self.generator(x)generator = Generator()

判别器

如前所述,判别器D是一个二分类网络模型,输出判定该图像为真实图的概率。通过一系列的Conv2dBatchNorm2dLeakyReLU层对其进行处理,最后通过Sigmoid激活函数得到最终概率。

DCGAN论文提到,使用卷积而不是通过池化来进行下采样是一个好方法,因为它可以让网络学习自己的池化特征。

判别器的代码实现如下:

class Discriminator(nn.Cell):"""DCGAN网络判别器"""def __init__(self):super(Discriminator, self).__init__()self.discriminator = nn.SequentialCell(nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, weight_init=weight_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf, ndf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 8, 1, 4, 1, 'valid', weight_init=weight_init),)self.adv_layer = nn.Sigmoid()def construct(self, x):out = self.discriminator(x)out = out.reshape(out.shape[0], -1)return self.adv_layer(out)discriminator = Discriminator()

模型训练

损失函数

当定义了DG后,接下来将使用MindSpore中定义的二进制交叉熵损失函数BCELoss。

优化器

这里设置了两个单独的优化器,一个用于D,另一个用于G。这两个都是lr = 0.0002beta1 = 0.5的Adam优化器。

训练模型

训练分为两个主要部分:训练判别器和训练生成器。

  • 训练判别器

    训练判别器的目的是最大程度地提高判别图像真伪的概率。按照Goodfellow的方法,是希望通过提高其随机梯度来更新判别器,所以我们要最大化𝑙𝑜𝑔𝐷(𝑥)+𝑙𝑜𝑔(1−𝐷(𝐺(𝑧))𝑙𝑜𝑔𝐷(𝑥)+𝑙𝑜𝑔(1−𝐷(𝐺(𝑧))的值。

  • 训练生成器

    如DCGAN论文所述,我们希望通过最小化𝑙𝑜𝑔(1−𝐷(𝐺(𝑧)))𝑙𝑜𝑔(1−𝐷(𝐺(𝑧)))来训练生成器,以产生更好的虚假图像。

在这两个部分中,分别获取训练过程中的损失,并在每个周期结束时进行统计,将fixed_noise批量推送到生成器中,以直观地跟踪G的训练进度。

下面实现模型训练正向逻辑:

# 定义损失函数
adversarial_loss = nn.BCELoss(reduction='mean')# 为生成器和判别器设置优化器
optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G.update_parameters_name('optim_g.')
optimizer_D.update_parameters_name('optim_d.')def generator_forward(real_imgs, valid):# 将噪声采样为发生器的输入z = ops.standard_normal((real_imgs.shape[0], nz, 1, 1))# 生成一批图像gen_imgs = generator(z)# 损失衡量发生器绕过判别器的能力g_loss = adversarial_loss(discriminator(gen_imgs), valid)return g_loss, gen_imgsdef discriminator_forward(real_imgs, gen_imgs, valid, fake):# 衡量鉴别器从生成的样本中对真实样本进行分类的能力real_loss = adversarial_loss(discriminator(real_imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs), fake)d_loss = (real_loss + fake_loss) / 2return d_lossgrad_generator_fn = ms.value_and_grad(generator_forward, None,optimizer_G.parameters,has_aux=True)
grad_discriminator_fn = ms.value_and_grad(discriminator_forward, None,optimizer_D.parameters)@ms.jit
def train_step(imgs):valid = ops.ones((imgs.shape[0], 1), mindspore.float32)fake = ops.zeros((imgs.shape[0], 1), mindspore.float32)(g_loss, gen_imgs), g_grads = grad_generator_fn(imgs, valid)optimizer_G(g_grads)d_loss, d_grads = grad_discriminator_fn(imgs, gen_imgs, valid, fake)optimizer_D(d_grads)return g_loss, d_loss, gen_imgsimport mindsporeG_losses = []
D_losses = []
image_list = []total = dataset.get_dataset_size()
for epoch in range(num_epochs):generator.set_train()discriminator.set_train()# 为每轮训练读入数据for i, (imgs, ) in enumerate(dataset.create_tuple_iterator()):g_loss, d_loss, gen_imgs = train_step(imgs)if i % 100 == 0 or i == total - 1:# 输出训练记录print('[%2d/%d][%3d/%d]   Loss_D:%7.4f  Loss_G:%7.4f' % (epoch + 1, num_epochs, i + 1, total, d_loss.asnumpy(), g_loss.asnumpy()))D_losses.append(d_loss.asnumpy())G_losses.append(g_loss.asnumpy())# 每个epoch结束后,使用生成器生成一组图片generator.set_train(False)fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))img = generator(fixed_noise)image_list.append(img.transpose(0, 2, 3, 1).asnumpy())# 保存网络模型参数为ckpt文件mindspore.save_checkpoint(generator, "./generator.ckpt")mindspore.save_checkpoint(discriminator, "./discriminator.ckpt")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/41133.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

秒验—手机号码置换接口

功能说明 提交客户端获取到的token、opToken等数据,验证后返回手机号码 服务端务必不要缓存DNS,否则可能影响服务高可用性 调用地址 POST https://identify-verify.dutils.com/auth/auth/sdkClientFreeLogin 请求头 Content-Type :appli…

图书商城系统java项目ssm项目jsp项目java课程设计java毕业设计

文章目录 图书商城系统一、项目演示二、项目介绍三、部分功能截图四、部分代码展示五、底部获取项目源码(9.9¥带走) 图书商城系统 一、项目演示 图书商城系统 二、项目介绍 语言: Java 数据库:MySQL 技术栈:SpringS…

SaaS行业的AI化征程:穿越“大模型焦虑”,拥抱“AI自信”

随着大模型技术的风起云涌,SaaS行业正站在一个充满机遇与挑战的十字路口。本文旨在深入剖析SaaS厂商在AI化升级过程中所遭遇的“大模型焦虑”,并探索通过战略性的AI应用策略,如何重拾信心,实现产品与服务的华丽转身,为…

关于虚拟机上不了网的解决办法

先ping出ip地址 或者查询ifconfig得到目前网络信息 继续输入命令Ifconfig -a查询是否能找到ip地址 明显ens33是没有打开的,所以找不到分配的ip地址,需要打开,自动随机分配ip 输入命令: sudo dhclient ens33 现在就可以开始上网…

公司“领导”们竟如此讨论工作!小伙:此事有蹊跷;|国家漏洞库CNNVD:关于OpenSSH安全漏洞的通报;

公司“领导”们竟如此讨论工作!小伙:此事有蹊跷 “当时我正在等验证码 还好你们快了一步 不然公司的93万余元就没了” 一谈到这件事 杜先生仍然心有余悸 近日 正在处理公司财务工作的杜先生 突然被拉进了一个QQ群聊 从头像、昵称上看 群聊里的竟…

累积分布函数的一些性质证明

性质1: E [ X ] ∫ 0 ∞ ( 1 − F ( x ) ) d x − ∫ − ∞ 0 F ( x ) d x ( 1 ) E[X]\int_0^{\infty}(1-F(x))dx - \int_{-\infty}^0F(x)dx\quad (1) E[X]∫0∞​(1−F(x))dx−∫−∞0​F(x)dx(1) 证明: E [ X ] ∫ − ∞ ∞ x p ( x ) d x E[X] …

SpringBoot | 大新闻项目后端(redis优化登录)

该项目的前篇内容的使用jwt令牌实现登录认证,使用Md5加密实现注册,在上一篇:http://t.csdnimg.cn/vn3rB 该篇主要内容:redis优化登录和ThreadLocal提供线程局部变量,以及该大新闻项目的主要代码。 redis优化登录 其实…

macOS版ChatGPT更新:修复AI对话纯文本存储问题

猫头虎 🐯 建联猫头虎,商务合作,产品评测,产品推广,个人自媒体创作,超级个体,涨粉秘籍,一起探索编程世界的无限可能! macOS版ChatGPT更新:修复AI对话纯文本…

JAVA高级进阶11多线程

第十一天、多线程 线程安全问题 线程安全问题 多线程给我们带来了很大性能上的提升,但是也可能引发线程安全问题 线程安全问题指的是当个多线程同时操作同一个共享资源的时候,可能会出现的操作结果不符预期问题 线程同步方案 认识线程同步 线程同步 线程同步就是让多个线…

内网渗透学习-杀入内网

1、靶机上线cs 我们已经拿到了win7的shell,执行whoami,发现win7是administrator权限,且在域中 执行ipconfig发现了win7存在内网网段192.168.52.0/24 kali开启cs服务端 客户端启动cs 先在cs中创建一个监听器 接着用cs生成后门,记…

Mysql 的第二次作业

一、数据库 1、登陆数据库 2、创建数据库zoo 3、修改数据库zoo字符集为gbk 4、选择当前数据库为zoo 5、查看创建数据库zoo信息 6、删除数据库zoo 1)登陆数据库。 打开命令行,输入登陆用户名和密码。 mysql -uroot -p123456 ​ 2)切换数据库…

利用pg_rman进行备份与恢复操作

文章目录 pg_rman简介一、安装配置pg_rman二、创建表与用户三、备份与恢复 pg_rman简介 pg_rman 是 PostgreSQL 的在线备份和恢复工具。类似oracle 的 rman pg_rman 项目的目标是提供一种与 pg_dump 一样简单的在线备份和 PITR 方法。此外,它还为每个数据库集群维护…

Day05-01-jenkins进阶

Day05-01-jenkins进阶 10. 案例07: 理解 案例06基于ans实现10.1 整体流程10.2 把shell改为Ansible剧本10.3 jk调用ansible全流程10.4 书写剧本 11. Jenkins进阶11.1 jenkins分布式1)概述2)案例08:拆分docker功能3)创建任务并绑定到…

【刷题笔记(编程题)05】另类加法、走方格的方案数、井字棋、密码强度等级

1. 另类加法 给定两个int A和B。编写一个函数返回AB的值,但不得使用或其他算数运算符。 测试样例: 1,2 返回:3 示例 1 输入 输出 思路1: 二进制0101和1101的相加 0 1 0 1 1 1 0 1 其实就是 不带进位的结果1000 和进位产生的1010相加 无进位加…

ssm校园志愿服务信息系统-计算机毕业设计源码97697

摘 要 随着社会的进步和信息技术的发展,越来越多的学校开始重视志愿服务工作,通过组织各种志愿服务活动,让学生更好地了解社会、服务社会。然而,在实际操作中,志愿服务的组织和管理面临着诸多问题,如志愿者…

dledger原理源码分析系列(一)-架构,核心组件和rpc组件

简介 dledger是openmessaging的一个组件, raft算法实现,用于分布式日志,本系列分析dledger如何实现raft概念,以及dledger在rocketmq的应用 本系列使用dledger v0.40 本文分析dledger的架构,核心组件;rpc组…

【pytorch16】MLP反向传播

链式法则回顾 多输出感知机的推导公式回顾 只与w相关的输出节点和输入节点有关 多层多输入感知机 扩展为多层感知机的话,意味着还有一些层(理解为隐藏层σ函数),暂且设置为 x j x_{j} xj​层 对于 x j x_{j} xj​层如果把前面的…

迅捷PDF编辑器合并PDF

迅捷PDF编辑器是一款专业的PDF编辑软件,不仅支持任意添加文本,而且可以任意编辑PDF原有内容,软件上方的工具栏中还有丰富的PDF标注、编辑功能,包括高亮、删除线、下划线这些基础的,还有规则或不规则框选、箭头、便利贴…

【护眼小知识】护眼台灯真的护眼吗?防近视台灯有效果吗?

当前,近视问题在人群中愈发普遍,据2024年的统计数据显示,我国儿童青少年的总体近视率已高达52.7%。并且近视背后潜藏着诸多眼部并发症的风险,例如视网膜脱离、白内障以及开角型青光眼等,严重的情况甚至可能引发失明。为…

PMP--知识卡片--波士顿矩阵

文章目录 记忆黑话概念作用图示 记忆 一说到波士顿就联想到波士顿龙虾,所以波士顿矩阵跟动物有关,狗,牛。 黑话 你公司的现金牛业务,正在逐渐变成瘦狗,应尽快采取收割策略;问题业务的储备太少&#xff0…