昇思训练营打卡第二十一天(DCGAN生成漫画头像)

DCGAN,即深度卷积生成对抗网络(Deep Convolutional Generative Adversarial Network),是一种深度学习模型,由Ian Goodfellow等人在2014年提出。DCGAN在生成对抗网络(GAN)的基础上,引入了深度卷积神经网络(CNN)的结构,用于生成高质量、高分辨率的图像。

DCGAN的原理可以概括为两个主要部分:生成器(Generator)和判别器(Discriminator)。生成器和判别器都是深度卷积神经网络,它们通过对抗过程互相博弈,最终达到纳什均衡。

  1. 生成器(Generator):生成器的输入是一个随机的噪声向量,通过一系列的卷积、反卷积、批标准化(Batch Normalization)和激活函数(如ReLU、Tanh等)操作,生成一个与真实图像具有相同尺寸的图像。生成器的目标是生成尽可能逼真的图像,以欺骗判别器。

  2. 判别器(Discriminator):判别器的输入是一个图像,它通过一系列的卷积、批标准化和激活函数操作,判断输入图像是真实图像还是生成器生成的假图像。判别器的目标是能够准确地区分真实图像和假图像。

在训练过程中,生成器和判别器交替进行优化。生成器尝试生成逼真的图像,而判别器尝试更好地识别真实图像和假图像。这个过程可以看作是一种博弈,生成器和判别器在不断的迭代过程中提高自己的性能。最终,当生成器和判别器达到纳什均衡时,生成器能够生成高质量的逼真图像,判别器无法准确地区分真实图像和假图像。

DCGAN在计算机视觉领域有广泛的应用,如图像生成、图像修复、图像转换等。通过调整网络结构和训练策略,DCGAN还可以应用于其他领域,如自然语言处理、音频生成等。

数据准备与处理

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
from download import downloadurl = "https://download.mindspore.cn/dataset/Faces/faces.zip"path = download(url, "./faces", kind="zip", replace=True)
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision as visiondef create_dataset_imagenet(dataset_path):"""数据加载"""dataset = ds.ImageFolderDataset(dataset_path,num_parallel_workers=4,shuffle=True,decode=True)# 数据增强操作transforms = [vision.Resize(image_size),vision.CenterCrop(image_size),vision.HWC2CHW(),lambda x: ((x / 255).astype("float32"))]# 数据映射操作dataset = dataset.project('image')dataset = dataset.map(transforms, 'image')# 批量操作dataset = dataset.batch(batch_size)return datasetdataset = create_dataset_imagenet('./faces')
import matplotlib.pyplot as pltdef plot_data(data):# 可视化部分训练数据plt.figure(figsize=(10, 3), dpi=140)for i, image in enumerate(data[0][:30], 1):plt.subplot(3, 10, i)plt.axis("off")plt.imshow(image.transpose(1, 2, 0))plt.show()sample_data = next(dataset.create_tuple_iterator(output_numpy=True))
plot_data(sample_data)

构造网络

生成器

import mindspore as ms
from mindspore import nn, ops
from mindspore.common.initializer import Normalweight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)class Generator(nn.Cell):"""DCGAN网络生成器"""def __init__(self):super(Generator, self).__init__()self.generator = nn.SequentialCell(nn.Conv2dTranspose(nz, ngf * 8, 4, 1, 'valid', weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 8, ngf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 4, ngf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 2, ngf, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf, nc, 4, 2, 'pad', 1, weight_init=weight_init),nn.Tanh())def construct(self, x):return self.generator(x)generator = Generator()

判别器

class Discriminator(nn.Cell):"""DCGAN网络判别器"""def __init__(self):super(Discriminator, self).__init__()self.discriminator = nn.SequentialCell(nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, weight_init=weight_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf, ndf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 8, 1, 4, 1, 'valid', weight_init=weight_init),)self.adv_layer = nn.Sigmoid()def construct(self, x):out = self.discriminator(x)out = out.reshape(out.shape[0], -1)return self.adv_layer(out)discriminator = Discriminator()

模型训练

损失函数

# 定义损失函数
adversarial_loss = nn.BCELoss(reduction='mean')

优化器

# 为生成器和判别器设置优化器
optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G.update_parameters_name('optim_g.')
optimizer_D.update_parameters_name('optim_d.')

训练模型

训练分为两个主要部分:训练判别器和训练生成器。

def generator_forward(real_imgs, valid):# 将噪声采样为发生器的输入z = ops.standard_normal((real_imgs.shape[0], nz, 1, 1))# 生成一批图像gen_imgs = generator(z)# 损失衡量发生器绕过判别器的能力g_loss = adversarial_loss(discriminator(gen_imgs), valid)return g_loss, gen_imgsdef discriminator_forward(real_imgs, gen_imgs, valid, fake):# 衡量鉴别器从生成的样本中对真实样本进行分类的能力real_loss = adversarial_loss(discriminator(real_imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs), fake)d_loss = (real_loss + fake_loss) / 2return d_lossgrad_generator_fn = ms.value_and_grad(generator_forward, None,optimizer_G.parameters,has_aux=True)
grad_discriminator_fn = ms.value_and_grad(discriminator_forward, None,optimizer_D.parameters)@ms.jit
def train_step(imgs):valid = ops.ones((imgs.shape[0], 1), mindspore.float32)fake = ops.zeros((imgs.shape[0], 1), mindspore.float32)(g_loss, gen_imgs), g_grads = grad_generator_fn(imgs, valid)optimizer_G(g_grads)d_loss, d_grads = grad_discriminator_fn(imgs, gen_imgs, valid, fake)optimizer_D(d_grads)return g_loss, d_loss, gen_imgs
import mindsporeG_losses = []
D_losses = []
image_list = []total = dataset.get_dataset_size()
for epoch in range(num_epochs):generator.set_train()discriminator.set_train()# 为每轮训练读入数据for i, (imgs, ) in enumerate(dataset.create_tuple_iterator()):g_loss, d_loss, gen_imgs = train_step(imgs)if i % 100 == 0 or i == total - 1:# 输出训练记录print('[%2d/%d][%3d/%d]   Loss_D:%7.4f  Loss_G:%7.4f' % (epoch + 1, num_epochs, i + 1, total, d_loss.asnumpy(), g_loss.asnumpy()))D_losses.append(d_loss.asnumpy())G_losses.append(g_loss.asnumpy())# 每个epoch结束后,使用生成器生成一组图片generator.set_train(False)fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))img = generator(fixed_noise)image_list.append(img.transpose(0, 2, 3, 1).asnumpy())# 保存网络模型参数为ckpt文件mindspore.save_checkpoint(generator, "./generator.ckpt")mindspore.save_checkpoint(discriminator, "./discriminator.ckpt")
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G", color='blue')
plt.plot(D_losses, label="D", color='orange')
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()import matplotlib.pyplot as plt
import matplotlib.animation as animationdef showGif(image_list):show_list = []fig = plt.figure(figsize=(8, 3), dpi=120)for epoch in range(len(image_list)):images = []for i in range(3):row = np.concatenate((image_list[epoch][i * 8:(i + 1) * 8]), axis=1)images.append(row)img = np.clip(np.concatenate((images[:]), axis=0), 0, 1)plt.axis("off")show_list.append([plt.imshow(img)])ani = animation.ArtistAnimation(fig, show_list, interval=1000, repeat_delay=1000, blit=True)ani.save('./dcgan.gif', writer='pillow', fps=1)showGif(image_list)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/44013.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【CentOS】Linux命令之docker命令(持续更新)

删除所有容器 该命令将删除所有已停止的容器。你还可以使用其他状态值,例如created、restarting或dead docker container rm $(docker container ls -aqf statusexited)删除所有镜像 该命令将删除所有镜像,包括被使用的镜像。请注意,如果某…

【深度学习】第5章——卷积神经网络(CNN)

一、卷积神经网络 1.定义 卷积神经网络(Convolutional Neural Network, CNN)是一种专门用于处理具有网格状拓扑结构数据的深度学习模型,特别适用于图像和视频处理。CNN 通过局部连接和权重共享机制,有效地减少了参数数量&#x…

使用OpencvSharp实现人脸识别

在网上有很多关于这方面的博客,但是都没有说完整,按照他们的博客做下来代码都不能跑。所以我就自己写个博客补充一下 我这使用的.NET框架版本是 .NetFramework4.7.1 使用Nuget安装这两个程序包就够了,不需要其他的配置 一定要安装OpenCvSha…

大模型日报 2024-07-09

大模型日报 2024-07-09 大模型资讯 大模型最强架构TTT问世!斯坦福UCSD等5年磨一剑,一夜推翻Transformer 斯坦福UCSD等机构研究者提出的TTT方法,直接替代了注意力机制,语言模型方法从此或将彻底改变。这个模型通过对输入token进行梯…

在亚马逊云科技AWS上利用SageMaker机器学习模型平台搭建生成式AI应用(附Llama大模型部署和测试代码)

项目简介: 接下来,小李哥将会每天介绍一个基于亚马逊云科技AWS云计算平台的全球前沿AI技术解决方案,帮助大家快速了解国际上最热门的云计算平台亚马逊云科技AWS AI最佳实践,并应用到自己的日常工作里。本次介绍的是如何在Amazon …

802.11漫游流程简单解析与笔记_Part2_05_wpa_supplicant如何通过nl80211控制内核开始关联

最近在进行和802.11漫游有关的工作,需要对wpa_supplicant认证流程和漫游过程有更多的了解,所以通过阅读论文等方式,记录整理漫游相关知识。Part1将记录802.11漫游的基本流程、802.11R的基本流程、与认证和漫游都有关的三层秘钥基础。Part1将包…

Vue 3与Pinia:下一代状态管理的探索

引言 随着Vue 3的推出,Pinia应运而生,成为官方推荐的状态管理库,旨在替代Vuex。Pinia与Vuex相比,带来了以下主要区别和优势: 更简洁的API:Pinia的API设计更加直观和简洁,易于理解和使用。更好…

220V降5V芯片输出电压电流封装选型WT

220V降5V芯片输出电压电流封装选型WT 220V降5V恒压推荐:非隔离芯片选型及其应用方案 在考虑220V转低压应用方案时,以下非隔离芯片型号及其封装形式提供了不同的电压电流输出能力: 1. WT5101A(SOT23-3封装)适用于将2…

【实战场景】大文件解析入库的方案有哪些?

【实战场景】大文件解析入库的方案有哪些? 开篇词:干货篇:分块解析内存映射文件流式处理数据库集群处理分布式计算框架 总结篇:我是杰叔叔,一名沪漂的码农,下期再会! 开篇词: 需求背…

14-57 剑和诗人31 - LLM/SLM 中的高级 RAG

​​​ 首先确定几个缩写的意思 SLM 小模型 LLM 大模型 检索增强生成 (RAG) 已成为一种增强语言模型能力的强大技术。通过检索和调整外部知识,RAG 可让模型生成更准确、更相关、更全面的文本。 RAG 架构主要有三种类型:简单型、模块化和高级 RAG&…

性能测试的流程(企业真实流程详解)(二)

性能测试的流程 1.需求分析以及需求确定(指标值,场景,环境,人员) 一般提出需求的人员有:客户,产品经理,项目组领导等 2.性能测试计划和方案制定 基准测试: 负觋测试: 压力测试: 稳定性测试: 其他:配置测试…

Git安装使用教程

# 《Git 操作使用教程》 一、Git 简介 Git 是一个分布式版本控制系统,用于敏捷高效地处理任何或小或大的项目。它让开发者可以轻松地跟踪代码的更改、与团队成员协作,并管理项目的不同版本。 二、安装 Git 在 Windows 系统上,可以从 Git 官…

刷题Day47|1143.最长公共子序列、1035.不相交的线、53. 最大子序和、

1143.最长公共子序列 1143. 最长公共子序列 - 力扣(LeetCode) 思路:dp数组含义是以i-1和j-1为结尾的最长公共子序列。当text1[i - 1] text2[i - 1], dp[i][j] dp[i - 1][j - 1] 1; 否则dp[i][j] max(dp[i - 1][j], dp[i][j - 1]); 因为两…

无法连接Linux远程服务器的Mysql,解决办法

问题描述 如果是关闭虚拟机之后,二次打开无法连接Mysql,则可尝试一下方法进行解决 解决方法 关闭虚拟机的防火墙 1:查看防火墙状态 systemctl status firewalld 一下显示说明防火墙是启动的状态 2:关闭防火墙 systemctl st…

git提交emoji指南

emoji 指南 emojiemoji 代码commit 说明🎉 (庆祝)tada初次提交✨ (火花)sparkles引入新功能🔖 (书签)bookmark发行/版本标签🐛 (bug)bug修复 bug🚑 (急救车)ambulance重要补丁🌐 (地球)globe_with_meridians国际化与本…

PTA - 编写函数计算圆面积

题目描述: 1.要求编写函数getCircleArea(r)计算给定半径r的圆面积,函数返回圆的面积。 2.要求编写函数get_rList(n) 输入n个值放入列表并将列表返回 函数接口定义: getCircleArea(r); get_rList(n); 传入的参数r表示圆的半径&#xff0c…

音视频解封装demo:将FLV文件解封装(demux)得到文件中的H264数据和AAC数据(纯手工,不依赖第三方开源库)

1、README 前言 注意:flv是不支持h.265封装的。目前解封装功能正常,所得到的H.264文件与AAC文件均可正常播放。 a. demo使用 $ make clean && make DEBUG1 $ $ $ ./flv_demux_h264_aac Usage: ./flv_demux_h264_aac avfile/test1.flv./flv_d…

压缩感知1——算法简介

传统的数据采集 传统的数字信号采样定律就是有名的香农采样定理,又称那奎斯特采样定律定理内容如下:为了不失真地恢复模拟信号,采样频率应该不小于模拟信号频谱中最高频率的2倍 上述步骤得到的数字信号的数据量比较大,一方面不利…

C语言程序题(一)

一.三个整数从大到小输出 首先做这个题目需要知道理清排序的思路,通过比较三个整数的值,使之从大到小输出。解这道题有很多方法我就总结了两种方法:一是通过中间变量比较和交换,二是可以用冒泡排序法(虽然三个数字排序…

车载聚合路由器应用场景分析

乾元通QYT-X1z车载式1U多卡聚合路由器,支持最多8路聚合,无论是应急救援,还是车载交通,任何宽带服务商无法覆盖的区域,聚合路由器可提供现场需要的稳定、流畅、安全的视频传输网络,聚合路由器可无缝接入应急…