昇思25天学习打卡营第22天 | DCGAN生成漫画头像

昇思25天学习打卡营第22天 | DCGAN生成漫画头像

文章目录

  • 昇思25天学习打卡营第22天 | DCGAN生成漫画头像
    • DCGAN模型
    • 数据集
      • 数据下载和超参数
      • 创建数据集
      • 数据集可视化
    • 搭建网络
      • 生成器
      • 判别器
      • 损失函数和优化器
    • 模型训练
    • 总结
    • 打卡

DCGAN模型

深度卷积对抗生成网络(Depp Convolutional Generative Adversarial Networks, DCGAN)是GAN的直接拓展。区别在于DCGAN使用卷积和反卷积。

  • 判别器:由分层的卷积层、BatchNorm层和LeakyReLU激活层组成。输入是 3 × 64 × 64 3\times 64\times 64 3×64×64的图像,输出是该图像为真图像的概率。
  • 生成器:由反卷积层、BatchNorm层和ReLU激活层组成,输入是标准正态分布中提取出的隐向量 z z z,输出是 3 × 64 × 64 3\times 64\times 64 3×64×64的RGB图像。

数据集

实验使用动漫头像数据集,共有70,171张动漫头像图片,大小均为 96 × 96 96\times 96 96×96

数据下载和超参数

from download import downloadurl = "https://download.mindspore.cn/dataset/Faces/faces.zip"path = download(url, "./faces", kind="zip", replace=True)batch_size = 128          # 批量大小
image_size = 64           # 训练图像空间大小
nc = 3                    # 图像彩色通道数
nz = 100                  # 隐向量的长度
ngf = 64                  # 特征图在生成器中的大小
ndf = 64                  # 特征图在判别器中的大小
num_epochs = 3           # 训练周期数
lr = 0.0002               # 学习率
beta1 = 0.5               # Adam优化器的beta1超参数

创建数据集

import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision as visiondef create_dataset_imagenet(dataset_path):"""数据加载"""dataset = ds.ImageFolderDataset(dataset_path,num_parallel_workers=4,shuffle=True,decode=True)# 数据增强操作transforms = [vision.Resize(image_size),vision.CenterCrop(image_size),vision.HWC2CHW(),lambda x: ((x / 255).astype("float32"))]# 数据映射操作dataset = dataset.project('image')dataset = dataset.map(transforms, 'image')# 批量操作dataset = dataset.batch(batch_size)return datasetdataset = create_dataset_imagenet('./faces')

数据集可视化

import matplotlib.pyplot as pltdef plot_data(data):# 可视化部分训练数据plt.figure(figsize=(10, 3), dpi=140)for i, image in enumerate(data[0][:30], 1):plt.subplot(3, 10, i)plt.axis("off")plt.imshow(image.transpose(1, 2, 0))plt.show()sample_data = next(dataset.create_tuple_iterator(output_numpy=True))
plot_data(sample_data)

在这里插入图片描述

搭建网络

生成器

生成器 G G G是将隐向量 z z z映射到数据空间,由一系列Conv2dTransposeBatchNorm2dReLU构成,输出数据经过tanh函数,使得返回 [ − 1 , 1 ] [-1,1] [1,1]范围的数据。

import mindspore as ms
from mindspore import nn, ops
from mindspore.common.initializer import Normalweight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)class Generator(nn.Cell):"""DCGAN网络生成器"""def __init__(self):super(Generator, self).__init__()self.generator = nn.SequentialCell(nn.Conv2dTranspose(nz, ngf * 8, 4, 1, 'valid', weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 8, ngf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 4, ngf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 2, ngf, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf, nc, 4, 2, 'pad', 1, weight_init=weight_init),nn.Tanh())def construct(self, x):return self.generator(x)generator = Generator()

判别器

判别器 D D D是一个二分类网络,由Conv2dBatchNorm2dLeakyReLU构成,最后通过Sigmoid激活函数得到最终概率。

class Discriminator(nn.Cell):"""DCGAN网络判别器"""def __init__(self):super(Discriminator, self).__init__()self.discriminator = nn.SequentialCell(nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, weight_init=weight_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf, ndf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 8, 1, 4, 1, 'valid', weight_init=weight_init),)self.adv_layer = nn.Sigmoid()def construct(self, x):out = self.discriminator(x)out = out.reshape(out.shape[0], -1)return self.adv_layer(out)discriminator = Discriminator()

损失函数和优化器

使用二进制交叉熵损失函数BCELossAdam优化器:

# 定义损失函数
adversarial_loss = nn.BCELoss(reduction='mean')# 为生成器和判别器设置优化器
optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G.update_parameters_name('optim_g.')
optimizer_D.update_parameters_name('optim_d.')

模型训练

  • 判别器:最大化 log ⁡ D ( x ) + log ⁡ ( 1 − D ( G ( z ) ) \log D(x)+\log(1-D(G(z)) logD(x)+log(1D(G(z))
  • 生成器:最小化 log ⁡ ( 1 − D ( G ( z ) ) ) \log(1-D(G(z))) log(1D(G(z)))
def generator_forward(real_imgs, valid):# 将噪声采样为发生器的输入z = ops.standard_normal((real_imgs.shape[0], nz, 1, 1))# 生成一批图像gen_imgs = generator(z)# 损失衡量发生器绕过判别器的能力g_loss = adversarial_loss(discriminator(gen_imgs), valid)return g_loss, gen_imgsdef discriminator_forward(real_imgs, gen_imgs, valid, fake):# 衡量鉴别器从生成的样本中对真实样本进行分类的能力real_loss = adversarial_loss(discriminator(real_imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs), fake)d_loss = (real_loss + fake_loss) / 2return d_lossgrad_generator_fn = ms.value_and_grad(generator_forward, None,optimizer_G.parameters,has_aux=True)
grad_discriminator_fn = ms.value_and_grad(discriminator_forward, None,optimizer_D.parameters)@ms.jit
def train_step(imgs):valid = ops.ones((imgs.shape[0], 1), mindspore.float32)fake = ops.zeros((imgs.shape[0], 1), mindspore.float32)(g_loss, gen_imgs), g_grads = grad_generator_fn(imgs, valid)optimizer_G(g_grads)d_loss, d_grads = grad_discriminator_fn(imgs, gen_imgs, valid, fake)optimizer_D(d_grads)return g_loss, d_loss, gen_imgsimport mindsporeG_losses = []
D_losses = []
image_list = []total = dataset.get_dataset_size()
for epoch in range(num_epochs):generator.set_train()discriminator.set_train()# 为每轮训练读入数据for i, (imgs, ) in enumerate(dataset.create_tuple_iterator()):g_loss, d_loss, gen_imgs = train_step(imgs)if i % 100 == 0 or i == total - 1:# 输出训练记录print('[%2d/%d][%3d/%d]   Loss_D:%7.4f  Loss_G:%7.4f' % (epoch + 1, num_epochs, i + 1, total, d_loss.asnumpy(), g_loss.asnumpy()))D_losses.append(d_loss.asnumpy())G_losses.append(g_loss.asnumpy())# 每个epoch结束后,使用生成器生成一组图片generator.set_train(False)fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))img = generator(fixed_noise)image_list.append(img.transpose(0, 2, 3, 1).asnumpy())# 保存网络模型参数为ckpt文件mindspore.save_checkpoint(generator, "./generator.ckpt")mindspore.save_checkpoint(discriminator, "./discriminator.ckpt")

总结

这一节介绍了深度卷积生成对抗网络DCGAN,相对于经典的GAN网络来说,将生成器中的全连接层换成了反卷积层,而将判别器中的全连接层换成了卷积层,其训练过程和GAN网络基本一样。通过在70171张动漫头像上进行训练,使得该对抗网络能够生成动漫头像图片。

打卡

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/49110.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python打包exe文件-实现记录

1、使用pyinstaller库 安装库: pip install pyinstaller打包命令标注主入库程序: pyinstaller -F.\程序入口文件.py 出现了一个问题就是我在打包运行之后会出现有一些插件没有被打包。 解决问题: 通过添加--hidden-importcomtypes.strea…

GeoHash原理介绍以及在redis中的应用

GeoHash将二维信息编码成了一个一维信息。降维后有三个好处: 编码后数据长度变短,利于节省存储。利于使用前缀检索当分割的足够细致,能够快速的对双方距离进行快速查询 GeoHash是一种地址编码方法。他能够把二维的空间经纬度数据编码成一个字符串。 1…

react开发-配置开发时候@指向SRC目录

这里写目录标题 配置开发时候指向SRC目录VScode编辑器给出提示总体1.配置react的 2.配置Vscode的1.配置react的2,配置VSCode的提示支持 配置开发时候指向SRC目录VScode编辑器给出提示 总体1.配置react的 2.配置Vscode的 1.配置react的 1. 我么需要下载一个webpack的插件 这样…

判断推理1

判断推理 1.定义判断 2.类比推理 3.逻辑判断 4.图形推理 加粗文本 加粗文本

map/multimap容器及STL案例

1.map概念:map中所有元素都是pair pair中的第一个元素为key(键值)起到索引作用,第二个为value(实值) 所有元素都会根据key值自动排序 本质:map/multimap属于关联式容器,底层结构是…

C语言 | Leetcode C语言题解之第257题二叉树的所有路径

题目: 题解: char** binaryTreePaths(struct TreeNode* root, int* returnSize) {char** paths (char**)malloc(sizeof(char*) * 1001);*returnSize 0;if (root NULL) {return paths;}struct TreeNode** node_queue (struct TreeNode**)malloc(size…

vue3使用html2canvas

安装 yarn add html2canvas 代码 <template><div class"container" ref"container"><div class"left"><img :src"logo" alt"" class"logo"><h2>Contractors pass/承包商通行证&l…

Mamba-yolo|结合Mamba注意力机制的视觉检测

一、本文介绍 PDF地址&#xff1a;https://arxiv.org/pdf/2405.16605v1 代码地址&#xff1a;GitHub - LeapLabTHU/MLLA: Official repository of MLLA Demystify Mamba in Vision: A Linear AttentionPerspective一文中引入Baseline Mamba&#xff0c;指明Mamba在处理各种高…

网络通讯实验报告

拓扑图 需求 1、通过DHCP服务&#xff0c;给PC4和PC5分配IP地址、网关、掩码、DNS服务器IP地址 2、Client-1要求手工配置IP地址&#xff0c;为192.168.1.1, c 3、telnet客户端可以远程登录telnet服务器进行设备管理&#xff0c;并成功修改telnet服务器的名字为123 &#xff0c…

操作系统——进程与线程(死锁)

1&#xff09;为什么会产生死锁&#xff1f;产生死锁有什么条件&#xff1f; 2&#xff09;有什么办法解决死锁&#xff1f; 一、死锁 死锁:多个程序因竞争资源而造成的一种僵局&#xff08;互相等待对方手里的资源&#xff09;&#xff0c;使得各个进程都被阻塞&#xff0c;…

一篇文章搞懂MySQL的事务与隔离级别

事务 概述 一个事务其实就是一个完整的业务逻辑&#xff0c;是一个最小的工作单元。要么同时成功&#xff0c;要么同时失败&#xff0c;不可再分 假设转账&#xff0c;从A账户向B账户转账10000 A账户的钱减去10000&#xff08;update语句&#xff09; B账户的钱加上10000&…

【HarmonyOS学习】用户文件访问

概述 文件所有者为登录到该终端设备的用户&#xff0c;包括用户私有的图片、视频、音频、文档等。 应用对用户文件的创建、访问、删除等行为&#xff0c;需要提前获取用户授权&#xff0c;或由用户操作完成。 用户文件访问框架 是一套提供给开发者访问和管理用户文件的基础框…

无需抠图!AI绘画直接文本生成透明底图层,设计师必看的ComfyUI透明图层生成工作流教程!(附插件模型)

大家好&#xff0c;我是画画的小强 AI 绘画自出现以来一直都在不断发展完善&#xff0c;实现了很多我们在实际应用中迫切需要的功能&#xff0c;比如生成正确的手指、指定的姿势、准确的文本内容等。上周&#xff0c;又一个重磅新功能在开源的 SD 生态内实现了——直接通过文本…

【北京迅为】《i.MX8MM嵌入式Linux开发指南》-第三篇 嵌入式Linux驱动开发篇-第四十四章 注册字符设备号

i.MX8MM处理器采用了先进的14LPCFinFET工艺&#xff0c;提供更快的速度和更高的电源效率;四核Cortex-A53&#xff0c;单核Cortex-M4&#xff0c;多达五个内核 &#xff0c;主频高达1.8GHz&#xff0c;2G DDR4内存、8G EMMC存储。千兆工业级以太网、MIPI-DSI、USB HOST、WIFI/BT…

Springboot项目打包成镜像、使用docker-compose启动

Springboot项目打包成镜像、使用docker-compose启动 1、创建一个boot项目 1、添加依赖 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSch…

gitee的怎么上传项目

前提 1.先下载Git Bash (如果没有下载的宝子们下载连接如下: 链接: link ) 项目上传到Gitee步骤 1.在Gitee上建立远程仓库 2.填写相关信息 3.进入本地你想要上传的文件目录下&#xff0c;右键单击空白处&#xff0c;点击Git Bash Here 4.配置你的用户名和邮箱 git con…

【leetcode】排列序列

给出集合 [1,2,3,...,n]&#xff0c;其所有元素共有 n! 种排列。 按大小顺序列出所有排列情况&#xff0c;并一一标记&#xff0c;当 n 3 时, 所有排列如下&#xff1a; "123""132""213""231""312""321" 给定…

最简单的typora+gitee+picgo配置图床

typoragiteepicgo图床 你是否因为管理图片而感到头大&#xff1f;是时候了解一下 Typora、Gitee 和 PicGo 这个超级三剑客了&#xff0c;它们可以帮你轻松打造自己的图床&#xff0c;让你的博客图片管理变得简单又有趣。让我们开始这场神奇的图床之旅吧&#xff01; Typora …

7.20 模拟赛总结 [邻项交换] + [决策单调性]

只放题解喽 题解 T1T2T3T4 T1 等价于维护差分数组&#xff0c;数据范围较小&#xff0c;map 套 vector 维护即可 更大的数据范围可以 hash 做 T2 神奇贪心 本题关键在于定序&#xff0c;考虑顺序确定后答案怎么求 设 f i f_i fi​ 表示 第 i i i 件衣服烘干完的时间&…

运放构成电压跟随器,反馈电阻作用;运放电流采集电路,单电源供电,TINA仿真

电压跟随器 使用运放构成电压跟随器可以减小负载对信号源的影响&#xff0c;还可以提高信号带负载的能力&#xff0c;这是因为运放的结构特性&#xff0c;输入电阻大&#xff0c;输出电阻小。 是否决定使用该电压跟随器&#xff0c;就要看信号源&#xff0c;以及负载的阻抗大小…