昇思学习打卡-20-生成式/GAN图像生成

文章目录

  • 网络介绍
  • 生成器和判别器的博弈过程
  • 数据集可视化
  • 模型细节
  • 训练过程
  • 网络优缺点
    • 优点
    • 缺点

网络介绍

GAN通过设计生成模型和判别模型这两个模块,使其互相博弈学习产生了相当好的输出。

GAN模型的核心在于提出了通过对抗过程来估计生成模型这一全新框架。在这个框架中,将会同时训练两个模型——捕捉数据分布的生成模型𝐺和估计样本是否来自训练数据的判别模型𝐷。

在训练过程中,生成器会不断尝试通过生成更好的假图像来骗过判别器,而判别器在这过程中也会逐步提升判别能力。这种博弈的平衡点是,当生成器生成的假图像和训练数据图像的分布完全一致时,判别器拥有50%的真假判断置信度。

生成器和判别器的博弈过程

  • 在训练刚开始的时候,生成器和判别器的质量都比较差,生成器会随机生成一个数据分布。
  • 判别器通过求取梯度和损失函数对网络进行优化,将靠近真实数据分布的数据判定为1,将靠近生成器生成出来数据分布的数据判定为0。
  • 生成器通过优化,生成出更加贴近真实数据分布的数据。
  • 生成器所生成的数据和真实数据达到相同的分布,此时判别器的输出为1/2。

过程如下图所示:
在这里插入图片描述
在上图中,蓝色虚线表示判别器,黑色虚线表示真实数据分布,绿色实线表示生成器生成的虚假数据分布,𝑧表示隐码,𝑥表示生成的虚假图像 𝐺(𝑧) 。

数据集可视化

import matplotlib.pyplot as pltdata_iter = next(mnist_ds.create_dict_iterator(output_numpy=True))
figure = plt.figure(figsize=(3, 3))
cols, rows = 5, 5
for idx in range(1, cols * rows + 1):image = data_iter['image'][idx]figure.add_subplot(rows, cols, idx)plt.axis("off")plt.imshow(image.squeeze(), cmap="gray")
plt.show()

在这里插入图片描述

模型细节

下面介绍本网络用到的生成器、判别器及损失函数和优化器:

  • 生成器 Generator 的功能是将隐码映射到数据空间。由于数据是图像,这一过程也会创建与真实图像大小相同的灰度图像(或 RGB 彩色图像)。在本案例演示中,该功能通过五层 Dense 全连接层来完成的,每层都与 BatchNorm1d 批归一化层和 ReLU 激活层配对,输出数据会经过 Tanh 函数,使其返回 [-1,1] 的数据范围内。注意实例化生成器之后需要修改参数的名称,不然静态图模式下会报错。
  • 判别器 Discriminator 是一个二分类网络模型,输出判定该图像为真实图的概率。主要通过一系列的 Dense 层和 LeakyReLU 层对其进行处理,最后通过 Sigmoid 激活函数,使其返回 [0, 1] 的数据范围内,得到最终概率。
  • 损失函数使用MindSpore中二进制交叉熵损失函数BCELoss ;
  • 这里生成器和判别器都是使用Adam优化器,但是需要构建两个不同名称的优化器,分别用于更新两个模型的参数,详情见下文代码。注意优化器的参数名称也需要修改。

训练过程

import os
import time
import matplotlib.pyplot as plt
import mindspore as ms
from mindspore import Tensor, save_checkpointtotal_epoch = 12  # 训练周期数
batch_size = 64  # 用于训练的训练集批量大小# 加载预训练模型的参数
pred_trained = False
pred_trained_g = './result/checkpoints/Generator99.ckpt'
pred_trained_d = './result/checkpoints/Discriminator99.ckpt'checkpoints_path = "./result/checkpoints"  # 结果保存路径
image_path = "./result/images"  # 测试结果保存路径
# 生成器计算损失过程
def generator_forward(test_noises):fake_data = net_g(test_noises)fake_out = net_d(fake_data)loss_g = adversarial_loss(fake_out, ops.ones_like(fake_out))return loss_g# 判别器计算损失过程
def discriminator_forward(real_data, test_noises):fake_data = net_g(test_noises)fake_out = net_d(fake_data)real_out = net_d(real_data)real_loss = adversarial_loss(real_out, ops.ones_like(real_out))fake_loss = adversarial_loss(fake_out, ops.zeros_like(fake_out))loss_d = real_loss + fake_lossreturn loss_d# 梯度方法
grad_g = ms.value_and_grad(generator_forward, None, net_g.trainable_params())
grad_d = ms.value_and_grad(discriminator_forward, None, net_d.trainable_params())def train_step(real_data, latent_code):# 计算判别器损失和梯度loss_d, grads_d = grad_d(real_data, latent_code)optimizer_d(grads_d)loss_g, grads_g = grad_g(latent_code)optimizer_g(grads_g)return loss_d, loss_g# 保存生成的test图像
def save_imgs(gen_imgs1, idx):for i3 in range(gen_imgs1.shape[0]):plt.subplot(5, 5, i3 + 1)plt.imshow(gen_imgs1[i3, 0, :, :] / 2 + 0.5, cmap="gray")plt.axis("off")plt.savefig(image_path + "/test_{}.png".format(idx))# 设置参数保存路径
os.makedirs(checkpoints_path, exist_ok=True)
# 设置中间过程生成图片保存路径
os.makedirs(image_path, exist_ok=True)net_g.set_train()
net_d.set_train()# 储存生成器和判别器loss
losses_g, losses_d = [], []for epoch in range(total_epoch):start = time.time()for (iter, data) in enumerate(mnist_ds):start1 = time.time()image, latent_code = dataimage = (image - 127.5) / 127.5  # [0, 255] -> [-1, 1]image = image.reshape(image.shape[0], 1, image.shape[1], image.shape[2])d_loss, g_loss = train_step(image, latent_code)end1 = time.time()if iter % 10 == 10:print(f"Epoch:[{int(epoch):>3d}/{int(total_epoch):>3d}], "f"step:[{int(iter):>4d}/{int(iter_size):>4d}], "f"loss_d:{d_loss.asnumpy():>4f} , "f"loss_g:{g_loss.asnumpy():>4f} , "f"time:{(end1 - start1):>3f}s, "f"lr:{lr:>6f}")end = time.time()print("time of epoch {} is {:.2f}s".format(epoch + 1, end - start))losses_d.append(d_loss.asnumpy())losses_g.append(g_loss.asnumpy())# 每个epoch结束后,使用生成器生成一组图片gen_imgs = net_g(test_noise)save_imgs(gen_imgs.asnumpy(), epoch)# 根据epoch保存模型权重文件if epoch % 1 == 0:save_checkpoint(net_g, checkpoints_path + "/Generator%d.ckpt" % (epoch))save_checkpoint(net_d, checkpoints_path + "/Discriminator%d.ckpt" % (epoch))

在这里插入图片描述

网络优缺点

优点

  • 生成数据自然:
    GAN通过生成器和判别器的对抗训练,生成的图像数据具有高度的自然性和逼真度。这种自然性使得GAN在图像生成、图像修复、图像超分辨率等领域具有广泛应用。
  • 训练效率高:
    GAN的创新之处在于其两个神经网络的对抗训练方式,这种训练方式简单易控,且能够显著改善生成式模型的训练效率。

缺点

  • 训练不稳定:
    GAN的训练过程可能不稳定,容易出现模式崩溃等问题。模式崩溃表现为生成器开始退化,总是生成同样的样本点,无法继续学习。这可能是由于生成器和判别器之间的对抗关系过于复杂,导致训练难以达到稳定的纳什均衡状态。
  • 计算资源需求高:
    GAN的训练过程需要大量的计算资源和时间。特别是对于大规模的数据集和高分辨率的图像,GAN的训练成本可能非常高昂。此外,GAN中的神经网络结构通常较为复杂,这也需要大量的存储空间来支持。

此章节学习到此结束,感谢昇思平台。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/47372.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

RK3568笔记三十九:多个LED驱动开发测试(设备树)

若该文为原创文章,转载请注明原文出处。 通过设备树配置一个节点下两个子节点控制两个IO口,一个板载LED,一个外接LED。 一、介绍 通过学习设备树控制GPIO,发现有多种方式 一、直接通过寄存器控制 二、通过设备树,但…

基于STC89C52RC单片机的大棚温控系统(含文档、源码与proteus仿真,以及系统详细介绍)

本篇文章论述的是基于STC89C52RC单片机的大棚温控系统的详情介绍,如果对您有帮助的话,还请关注一下哦,如果有资源方面的需要可以联系我。 目录 摘要 原理图 仿真图 系统总体设计图 代码 系统论文 参考文献 资源下载 摘要 本文介绍的…

CSA笔记3-文件管理命令(补充)+vim+打包解包压缩解压缩命令

grep(-i -n -v -w) [rootxxx ~]# grep root anaconda-ks.cfg #匹配关键字所在的行 [rootxxx ~]# grep -i root anaconda-ks.cfg #-i 忽略大小写 [rootxxx ~]# grep -n root anaconda-ks.cfg #显示匹配到的行号 [rootxxx ~]# grep -v root anaconda-ks.cfg #-v 不匹配有…

甄选范文“论软件维护方法及其应用”软考高级论文,系统架构设计师论文

论文真题 软件维护是指在软件交付使用后,直至软件被淘汰的整个时间范围内,为了改正错误或满足 新的需求而修改软件的活动。在软件系统运行过程中,软件需要维护的原因是多种多样的, 根据维护的原因不同,可以将软件维护分为改正性维护、适应性维护、完善性维护和预防性 维护…

Linux 上 TTY 的起源

注:机翻,未校对。 What is a TTY on Linux? (and How to Use the tty Command) What does the tty command do? It prints the name of the terminal you’re using. TTY stands for “teletypewriter.” What’s the story behind the name of the co…

debian 实现离线批量安装软件包

前言 实现在线缓冲需要的软件和对应依赖的包,离线进行安装 ,用于软件封装。 测试下载一个gcc和依赖环境,关闭默认在线源,测试离线安装gcc和依赖环境 兼容 debian ubuntu/test 测试下载安装包到目录 vim /repo_download.sh #!…

【数据结构】算法复杂度

算法复杂度 数据结构算法复杂度 大o渐进表示法空间复杂度 数据结构 数据结构:是计算机存储和组织数据的方式。 比如打开一个网页,我们看到的文字就是数据,这些数据需要用一个结构来把他管理起来,我们称之为:数据结构 …

基于springboot3实现单点登录(一): 单点登录及其相关概念介绍

引言 应网友要求,从本文开始我们将实现一套基于springboot3springsecurity的单点登录认证系统。 单点登录的实现方式有多种,接下来我们会以oauth2为例来介绍和实现。 单点登录介绍 单点登录(Single Sign-On,简称SSO&#xff0…

nftables(7)集合(SETS)

简介 在nftables中,集合(sets)是一个非常有用的特性,它允许你以集合的形式管理IP地址、端口号等网络元素,从而简化规则的配置和管理。 nftables提供了两种类型的集合:匿名集合和命名集合。 匿名集合&…

使用base64通用文件上传

编写一个上传文件的组件 tuku,点击图片上传后使用FileReader异步读取文件的内容&#xff0c;读取完成后获得文件名和base64码&#xff0c;调用后端uploadApi,传入姓名和base64文件信息&#xff0c;后端存入nginx中&#xff0c;用于访问 tuku.ts组件代码&#xff1a; <templa…

系统测试-白盒测试学习

目录 1、语句覆盖法&#xff1a; 2、判定覆盖法&#xff1a; 3、条件覆盖法&#xff1a; 4、判定条件覆盖&#xff1a; 5、条件组合的覆盖&#xff1a; 6、路径覆盖&#xff1a; 黑盒&#xff1a;需求 白盒&#xff1a;主要用于单元测试 1、语句覆盖法&#xff1a; 程序…

OSU!题解(概率dp)

题目&#xff1a;OSU! - 洛谷 思路&#xff1a; 设E()表示截止到i所获得的分数&#xff1b; 对于到i点的每一个l&#xff0c;如果第i1点为1&#xff0c;那么会新增分数3*l^23*l1; 就有递推公式方程&#xff1a; E()E()p[i1]p*(3*l^23*l1);(p代表截止到i获得长度l的概率)&a…

怎样在 PostgreSQL 中优化对多表关联的连接条件选择?

&#x1f345;关注博主&#x1f397;️ 带你畅游技术世界&#xff0c;不错过每一次成长机会&#xff01;&#x1f4da;领书&#xff1a;PostgreSQL 入门到精通.pdf 文章目录 怎样在 PostgreSQL 中优化对多表关联的连接条件选择一、理解多表关联的基本概念二、选择合适的连接条件…

【C++】拷贝构造函数及析构函数

&#x1f4e2;博客主页&#xff1a;https://blog.csdn.net/2301_779549673 &#x1f4e2;欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1f4dd; 如有错误敬请指正&#xff01; &#x1f4e2;本文由 JohnKi 原创&#xff0c;首发于 CSDN&#x1f649; &#x1f4e2;未来很长&#…

dbeaver连接mysql8异常

部署了mysql8&#xff0c;尝试用dbeaver 24.1.2连接它。结果配置完成后测试连接时报错&#xff1a;Public Key Retrieval is not allowed. 按照提示修改驱动属性&#xff1a; allowPublicKeyRetrievaltrue

【BUG】已解决:ValueError: Expected 2D array, got 1D array instead

已解决&#xff1a;ValueError: Expected 2D array, got 1D array instead 欢迎来到英杰社区https://bbs.csdn.net/topics/617804998 欢迎来到我的主页&#xff0c;我是博主英杰&#xff0c;211科班出身&#xff0c;就职于医疗科技公司&#xff0c;热衷分享知识&#xff0c;武汉…

Python | Leetcode Python题解之第238题除自身以外数组的乘积

题目&#xff1a; 题解&#xff1a; class Solution:def productExceptSelf(self, nums: List[int]) -> List[int]:length len(nums)# L 和 R 分别表示左右两侧的乘积列表L, R, answer [0]*length, [0]*length, [0]*length# L[i] 为索引 i 左侧所有元素的乘积# 对于索引为…

人工智能 (AI) 应用:一个异常肺呼吸声辅助诊断系统

关键词&#xff1a;深度学习、肺癌、多标签、轻量级模型设计、异常肺音、音频分类 近年来&#xff0c;流感对人类的危害不断增加&#xff0c;COVID-19疾病的迅速传播加剧了这一问题&#xff0c;导致大多数患者因呼吸系统异常而死亡。在这次流行病爆发之前&#xff0c;呼吸系统…

SCI一区级 | Matlab实现GJO-CNN-LSTM-Multihead-Attention多变量时间序列预测

SCI一区级 | Matlab实现GJO-CNN-LSTM-Mutilhead-Attention多变量时间序列预测 目录 SCI一区级 | Matlab实现GJO-CNN-LSTM-Mutilhead-Attention多变量时间序列预测预测效果基本介绍程序设计参考资料 预测效果 基本介绍 1.Matlab实现GJO-CNN-LSTM-Mutilhead-Attention金豺优化算…

MongoDB自学笔记(三)

一、前文回顾 上一篇文章中我们学习了更新操作&#xff0c;以及讲解了部分的更新操作符&#xff0c;今天我们继续学习剩余的更新操作符。 二、更新操作符 1、$rename 语法&#xff1a;{ $rename: { < field1 >: < newName1 >, < field2 >: < newName2…