Paddle 基于ANN(全连接神经网络)的GAN(生成对抗网络)实现

什么是GAN

GAN是生成对抗网络,将会根据一个随机向量,实现数据的生成(如生成手写数字、生成文本等)。

GAN的训练过程中,需要有一个生成器G和一个鉴别器D.

生成器用于生成数据,鉴定器用于鉴定数据的准确性,其实就是在鉴别数据是人生成的还是机器生成的,因为生成器需要以假乱真。

鉴别器将会与生成器一起训练。鉴别器将会先训练,这样才有适当的能力去鉴定生成器生成数据的准确性。

鉴别器的训练过程中,需要先给它准确的数据,和通过随机向量传入生成器产生的数据(一律视为负样本),并通过损失函数对其进行训练;生成器训练过程中,会先给它一个随机向量进行前向传播,然后让鉴别器判断其正确性,并通过损失函数(不正确的数据意味着有损失)进行训练:

生成器训练过程中,需要先通过随机向量获取其结果,然后让鉴别器进行鉴别,在通过鉴别器的鉴别结果计算损失(如果鉴别器认为这是生成器生成的,则产生损失),最后更新梯度和参数:

训练过程直到生成器拟合训练集(收敛),判别器的输出总是0.5(均方误差损失函数应为0.25)为止.

形象的GAN的例子

想象一场由一位“名画伪造者”和一位“艺术鉴定家”参与的猫捉老鼠游戏。

在这个场景中,名画伪造者(即GAN中的生成器)的目标是创造出一幅足以欺骗艺术鉴定家(即GAN中的判别器)的假画。开始时,伪造者的技艺并不精湛,他制作的假画充满了破绽,很容易被鉴定家一眼识破。

然而,随着伪造者不断尝试和失败,他逐渐从每一次的失败中学习,逐渐提升了自己的技艺。他开始注意到真画的每一个细节,从笔触、色彩到构图,都尽量模仿得惟妙惟肖。每一次的失败都让他更接近成功,他制作的假画也越来越难以辨别真伪。

而艺术鉴定家也不甘示弱。他开始时能够轻易地识别出伪造者的假画,但随着伪造者技艺的提升,他也需要不断提升自己的鉴定能力。他开始深入研究真画的每一个特点,以便更准确地识别出伪造者的假画。

这个过程就像GAN中的训练过程一样。生成器不断尝试生成新的数据(在这里是假画),而判别器则不断尝试区分这些数据是真实的还是生成的。两者在相互竞争的过程中不断提升自己的能力,最终达到了一个平衡状态。

在这个例子中,名画伪造者就是GAN中的生成器,他负责生成新的数据;而艺术鉴定家则是GAN中的判别器,他负责区分数据的真伪。两者在相互竞争的过程中共同进步,使得生成的数据越来越接近真实的数据。

代码实现

本文将以基于MNIST(手写数据集)为数据集,实现一个生成手写数字的GAN模型:

首先创建models.py,用于定义判别器和生成器:

import paddle# Generator Code
class Generator(paddle.nn.Layer):def __init__(self, ):super(Generator, self).__init__()self.gen = paddle.nn.Sequential(paddle.nn.Linear(in_features=100, out_features=256),paddle.nn.ReLU(True),paddle.nn.Linear(in_features=256, out_features=512),paddle.nn.ReLU(True),paddle.nn.Linear(in_features=512, out_features=1024),paddle.nn.Tanh(),)def forward(self, x):x = self.gen(x)out = paddle.reshape(x,[-1,1,32,32])return out# Discriminator Code
class Discriminator(paddle.nn.Layer):def __init__(self, ):super(Discriminator, self).__init__()self.dis = paddle.nn.Sequential(paddle.nn.Linear(in_features=1024, out_features=512),paddle.nn.LeakyReLU(0.2),paddle.nn.Linear(in_features=512, out_features=256),paddle.nn.LeakyReLU(0.2),paddle.nn.Linear(in_features=256, out_features=1),paddle.nn.Sigmoid())def forward(self, x):x = paddle.reshape(x, [-1, 1024])out = self.dis(x)return out

其中,生成器将接收一个长度为100的张量(随机向量),输出一个长度为1024的张量(生成的图片);鉴别器将接收一个长度为1024的张量(图片) ,输出长度为1的张量(鉴别结果)

然后创建main.py,用于训练:

import paddle
import matplotlib.pyplot as plt
from models import Generator, Discriminator
import numpy as npdataset = paddle.vision.datasets.MNIST(mode='train',transform=paddle.vision.transforms.Compose([paddle.vision.transforms.Resize((32, 32)),paddle.vision.transforms.Normalize([0], [255])]))dataloader = paddle.io.DataLoader(dataset, batch_size=32, shuffle=True)netG = Generator()
netD = Discriminator()if 1:try:mydict = paddle.load('generator.params')netG.set_dict(mydict)mydict = paddle.load('discriminator.params')netD.set_dict(mydict)except:print('fail to load model')optimizerD = paddle.optimizer.Adam(parameters=netD.parameters(), learning_rate=0.0002, beta1=0.5, beta2=0.999)
optimizerG = paddle.optimizer.Adam(parameters=netG.parameters(), learning_rate=0.0002, beta1=0.5, beta2=0.999)# 最大迭代epoch
max_epoch = 10for epoch in range(max_epoch):now_step = 0for step, (data, label) in enumerate(dataloader):############################# (1) 更新鉴别器############################ 清除D的梯度optimizerD.clear_grad()# 传入正样本,并更新梯度pos_img = datalabel = paddle.full([pos_img.shape[0], 1], 1, dtype='float32')pre = netD(pos_img)loss_D_1 = paddle.nn.functional.mse_loss(pre, label)loss_D_1.backward()# 通过randn构造随机数,制造负样本,并传入D,更新梯度noise = paddle.randn([pos_img.shape[0], 100], 'float32')neg_img = netG(noise)label = paddle.full([pos_img.shape[0], 1], 0, dtype='float32')pre = netD(neg_img.detach())  # 通过detach阻断网络梯度传播,不影响G的梯度计算loss_D_2 = paddle.nn.functional.mse_loss(pre, label)loss_D_2.backward()# 更新D网络参数optimizerD.step()optimizerD.clear_grad()loss_D = loss_D_1 + loss_D_2############################# (2) 更新生成器############################ 清除D的梯度optimizerG.clear_grad()noise = paddle.randn([pos_img.shape[0], 100], 'float32')fake = netG(noise)label = paddle.full((pos_img.shape[0], 1), 1, dtype=np.float32, )output = netD(fake)# 这个写法没有问题,因为这个mse_loss既会影响到netG(output=netD(netG(noise)))的梯度,也会影响到netD的梯度,但是之后的代码并没有更新netD的参数,而循环开头就清除了netD的梯度loss_G = paddle.nn.functional.mse_loss(output, label)loss_G.backward()# 更新G网络参数optimizerG.step()optimizerG.clear_grad()now_step += 1############################ 输出日志###########################if now_step % 100 == 0:print(f'Epoch ID={epoch} Batch ID={now_step} \n\n D-Loss={float(loss_D)} G-Loss={float(loss_G)}')paddle.save(netG.state_dict(), "generator.params")
paddle.save(netD.state_dict(), "discriminator.params")

如果是第一次训练或不使用原有训练参数,可以将if 1改成if 0.

接下来创建use.py,用于生成图片:

import paddle
from models import Generator
import matplotlib.pyplot as pltimport paddle
from models import Generator
import matplotlib.pyplot as plt
import numpy as np# 加载模型
netG = Generator()
mydict = paddle.load('generator.params')
netG.set_dict(mydict)# 设置matplotlib的显示环境
fig, axs = plt.subplots(nrows=2, ncols=5, figsize=(15, 6))  # 创建一个2x5的子图网格# 生成10个噪声向量
for i, ax in enumerate(axs.flatten()):noise = paddle.randn([1, 100], 'float32')img = netG(noise)img = img.numpy()[0][0]  # img.numpy():张量转np数组img[img < 0] = 0  # 将img中所有小于0的元素赋值为0img = np.clip(img, 0, 1)  # 将img中所有小于0的元素设为0,大于1的设为1(如果需要)# 显示图片ax.imshow(img)ax.axis('off')  # 不显示坐标轴# 显示图像
plt.show()

进行多轮训练后,生成结果:

 

可以看到,它很好的生成了我们想要的图片。

GANs

但是,我们这个模型只能随机产生数字,还不能生成指定的数字(如让机器生成一个1).为了解决这个问题,我们可以针对每一个数字生成一个对应的GAN,所有这样的GAN组合起来,就是GANs. 这里不展开讲解。

参考

MNIST数据集下用Paddle框架的动态图模式玩耍经典对抗生成网络(GAN)-使用文档-PaddlePaddle深度学习平台

【飞桨PaddlePaddle】四天搞懂生成对抗网络(一)——通俗理解经典GAN_四天搞懂生成对抗网络(一)-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/834487.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

通过Docker Compose部署GitLab和GitLab Runner(一)

GitLab 是一个用于版本控制、项目管理和持续集成的开源软件平台&#xff0c;它提供了一整套工具&#xff0c;能够帮助团队高效地协作开发。而 GitLab Runner 则是 GitLab CI/CD 的执行者&#xff0c;用于运行持续集成和持续交付任务。 在本文中&#xff0c;我们将使用 Docker …

实现C++ Vector

手写C Vector&#xff0c;参考QVector 类声明 template<typename T >class IteratorVector;template<typename T >class IteratorVectorConst;template<typename T >class Vector final :public ContainerBase{public:explicit Vector()noexcept;explicit V…

【SpringBoot】-- 监听容器事件、Bean的前后置事件

目录 一、ApplicationContextInitializer 使用 1、自定义类&#xff0c;实现ApplicationContextInitializer接口 2、在META-INF/spring.factories配置文件中配置自定义类 二、ApplicationListener 使用 1、自定义类&#xff0c;实现ApplicationListener接口 2、在META-…

Selenium 自动化 —— 常用的定位器(Locator)

什么是定位器 定位器&#xff08;Locator&#xff09;是识别DOM中一个或多个特定元素的方法。 也可以叫选择器 Selenium 通过By类&#xff0c;提供了常见的定位器。具体语法如下&#xff1a; By.xxx("");我们选择单个元素时可以使用findByElement&#xff1a; Web…

三.Django--ORM(操作数据库)

目录 1 什么是ORM 1.1 ORM优势 1.2ORM 劣势 1.3 ORM与数据库的关系 2 ORM 2.1 作用 2.2 连接数据库 2.3 表操作--设置字段 2.4 数据库的迁移 写路由增删改查操作 项目里的urls.py: app里的views.py: 注意点: 1 什么是ORM ORM中文---对象-关系映射 在MTV,MVC设计…

Android 的 Timer 和 TimerTask

Timer 简介(来自Gemini) Timer 是 Java 中用于创建定时任务的类。它位于 java.util 包中。可以使用 Timer 来安排一次性或定期执行的任务。 每个 Timer 对象都对应一个后台线程。此线程负责从任务队列中检索任务并按计划执行它们。 使用 Timer 要使用 Timer&#xff0c;首先…

政安晨:【Keras机器学习示例演绎】(三十九)—— 使用 FNet 进行文本分类

目录 简介 模型 设置 加载数据集 对数据进行标记 格式化数据集 建立模型 训练我们的模型 与变换器模型比较 政安晨的个人主页&#xff1a;政安晨 欢迎 &#x1f44d;点赞✍评论⭐收藏 收录专栏: TensorFlow与Keras机器学习实战 希望政安晨的博客能够对您有所裨益&…

IAM帮你破解密码管理难题

面对让人头疼的密码管理难题&#xff0c;全球的政府、企业、安全厂商都在想办法。比如通过定期改密、强制添加特殊字符等方式提升密码强度&#xff0c;比如开展网络安全教育、提升员工对密码的重视程度&#xff0c;并且取得了显著效果。但是&#xff0c;想从根本上解决密码管理…

第二证券|炒股是波段好还是长期好?

炒股长时间比波段好一些&#xff0c;其原因如下&#xff1a; 1、长时间持有费用低 投资者在生意过程中&#xff0c;需求交纳必定的佣金费用、过户费用、印花税&#xff0c;而长时间持有股票&#xff0c;减少生意次数&#xff0c;能够节省一笔生意成本。 2、短期持有容易卖飞…

3.ERC4626

ERC4626是一个vault&#xff0c;在DAI中&#xff0c;使用ETH换取DAI。其流程为先充值ETH到maker vault。 Vault 资产的管理、分红用户充值某项资产获取某个凭证该凭证作为分红、推出的依据Yield Farming/借贷/质押等 以太坊改进提案EIP:ethereum improvemwnt proposal 最初E…

7.基于麻雀搜索算法(SSA)优化VMD参数(SSA-VMD)

01.智能优化算法优化VMD参数的使用说明 02.基本原理 麻雀搜索算法&#xff08;SSA&#xff09;是一种基于鸟类觅食行为的启发式优化算法&#xff0c;它模拟了麻雀在觅食时的群体行为&#xff0c;通过模拟麻雀的觅食过程来寻找问题的最优解。SSA的基本原理是通过模拟麻雀的搜索…

视频监控平台:交通运输标准JTT808设备SDK接入源代码函数分享

目录 一、JT/T 808标准简介 &#xff08;一&#xff09;概述 &#xff08;二&#xff09;协议特点 1、通信方式 2、鉴权机制 3、消息分类 &#xff08;三&#xff09;协议主要内容 1、位置信息 2、报警信息 3、车辆控制 4、数据转发 二、代码和解释 &#xff08;一…

《ESP8266通信指南》13-Lua 简单入门(打印数据)

往期 《ESP8266通信指南》12-Lua 固件烧录-CSDN博客 《ESP8266通信指南》11-Lua开发环境配置-CSDN博客 《ESP8266通信指南》10-MQTT通信&#xff08;Arduino开发&#xff09;-CSDN博客 《ESP8266通信指南》9-TCP通信&#xff08;Arudino开发&#xff09;-CSDN博客 《ESP82…

AJAX知识点(前后端交互技术)

原生AJAX AJAX全称为Asynchronous JavaScript And XML,就是异步的JS和XML&#xff0c;通过AJAX可以在浏览器中向服务器发送异步请求&#xff0c;最大的优势&#xff1a;无需刷新就可获取数据。 AJAX不是新的编程语言&#xff0c;而是一种将现有的标准组合在一起使用的新方式 …

C语言【文件操作 2】

文章目录 前言顺序读写函数的介绍fputc && fgetcfputcfgetc fputs && fgetsfputsfgets fprintf && fscanffprintffscanf fwrite && freadfwritefread 文件的随机读写fseek函数偏移量ftell函数rewind函数 文件的结束判断被错误使用的feof 结语 …

Linux与windows网络管理

文章目录 一、TCP/IP1.1、TCP/IP概念TCP/IP是什么TCP/IP的作用TCP/IP的特点TCP/IP的工作原理 1.2、TCP/IP网络发展史1.3、OSI网络模型1.4、TCP/IP网络模型1.5、linux中配置网络网络配置文件位置DNS配置文件主机名配置文件常用网络查看命令 1.6、windows中配置网络CMD中网络常用…

认识卷积神经网络

我们现在开始了解卷积神经网络&#xff0c;卷积神经网络是深度学习在计算机视觉领域的突破性成果&#xff0c;在计算机视觉领域&#xff0c;往往我们输入的图像都很大&#xff0c;使用全连接网络的话&#xff0c;计算的代价较高&#xff0c;图像也很难保留原有的特征&#xff0…

python 和 MATLAB 都能绘制的母亲节花束!!

hey 母亲节快到了&#xff0c;教大家用python和MATLAB两种语言绘制花束~这段代码是我七夕节发的&#xff0c;我对代码进行了简化&#xff0c;同时自己整了个python版本 MATLAB 版本代码 function roseBouquet_M() % author : slandarer% 生成花朵数据 [xr,tr]meshgrid((0:24).…

jQuery-1.语法、选择器、节点操作

jQuery jQueryJavaScriptQuery&#xff0c;是一个JavaScript函数库&#xff0c;为编写JavaScript提供了更高效便捷的接口。 jQuery安装 去官网下载jQuery&#xff0c;1.x版本练习就够用 jQuery引用 <script src"lib/jquery-1.11.2.min.js"></script>…