使用pytorch构建一个初级的无监督的GAN网络模型

在这个系列中将系统的构建GAN及其相关的一些变种模型,来了解GAN的基本原理。本片为此系列的第一篇,实现起来很简单,所以不要期待有很好的效果出来。

第一篇我们搭建一个无监督的可以生成数字 (0-9) 手写图像的 GAN,使用MINIST数据集,包含0-9的60000张手写数字图像,如图:
在这里插入图片描述

原理

首先简单讲一下GAN的工作原理,如下为前向传播的过程:
在这里插入图片描述
GAN网络有两个模型,分别是生成器generator和判别器discriminator。generator的作用是生成图片的,也就是我们想要的结果,通过输入随机噪声来生成图片;而discriminator是判断输入的图片是真实数据还是生成的假数据,输入生成的假数据或真实数据,输出真与假的概率值。

而反向传播过程其实是分开的,即generator和discriminator是分别进行梯度更新的。且交替进行训练的,一个模型训练,另一个模型就要保持不变,保持两个模型的能力要相当才能一起进步,否则如果判别器的性能要比生成器要好的话就很容易陷入模式崩溃mdoel collapse或梯度消失等。
下图为discriminator的反向传播的过程:
在这里插入图片描述
discriminator的工作是为了将生成的假数据判别为0,将真实的数据判别为1,即公正判别非黑即白,所以loss的计算为:
在这里插入图片描述

下图为generator的反向传播的过程:
在这里插入图片描述
而generator的工作是为了将生成的假数据让discriminator判别为1,即骗过discriminator颠倒黑白,所以loss的计算为:
在这里插入图片描述

代码

下面开始直接上代码,我在网上学习别人代码的习惯是先把所有代码跑起来再来仔细看每个代码模块,我在这也就先放上所有代码再分析各个模块。
model.py:

from torch import nn
import torchdef get_generator_block(input_dim, output_dim):return nn.Sequential(nn.Linear(input_dim, output_dim),nn.BatchNorm1d(output_dim),nn.ReLU(inplace=True),)class Generator(nn.Module):def __init__(self, z_dim=10, im_dim=784, hidden_dim=128):super(Generator, self).__init__()self.gen = nn.Sequential(get_generator_block(z_dim, hidden_dim),get_generator_block(hidden_dim, hidden_dim * 2),get_generator_block(hidden_dim * 2, hidden_dim * 4),get_generator_block(hidden_dim * 4, hidden_dim * 8),nn.Linear(hidden_dim * 8, im_dim),nn.Sigmoid())def forward(self, noise):return self.gen(noise)def get_gen(self):return self.gendef get_discriminator_block(input_dim, output_dim):return nn.Sequential(nn.Linear(input_dim, output_dim), #Layer 1nn.LeakyReLU(0.2, inplace=True))class Discriminator(nn.Module):def __init__(self, im_dim=784, hidden_dim=128):super(Discriminator, self).__init__()self.disc = nn.Sequential(get_discriminator_block(im_dim, hidden_dim * 4),get_discriminator_block(hidden_dim * 4, hidden_dim * 2),get_discriminator_block(hidden_dim * 2, hidden_dim),nn.Linear(hidden_dim, 1))def forward(self, image):return self.disc(image)def get_disc(self):return self.disc

train.py:

import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST # Training dataset
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from model import Discriminator, Generator
torch.manual_seed(0) # Set for testing purposes, please do not change!def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):image_unflat = image_tensor.detach().cpu().view(-1, *size)image_grid = make_grid(image_unflat[:num_images], nrow=5)plt.imshow(image_grid.permute(1, 2, 0).squeeze())plt.show()def get_noise(n_samples, z_dim, device='cpu'):return torch.randn(n_samples,z_dim,device=device)criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.00001
device = 'cuda'dataloader = DataLoader(MNIST('./', download=True, transform=transforms.ToTensor()),  # 已经下载过的可以改为False跳过下载batch_size=batch_size,shuffle=True)gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator().to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)def get_disc_loss(gen, disc, criterion, real, num_images, z_dim, device):fake_noise = get_noise(num_images, z_dim, device=device)fake = gen(fake_noise)disc_fake_pred = disc(fake.detach())disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))disc_real_pred = disc(real)disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))disc_loss = (disc_fake_loss + disc_real_loss) / 2return disc_lossdef get_gen_loss(gen, disc, criterion, num_images, z_dim, device):fake_noise = get_noise(num_images, z_dim, device=device)fake = gen(fake_noise)disc_fake_pred = disc(fake)gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))return gen_losscur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
gen_loss = False
error = False
for epoch in range(n_epochs):# Dataloader returns the batchesfor real, _ in tqdm(dataloader):cur_batch_size = len(real)# Flatten the batch of real images from the datasetreal = real.view(cur_batch_size, -1).to(device)### Update discriminator #### Zero out the gradients before backpropagationdisc_opt.zero_grad()# Calculate discriminator lossdisc_loss = get_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim, device)# Update gradientsdisc_loss.backward(retain_graph=True)# Update optimizerdisc_opt.step()### Update generator ###gen_opt.zero_grad()gen_loss = get_gen_loss(gen, disc, criterion, cur_batch_size, z_dim, device)gen_loss.backward()gen_opt.step()# Keep track of the average discriminator lossmean_discriminator_loss += disc_loss.item() / display_step# Keep track of the average generator lossmean_generator_loss += gen_loss.item() / display_step### Visualization code ###if cur_step % display_step == 0 and cur_step > 0:print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")fake_noise = get_noise(cur_batch_size, z_dim, device=device)fake = gen(fake_noise)show_tensor_images(fake)show_tensor_images(real)mean_generator_loss = 0mean_discriminator_loss = 0cur_step += 1

运行结果

运行后每隔500个epoch画出fake和real,刚开始的fake和real是这样的:
在这里插入图片描述
在这里插入图片描述
到后面的fake逐渐变成这样:
在这里插入图片描述

代码解释

网络模型

model.py里面存放了generator和discriminator的网络模型,神经元使用的是简单的全连接层,后面的文章再使用卷积。
在这里插入图片描述
生成器输出为784 = 28 * 28,因为使用的是MINIST手写字体数据集,每张图的shape是28 * 28 * 1(黑白图单通道),所以输出的假数据要与真实数据的shape一致,这样输入鉴别器才不会出错。
在这里插入图片描述
生成的图片(或真实数据)直接输入鉴别器,所以鉴别器的输入也是28*28,而输出为1,即输出判别结果为真或假。
在这里插入图片描述
每个优化器仅优化一个模型的参数,所以一个模型构建一个优化器。

图像显示

在这里插入图片描述
首先将图像的tensor转到cpu上,因为PyTorch中的大部分图像处理和显示函数都是在CPU上执行的,包括我们使用的imshow。
detach() 方法将张量从计算图中分离出来,但是仍指向原变量的存放位置,不同之处只是requirse_grad为false,得到的这个tensor永远不需要计算器梯度,不具有grad,这样做的目的是避免梯度计算的影响,因为在展示图像时通常不需要计算梯度。
Pytorch的计算图由节点和边组成,节点表示张量或者Function,边表示张量和Function之间的依赖关系,类似这样:
在这里插入图片描述
一个网络模型就是一个计算图,在网络backward时候,需要用链式求导法则求出网络最后输出的梯度,然后再对网络进行优化,求导过程就如上图这样。
make_grid 函数用于将多个图像组成一个网格,方便显示。
在这里插入图片描述

在这里插入图片描述
然后每500个batch显示一次当前模型性能所能生成的图片以及当前batch的真实图片(虽然一个batch设置了128张,但是我们只展示25张),以及print出生成器和鉴别器的loss。

损失函数

在这里插入图片描述
损失函数的原理在上面的“原理”中有讲解,这里不再赘述。
在计算鉴别器的loss里,disc_fake_pred = disc(fake.detach())是对生成图片的判别,这里也使用 .detach() 的目的是将生成器产生的假数据与生成器的参数分离,使得在计算 disc_fake_pred 时不会对生成器的梯度进行传播。这是因为在训练鉴别器的阶段,我们只希望更新鉴别器的参数,而不希望更新生成器的参数(就如上面说的生成器的训练和鉴别器的应该要隔开分别训练、交替训练)。

反向传播

在这里插入图片描述
retain_graph=True参数是用来指示 PyTorch 在反向传播时保留计算图。这个参数的作用是为了在一次反向传播之后保留计算图的状态,以便后续再次调用 backward() 函数时能够继续使用这个计算图进行梯度计算。
Pytoch构建的计算图是动态图,为了节约内存,所以每次一轮迭代完之后计算图就被在内存释放,所以当你想要多次backward时候就会报如下错:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed.

而在GAN中一次的迭代需要先更新鉴别器的参数,然后再更新生成器的参数;在更新生成器的参数时,我们仍然需要使用鉴别器来鉴别real or fake,只要使用到鉴别器就需要他的计算图。因此,我们需要在调用 disc_loss.backward() 后保留计算图,以便后续调用 gen_loss.backward() 时能够继续使用相同的计算图进行梯度计算。而对于生成器的梯度更新 gen_loss.backward(),不需要显式指定 retain_graph=True。
所以,在同一个计算图上多次调用 backward() 函数时才需要使用它。

下一篇构建DCGAN。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/777776.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

入门指南|营销中人工智能生成内容的主要类型 [新数据、示例和技巧]

由于人工智能技术的进步,内容生成不再是一项令人头疼的任务。随着人工智能越来越多地接管手动内容制作任务,营销人员明智的做法是了解现有的不同类型的人工智能生成内容,以及哪些内容从中受益最多。这些工具可以帮助我们制作对您的受众和品牌…

Synchronized锁、公平锁、悲观锁乐观锁、死锁等

悲观锁 认为自己在使用数据的时候一定会有别的线程来修改数据,所以在获取数据前会加锁,确保不会有别的线程来修改 如: Synchronized和Lock锁 适合写操作多的场景 乐观锁 适合读操作多的场景 总结: 线程8锁🔐 调用 声明 结果:先打印发送短信,后打印发送邮件 结论…

机器学习(三)

神经网络: 神经网络是由具有适应性的简单单元组成的广泛并行互连的网络,它的组织能够模拟生物神经系统对真实世界物体所作出的交互反应。 f为激活(响应)函数: 理想激活函数是阶跃函数,0表示抑制神经元而1表示激活神经元。 多层前馈网络结构: BP(误差逆…

OpenPLC_Editor 在Ubuntu 虚拟机安装记录

1. OpenPLC_Editor在虚拟机上费劲的装了一遍,有些东西已经忘了,主要还是python3 的缺失库版本对应问题,OpenPLC_Editor使用python3编译的,虚拟机的Ubuntu 18.4 有2.7和3.6两个版本,所以需要注意。 2. OpenPLC_Editor …

Svg Flow Editor 原生svg流程图编辑器(四)

系列文章 Svg Flow Editor 原生svg流程图编辑器(一) Svg Flow Editor 原生svg流程图编辑器(二) Svg Flow Editor 原生svg流程图编辑器(三) Svg Flow Editor 原生svg流程图编辑器(四&#xf…

贪心算法--最大数

个人主页:Lei宝啊 愿所有美好如期而遇 本题链接https://leetcode.cn/problems/largest-number/description/ class Solution { public:bool static compare(int a, int b){return (to_string(a) to_string(b)) > (to_string(b) to_string(a));}bool operato…

探索 2024 年 Web 开发最佳前端框架

前端框架通过简化和结构化的网站开发过程改变了 Web 开发人员设计和实现用户界面的方法。随着 Web 应用程序变得越来越复杂,交互和动画功能越来越多,这是开发前端框架的初衷之一。 在网络的早期,网页相当简单。它们主要以静态 HTML 为特色&a…

数据库---PDO

以pikachu数据库为例&#xff0c;数据库名&#xff1a; pikachu 1.连接数据库 <?php $dsn mysql:hostlocalhost; port3306; dbnamepikachu; // 这里的空格比较敏感 $username root; $password root; try { $pdo new PDO($dsn, $username, $password); var_dump($pdo)…

【管理咨询宝藏59】某大型汽车物流战略咨询报告

本报告首发于公号“管理咨询宝藏”&#xff0c;如需阅读完整版报告内容&#xff0c;请查阅公号“管理咨询宝藏”。 【管理咨询宝藏59】某大型汽车物流战略咨询报告 【格式】PDF 【关键词】HR调研、商业分析、管理咨询 【核心观点】 - 重新评估和调整商业模式&#xff0c;开拓…

如何开始定制你自己的大型语言模型

2023年的大型语言模型领域经历了许多快速的发展和创新&#xff0c;发展出了更大的模型规模并且获得了更好的性能&#xff0c;那么我们普通用户是否可以定制我们需要的大型语言模型呢&#xff1f; 首先你需要有硬件的资源&#xff0c;对于硬件来说有2个路径可以选。高性能和低性…

StatefulBuilder 和 Builder

前言 果然了解的越多&#xff0c;越发现自己狗屁都不是。StatefulBuilder 和 Builder 之前真的不知道。还是在 对话框状态管理 中了解到了这两个东西。 简介 以下内容来自通义灵码 在Flutter中&#xff0c;StatefulBuilder 和 Builder 都是用来动态构建 widget 树的组件&am…

使用unplugin-auto-import页面不引入api飘红

解决方案&#xff1a;. tsconfig.json文件夹加上 {"compilerOptions": {"target": "ES2020","useDefineForClassFields": true,"module": "ESNext","lib": ["ES2020", "DOM", &q…

Mybatis别名 动态sql语句 分页查询

给Mybatis的实体类起别名 给Mybatis的xml文件注册mapper映射文件 动态sql语句 1 if 2 choose 3 where 4 foreach 一&#xff09;if 查询指定名称商品信息 语法&#xff1a; SELECT * FROM goods where 11 <if test "gName!null"> and g.g_name like co…

Intellij IDEA安装配置Spark与运行

目录 Scala配置教程 配置Spark运行环境 编写Spark程序 1、包和导入 2、定义对象 3、主函数 4、创建Spark配置和上下文 5、定义输入文件路径 6、单词计数逻辑 7、输出结果 8、完整代码&#xff1a; Scala配置教程 IDEA配置Scala&#xff1a;教程 配置Spark运行环境 …

Untiy 布局控制器Aspect Ratio Fitter

Aspect Ratio Fitter是Unity中的一种布局控制器组件&#xff0c;用于根据指定的宽高比来调整包含它的UI元素的大小。实际开发中&#xff0c;它可以确保UI元素保持特定的宽高比&#xff0c;无论UI元素的内容或父容器的大小如何变化。 如图为Aspect Ratio Fitter组件的基本属性&…

自然语言处理(NLP)全面指南

自然语言处理&#xff08;NLP&#xff09;是人工智能领域中最热门的技术之一&#xff0c;它通过构建能够理解和生成人类语言的机器&#xff0c;正在不断推动技术的发展。本文将为您提供NLP的全面介绍&#xff0c;包括其定义、重要性、应用场景、工作原理以及面临的挑战和争议。…

Python图像处理——计算机视觉中常用的图像预处理

概述 在计算机视觉项目中&#xff0c;使用样本时经常会遇到图像样本不统一的问题&#xff0c;比如图像质量&#xff0c;并非所有的图像都具有相同的质量水平。在开始训练模型或运行算法之前&#xff0c;通常需要对图像进行预处理&#xff0c;以确保获得最佳的结果。图像预处理…

typescript 实现RabbitMQ死信队列和延迟队列 订单10分钟未付归还库存

Manjaro安装RabbitMQ 安装 sudo pacman -S rabbitmq rabbitmqadmin启动管理模块 sudo rabbitmq-plugins enable rabbitmq_managementsudo rabbitmq-server管理界面 http://127.0.0.1:15672/ 默认用户名和密码都是guest。 要使用 rabbitmqctl 命令添加用户并分配权限&#xf…

怎样去保证 Redis 缓存与数据库双写一致性?

解决方案 那么我们这里列出来所有策略&#xff0c;并且讨论他们优劣性。 先更新数据库&#xff0c;后更新缓存先更新数据库&#xff0c;后删除缓存先更新缓存&#xff0c;后更新数据库先删除缓存&#xff0c;后更新数据库 先更新数据库&#xff0c;后更新缓存 这种方法是不推…

在scroll-view中使用input,input键盘弹出时,滚动页面,输入框内容会出现错位问题?

解决办法 <view classpages><view><scroll-view scroll-y"{{sysScroll}}" scroll-top"{{scrollTop}}" class"scroll-hei-2 bg-def">...<input bindfocus"onfocus" bindblur"onblur" placeholder&quo…