第G1周:生成对抗网络(GAN)入门

🍨 本文为[🔗365天深度学习训练营]内部限免文章(版权归 *K同学啊* 所有)
🍖 作者:[K同学啊]

一、理论基础
生成对抗网络(Generative Adversarial Networks, GAN)是近年来深度学习领域的一个热点方向。GAN并不指代某一个具体的神经网络,而是指一类基于博弈思想而设计的神经网络。GAN由两个分别被称为生成器(Generator)和判别器(Discriminator)的神经网络组成。其中,生成器从某种噪声分布中随机采样作为输入,输出与训练集中真实样本非常相似的人工样本;判别器的输入则为真实样本或人工样本,其目的是将人工样本与真实样本尽可能地区分出来。生成器和判别器交替运行,相互博弈,各自的能力都得到升。理想情况下,经过足够次数的博弈之后,判别器无法判断给定样本的真实性,即对于所有样本都输出50%真,50%假的判断。此时,生成器输出的人工样本已经逼真到使判别器无法分辨真假,停止博弈。这样就可以得到一个具有“伪造”真实样本能力的生成器。
1. 生成器

GANs中,生成器 G 选取随机噪声 z 作为输入,通过生成器的不断拟合,最终输出一个和真实样本尺寸相同,分布相似的伪造样本G(z)。生成器的本质是一个使用生成式方法的模型,它对数据的分布假设和分布参数进行学习,然后根据学习到的模型重新采样出新的样本。
从数学上来说,生成式方法对于给定的真实数据,首先需要对数据的显式变量或隐含变量做分布假设;然后再将真实数据输入到模型中对变量、参数进行训练;最后得到一个学习后的近似分布,这个分布可以用来生成新的数据。从机器学习的角度来说,模型不会去做分布假设,而是通过不断地学习真实数据,对模型进行修正,最后也可以得到一个学习后的模型来做样本生成任务。这种方法不同于数学方法,学习的过程对人类理解较不直观。

2. 判别器
GANs中,判别器 D 对于输入的样本 x,输出一个[0,1]之间的概率数值D(x)。x 可能是来自于原始数据集中的真实样本 x,也可能是来自于生成器 G 的人工样本G(z)。通常约定,概率值D(x)越接近于1就代表此样本为真实样本的可能性更大;反之概率值越小则此样本为伪造样本的可能性越大。也就是说,这里的判别器是一个二分类的神经网络分类器,目的不是判定输入数据的原始类别,而是区分输入样本的真伪。可以注意到,不管在生成器还是判别器中,样本的类别信息都没有用到,也表明 GAN 是一个无监督的学习过程。

3. 基本原理
GAN是博弈论和机器学习相结合的产物,于2014年Ian Goodfellow的论文中问世,一经问世即火爆足以看出人们对于这种算法的认可和狂热的研究热忱。想要更详细的了解GAN,就要知道它是怎么来的,以及这种算法出现的意义是什么。研究者最初想要通过计算机完成自动生成数据的功能,例如通过训练某种算法模型,让某模型学习过一些苹果的图片后能自动生成苹果的图片,具备些功能的算法即认为具有生成功能。但是GAN不是第一个生成算法,而是以往的生成算法在衡量生成图片和真实图片的差距时采用均方误差作为损失函数,但是研究者发现有时均方误差一样的两张生成图片效果却截然不同,鉴于此不足Ian Goodfellow提出了GAN。

image.png

那么GAN是如何完成生成图片这项功能的呢,如图1所示,GAN是由两个模型组成的:生成模型G和判别模型D。首先第一代生成模型1G的输入是随机噪声z,然后生成模型会生成一张初级照片,训练一代判别模型1D另其进行二分类操作,将生成的图片判别为0,而真实图片判别为1;为了欺瞒一代鉴别器,于是一代生成模型开始优化,然后它进阶成了二代,当它生成的数据成功欺瞒1D时,鉴别模型也会优化更新,进而升级为2D,按照同样的过程也会不断更新出N代的G和D。

二、前期准备工作

1. 定义超参数

import argparse
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch## 创建文件夹
os.makedirs("./images/", exist_ok=True)         ## 记录训练过程的图片效果
os.makedirs("./save/", exist_ok=True)           ## 训练完成时模型保存的位置
os.makedirs("./datasets/mnist", exist_ok=True)      ## 下载数据集存放的位置## 超参数配置
n_epochs=50
batch_size=512
lr=0.0002
b1=0.5
b2=0.999
n_cpu=2
latent_dim=100
img_size=28
channels=1
sample_interval=500## 图像的尺寸:(1, 28, 28),  和图像的像素面积:(784)
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)## 设置cuda:(cuda:0)
cuda = True if torch.cuda.is_available() else False
print(cuda)
## mnist数据集下载
mnist = datasets.MNIST(root='./datasets/', train=True, download=True, transform=transforms.Compose([transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),
)
## 配置数据到加载器
dataloader = DataLoader(mnist,batch_size=batch_size,shuffle=True,
)
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(img_area, 512),         # 输入特征数为784,输出为512nn.LeakyReLU(0.2, inplace=True),  # 进行非线性映射nn.Linear(512, 256),              # 输入特征数为512,输出为256nn.LeakyReLU(0.2, inplace=True),  # 进行非线性映射nn.Linear(256, 1),                # 输入特征数为256,输出为1nn.Sigmoid(),                     # sigmoid是一个激活函数,二分类问题中可将实数映射到[0, 1],作为概率值, 多分类用softmax函数)def forward(self, img):img_flat = img.view(img.size(0), -1) # 鉴别器输入是一个被view展开的(784)的一维图像:(64, 784)validity = self.model(img_flat)      # 通过鉴别器网络return validity                      # 鉴别器返回的是一个[0, 1]间的概率
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()## 模型中间块儿def block(in_feat, out_feat, normalize=True):        # block(in, out )layers = [nn.Linear(in_feat, out_feat)]          # 线性变换将输入映射到out维if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8)) # 正则化layers.append(nn.LeakyReLU(0.2, inplace=True))   # 非线性激活函数return layers## prod():返回给定轴上的数组元素的乘积:1*28*28=784self.model = nn.Sequential(*block(latent_dim, 128, normalize=False), # 线性变化将输入映射 100 to 128, 正则化, LeakyReLU*block(128, 256),                         # 线性变化将输入映射 128 to 256, 正则化, LeakyReLU*block(256, 512),                         # 线性变化将输入映射 256 to 512, 正则化, LeakyReLU*block(512, 1024),                        # 线性变化将输入映射 512 to 1024, 正则化, LeakyReLUnn.Linear(1024, img_area),                # 线性变化将输入映射 1024 to 784nn.Tanh()                                 # 将(784)的数据每一个都映射到[-1, 1]之间)## view():相当于numpy中的reshape,重新定义矩阵的形状:这里是reshape(64, 1, 28, 28)def forward(self, z):                           # 输入的是(64, 100)的噪声数据imgs = self.model(z)                        # 噪声数据通过生成器模型imgs = imgs.view(imgs.size(0), *img_shape)  # reshape成(64, 1, 28, 28)return imgs                                 # 输出为64张大小为(1, 28, 28)的图像
## 创建生成器,判别器对象
generator = Generator()
discriminator = Discriminator()## 首先需要定义loss的度量方式  (二分类的交叉熵)
criterion = torch.nn.BCELoss()## 其次定义 优化函数,优化函数的学习率为0.0003
## betas:用于计算梯度以及梯度平方的运行平均值的系数
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))## 如果有显卡,都在cuda模式中运行
if torch.cuda.is_available():generator     = generator.cuda()discriminator = discriminator.cuda()criterion     = criterion.cuda()
for epoch in range(n_epochs):                   # epoch:50for i, (imgs, _) in enumerate(dataloader):  # imgs:(64, 1, 28, 28)     _:label(64)imgs = imgs.view(imgs.size(0), -1)    # 将图片展开为28*28=784  imgs:(64, 784)real_img = Variable(imgs).cuda()      # 将tensor变成Variable放入计算图中,tensor变成variable之后才能进行反向传播求梯度real_label = Variable(torch.ones(imgs.size(0), 1)).cuda()      ## 定义真实的图片label为1fake_label = Variable(torch.zeros(imgs.size(0), 1)).cuda()     ## 定义假的图片的label为0real_out = discriminator(real_img)            # 将真实图片放入判别器中loss_real_D = criterion(real_out, real_label) # 得到真实图片的lossreal_scores = real_out                        # 得到真实图片的判别值,输出的值越接近1越好## 计算假的图片的损失## detach(): 从当前计算图中分离下来避免梯度传到G,因为G不用更新z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()      ## 随机生成一些噪声, 大小为(128, 100)fake_img    = generator(z).detach()                                    ## 随机噪声放入生成网络中,生成一张假的图片。fake_out    = discriminator(fake_img)                                  ## 判别器判断假的图片loss_fake_D = criterion(fake_out, fake_label)                       ## 得到假的图片的lossfake_scores = fake_out## 损失函数和优化loss_D = loss_real_D + loss_fake_D  # 损失包括判真损失和判假损失optimizer_D.zero_grad()             # 在反向传播之前,先将梯度归0loss_D.backward()                   # 将误差反向传播optimizer_D.step()                  # 更新参数z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()      ## 得到随机噪声fake_img = generator(z)                                             ## 随机噪声输入到生成器中,得到一副假的图片output = discriminator(fake_img)                                    ## 经过判别器得到的结果## 损失函数和优化loss_G = criterion(output, real_label)                              ## 得到的假的图片与真实的图片的label的lossoptimizer_G.zero_grad()                                             ## 梯度归0loss_G.backward()                                                   ## 进行反向传播optimizer_G.step()                                                  ## step()一般用在反向传播后面,用于更新生成网络的参数## 打印训练过程中的日志## item():取出单元素张量的元素值并返回该值,保持原元素类型不变if ( i + 1 ) % 100 == 0:print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"% (epoch, n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean()))## 保存训练过程中的图像batches_done = epoch * len(dataloader) + iif batches_done % sample_interval == 0:save_image(fake_img.data[:25], "./images/%d.png" % batches_done, nrow=5, normalize=True)
torch.save(generator.state_dict(), './generator.pth')
torch.save(discriminator.state_dict(), './discriminator.pth')

部分运行截图:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/37733.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Windows安装Go开发环境

Windows安装Go开发环境 一、Go语言下载地址 https://golang.google.cn/dl/ 二、设置工作空间GOPATH目录(Go语言开发的项目路径) 首先进入我的C盘(你放到其他盘也行),新建一个文件夹,名字叫做mygo(这个就是你的工作目…

ArcGIS Maps SDK for JavaScript系列之一:在Vue3中加载ArcGIS地图

目录 ArcGIS Maps SDK for JavaScript简介ArcGIS Maps SDK for JavaScript 4.x 的主要特点和功能AMD modules 和 ES modules两种方式比较Vue3中使用ArcGIS Maps SDK for JavaScript的步骤创建 Vue 3 项目安装 ArcGIS Maps SDK for JavaScript创建地图组件 ArcGIS Maps SDK for …

华为开源自研AI框架昇思MindSpore应用案例:基于MindSpore框架的UNet-2D案例实现

目录 一、环境准备1.进入ModelArts官网2.使用CodeLab体验Notebook实例 二、环境准备与数据读取三、模型解析Transformer基本原理Attention模块 Transformer EncoderViT模型的输入整体构建ViT 四、模型训练与推理模型训练模型验证模型推理 近些年,随着基于自注意&…

改造旧项目-长安分局人事费用管理系统

一、系统环境搭建 1、搭建前台环境 vue3vite构建项目复制“银税系统”页面结构,包括:路由、vuex存储、菜单、登录(复制一个干净的空架子) 2、搭建后台环境 新三大框架 SSMP聚合工程:common、admin,新的…

2023国赛数学建模E题思路分析

文章目录 0 赛题思路1 竞赛信息2 竞赛时间3 建模常见问题类型3.1 分类问题3.2 优化问题3.3 预测问题3.4 评价问题 4 建模资料 0 赛题思路 (赛题出来以后第一时间在CSDN分享) https://blog.csdn.net/dc_sinor?typeblog 1 竞赛信息 全国大学生数学建模…

Linux服务器上配置HTTP和HTTPS代理

本文将向你分享如何在Linux服务器上配置HTTP和HTTPS代理的方法,解决可能遇到的问题,让你的爬虫项目顺利运行,畅爬互联网! 配置HTTP代理的步骤 1. 了解HTTP代理的类型:常见的有正向代理和反向代理两种类型。根据实际需求…

涉及近300个业务场景,重庆银行数字员工平台建设解析

随着数字化转型战略规划的逐步落地,重庆银行于2022年6月成功建设了数字员工平台,该平台已成为行内数字化转型的标杆应用。数字员工平台以RPA(机器人流程自动化)为基础,AI(人工智能)技术为抓手&a…

PHP最简单自定义自己的框架view使用引入smarty(8)--自定义的框架完成

1、实现效果。引入smarty, 实现assign和 display 2、下载smarty,创建缓存目录cache和扩展extend 点击下面查看具体下载使用,下载改名后放到extend PHP之Smarty使用以及框架display和assign原理_PHP隔壁老王邻居的博客-CSDN博客 3、当前控…

leetcode 力扣刷题 旋转矩阵(循环过程边界控制)

力扣刷题 旋转矩阵 二维矩阵按圈遍历(顺时针 or 逆时针)遍历59. 旋转矩阵Ⅱ54. 旋转矩阵剑指 Offer 29. 顺时针打印矩阵 二维矩阵按圈遍历(顺时针 or 逆时针)遍历 下面的题目的主要考察点都是,二维数组从左上角开始顺…

C# Linq源码分析之Take (一)

概要 在.Net 6 中引入的Take的另一个重载方法,一个基于Range的重载方法。因为该方法中涉及了很多新的概念,所以在分析源码之前,先将这些概念搞清楚。 Take方法基本介绍 public static System.Collections.Generic.IEnumerable Take (this …

【LeetCode: 2811. 判断是否能拆分数组】

🚀 算法题 🚀 🌲 算法刷题专栏 | 面试必备算法 | 面试高频算法 🍀 🌲 越难的东西,越要努力坚持,因为它具有很高的价值,算法就是这样✨ 🌲 作者简介:硕风和炜,…

NavMeshPlus 2D寻路插件

插件地址:h8man/NavMeshPlus: Unity NavMesh 2D Pathfinding (github.com) 我对Unity官方是深恶痛觉,一个2D寻路至今都没想解决,这破引擎早点倒闭算了. 这插件是githun的开源项目,我本身是有写jps寻路的,但是无法解决多个单位互相阻挡的问题(可以解决但是有性能问…

vue3+ts使用antv/x6 + 自定义节点

使用 2.x 版本 x6.antv 新官网: 安装 npm install antv/x6 //"antv/x6": "^2.1.6",项目结构 1、初始化画布 index.vue <template><div id"container"></div> </template><script setup langts> import { onM…

Python爬虫——scrapy_基本使用

安装scrapy pip install scrapy创建scrapy项目&#xff0c;需要在终端里创建 注意&#xff1a;项目的名字开头不能是数字&#xff0c;也不能包含中文 scrapy startproject 项目名称 示例&#xff1a; scrapy startproject scra_baidu_36创建好后的文件 3. 创建爬虫文件&…

MySQL表的操作

文章目录 MySQL表的操作1. 创建表2. 查看表2.1 查看数据库中存在的表2.2 查看表的属性2.3 查看创建时表的详细信息 3. 修改表3.1 向表中添加记录3.2 添加列3.3 修改列的数据类型3.4 删除列3.5 表的重命名3.6 修改列名 4. 删除表 MySQL表的操作 1. 创建表 CREATE TABLE table_…

zabbix监控tomcat

一、zabbix监控Tomcat1.1 zbx-agent配置1.1.1 关闭防火墙&#xff0c;将安装 Tomcat 所需软件包传到/opt目录下1.1.2 安装JDK1.1.3 设置JDK环境变量1.1.4 安装启动Tomcat1.1.5 配置 JMX 1.2 zbx-server配置1.2.1 安装zabbix&#xff08;省略&#xff0c;可看上一篇博客&#xf…

Docker自动化部署安装(十)之安装SonarQube

这里选择的是&#xff1a; sonarqube:9.1.0-community (推荐使用) postgres:9.6.23 数据库(sonarqube7.9及以后便不再支持mysql&#xff0c;版本太低的话里面的一些插件会下载不成功的) 1、docker-sonarqube.yml文件 version: 3 services:sonarqube:container_name: sonar…

Redis详解

Redis 简介 Redis&#xff08;Remote Dictionary Server&#xff09;是一个开源的高性能键值对存储数据库&#xff0c;最初由 Salvatore Sanfilippo 开发&#xff0c;它在内存中存储数据&#xff0c;并提供了持久化功能&#xff0c;可以将数据保存到磁盘中&#xff0c;是一种N…

【论文阅读】DEPCOMM:用于攻击调查的系统审核日志的图摘要(SP-2022)

Xu Z, Fang P, Liu C, et al. Depcomm: Graph summarization on system audit logs for attack investigation[C]//2022 IEEE Symposium on Security and Privacy (SP). IEEE, 2022: 540-557. 1 摘要 ​ 提出了 DEPCOMM&#xff0c;这是一种图摘要方法&#xff0c;通过将大图划…

从0开始搭建ns3环境以及NetAnim简单使用

一、环境准备 ns3是基于GNU/Linux平台使用C开发的工具软件&#xff0c;在windows系统中安装使用ns3环境&#xff0c;可以使用虚拟机VMware并安装ubuntu系统来实现&#xff0c;现将本教程所用到的虚拟机和系统镜像放到网盘提供下载 名称链接提取码VMware Workstation 17 Proht…