PyTorch训练简单的生成对抗网络GAN

文章目录

    • 原理
    • 代码
    • 结果
    • 参考

原理

同时训练两个网络:辨别器Discriminator 和 生成器Generator
Generator是 造假者,用来生成假数据。
Discriminator 是警察,尽可能的分辨出来哪些是造假的,哪些是真实的数据。

目的:使得判别模型尽量犯错,无法判断数据是来自真实数据还是生成出来的数据。

GAN的梯度下降训练过程:

在这里插入图片描述
上图来源:https://arxiv.org/abs/1406.2661

Train 辨别器: m a x max max l o g ( D ( x ) ) + l o g ( 1 − D ( G ( z ) ) ) log(D(x)) + log(1 - D(G(z))) log(D(x))+log(1D(G(z)))

Train 生成器: m i n min min l o g ( 1 − D ( G ( z ) ) ) log(1-D(G(z))) log(1D(G(z)))

我们可以使用BCEloss来计算上述两个损失函数

BCEloss的表达式: m i n − [ y l n x + ( 1 − y ) l n ( 1 − x ) ] min -[ylnx + (1-y)ln(1-x)] min[ylnx+(1y)ln(1x)]
具体过程参加代码中注释

代码

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter  # to print to tensorboardclass Discriminator(nn.Module):def __init__(self, img_dim):super(Discriminator, self).__init__()self.disc = nn.Sequential(nn.Linear(img_dim, 128),nn.LeakyReLU(0.1),nn.Linear(128, 1),nn.Sigmoid(),)def forward(self, x):return self.disc(x)class Generator(nn.Module):def __init__(self, z_dim, img_dim): # z_dim 噪声的维度super(Generator, self).__init__()self.gen = nn.Sequential(nn.Linear(z_dim, 256),nn.LeakyReLU(0.1),nn.Linear(256, img_dim), # 28x28 -> 784nn.Tanh(),)def forward(self, x):return self.gen(x)# Hyperparameters
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 3e-4 # 3e-4是Adam最好的学习率
z_dim = 64 # 噪声维度
img_dim = 784 # 28x28x1
batch_size = 32
num_epochs = 50disc = Discriminator(img_dim).to(device)
gen = Generator(z_dim, img_dim).to(device)fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = transforms.Compose( # MNIST标准化系数:(0.1307,), (0.3081,)[transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081,))] # 不同数据集就有不同的标准化系数
)dataset = datasets.MNIST(root='dataset/', transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
# BCE 损失
criterion = nn.BCELoss()# 打开tensorboard:在该目录下,使用 tensorboard --logdir=runs
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")
step = 0for epoch in range(num_epochs):for batch_idx, (real, _) in enumerate(loader):real = real.view(-1, 784).to(device) # view相当于reshapebatch_size = real.shape[0]### Train Discriminator: max log(D(real)) + log(1 - D(G(z)))noise = torch.randn(batch_size, z_dim).to(device)fake = gen(noise) # G(z)disc_real = disc(real).view(-1) # flatten# BCEloss的表达式:min -[ylnx + (1-y)ln(1-x)]# max log(D(real)) 相当于 min -log(D(real))# ones_like:1填充得到y=1, 即可忽略  min -[ylnx + (1-y)ln(1-x)]中的后一项# 得到 min -lnx,这里的x就是我们的real图片lossD_real = criterion(disc_real, torch.ones_like(disc_real))disc_fake = disc(fake).view(-1)# max log(1 - D(G(z))) 相当于 min -log(1 - D(G(z)))# zeros_like用0填充,得到y=0,即可忽略  min -[ylnx + (1-y)ln(1-x)]中的前一项# 得到 min -ln(1-x),这里的x就是我们的fake噪声lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))lossD = (lossD_real + lossD_fake) / 2disc.zero_grad()lossD.backward(retain_graph=True)opt_disc.step()### Train Generator: min log(1-D(G(z))) <--> max log(D(G(z))) <--> min - log(D(G(z)))# 依然可使用BCEloss来做output = disc(fake).view(-1)lossG = criterion(output, torch.ones_like(output))gen.zero_grad()lossG.backward()opt_gen.step()if batch_idx == 0:print(f"Epoch [{epoch}/{num_epochs}] \ "f"Loss D: {lossD:.4f}, Loss G: {lossG:.4f}")with torch.no_grad():fake = gen(fixed_noise).reshape(-1, 1, 28, 28)data = real.reshape(-1, 1, 28, 28)img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)img_grid_real = torchvision.utils.make_grid(data, normalize=True)writer_fake.add_image("Mnist Fake Images", img_grid_fake, global_step=step)writer_real.add_image("Mnist Real Images", img_grid_real, global_step=step)step += 1

结果

训练50轮的的损失

Epoch [0/50] \ Loss D: 0.7366, Loss G: 0.7051
Epoch [1/50] \ Loss D: 0.2483, Loss G: 1.6877
Epoch [2/50] \ Loss D: 0.1049, Loss G: 2.4980
Epoch [3/50] \ Loss D: 0.1159, Loss G: 3.4923
Epoch [4/50] \ Loss D: 0.0400, Loss G: 3.8776
Epoch [5/50] \ Loss D: 0.0450, Loss G: 4.1703
...
Epoch [43/50] \ Loss D: 0.0022, Loss G: 7.7446
Epoch [44/50] \ Loss D: 0.0007, Loss G: 9.1281
Epoch [45/50] \ Loss D: 0.0138, Loss G: 6.2177
Epoch [46/50] \ Loss D: 0.0008, Loss G: 9.1188
Epoch [47/50] \ Loss D: 0.0025, Loss G: 8.9419
Epoch [48/50] \ Loss D: 0.0010, Loss G: 8.3315
Epoch [49/50] \ Loss D: 0.0007, Loss G: 7.8302

使用

tensorboard --logdir=runs

打开tensorboard:

在这里插入图片描述
可以看到效果并不好,这是由于我们只是采用了简单的线性网络来做辨别器和生成器。后面的博文我们会使用更复杂的网络来训练GAN。

参考

[1] Building our first simple GAN
[2] https://arxiv.org/abs/1406.2661

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/51163.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[LeetCode周赛复盘] 第 111 场双周赛20230819

[LeetCode周赛复盘] 第 111 场双周赛20230819 一、本周周赛总结2824. 统计和小于目标的下标对数目1. 题目描述2. 思路分析3. 代码实现 2825. 循环增长使字符串子序列等于另一个字符串1. 题目描述2. 思路分析3. 代码实现 2826. 将三个组排序1. 题目描述2. 思路分析3. 代码实现 …

js、PHP连接外卖小票机打印机方案(调用佳博、芯烨等)

前言&#xff1a; 目前开发需要用到电脑直接连接外卖小票机打印小票&#xff0c;查阅各种资料&#xff0c;使用 6612345浏览器 终于解决了这个问题。 效果&#xff1a; PHP、js直接连接小票机并且自动出票。 支持的小票机&#xff1a; 目前测试可以的有&#xff1a;电脑A4打印…

SQL注入读写文件

文章目录 条件利用SQL注入漏洞读取hosts文件查看文件读写权限安全选项允许导入导出读取hosts文件 利用SQL注入漏洞写入一句话木马&#xff0c;并用蚁剑连接webshell写入文件 条件 SQL注入有直接SQL注入&#xff0c;也有文件读写时的注入&#xff0c;后者的主要 目的在于获取web…

回归预测 | MATLAB实现PSO-RF粒子群优化算法优化随机森林算法多输入单输出回归预测(多指标,多图)

回归预测 | MATLAB实现PSO-RF粒子群优化算法优化随机森林算法多输入单输出回归预测&#xff08;多指标&#xff0c;多图&#xff09; 目录 回归预测 | MATLAB实现PSO-RF粒子群优化算法优化随机森林算法多输入单输出回归预测&#xff08;多指标&#xff0c;多图&#xff09;效果…

嵌入式Linux开发实操(十一):ETH网络接口开发

# 前言 嵌入式linux也有些是支持网口的,比如RGMII,嵌入式系统资源支持以太网和其他基本接口的硬件平台(板上或片上系统),有充足的NOR或NAND Flash闪存,用于容纳OS、lib库、fileSystem文件系统、APP应用程序、Bootloader引导程序等。嵌入式Linux是开源的、可修改的,并且…

个人微信AI聊天机器人

个人微信AI聊天机器人 微信AI机器人介绍产品介绍联系本人微信&#xff1a;yao_you_meng_xiang代码地址&#xff1a;https://github.com/xshxsh/weChatAiRobot 前期准备个人微信号Windows电脑注册AI模型账号 搭建使用注册AI账号注册讯飞账号创建应用申请API使用 安装微信 安装代…

【网络安全】防火墙知识点全面图解(三)

本系列文章包含&#xff1a; 【网络安全】防火墙知识点全面图解&#xff08;一&#xff09;【网络安全】防火墙知识点全面图解&#xff08;二&#xff09;【网络安全】防火墙知识点全面图解&#xff08;三&#xff09; 防火墙知识点全面图解&#xff08;三&#xff09; 39、什…

解决idea登录github copilot报错问题

试了好多方案都没用&#xff0c;但是这个有用&#xff0c; 打开idea-help-edit custonm vm options 然后在这个文件里面输入 -Dcopilot.agent.disabledtrue再打开 https://github.com/settings/copilot 把这个设置成allow&#xff0c;然后重新尝试登录copilot就行就行 解决方…

nginx代理请求到内网不同服务器

需求&#xff1a;之前用的是frp做的内网穿透&#xff0c;但是每次电脑断电重启&#xff0c;路由或者端口会冲突&#xff0c;现在使用汉土云盒替换frp。 需要把公网ip映射到任意一台内网服务器上&#xff0c;然后在这台内网服务器上用Nginx做代理即可访问内网其它服务器&#xf…

mysql使用flashback恢复数据

常在河边走&#xff0c;哪有不湿鞋。如果我们经常操作数据库&#xff0c;很有可能就会造成误操作&#xff0c;假如我们不幸误删了数据&#xff0c;有没有办法快速恢复呢&#xff1f; 这里&#xff0c;我们就以用的最多的mysql举例&#xff0c;聊聊如何快速恢复数据。mysql官方貌…

springboot里 运用 easyexcel 导出

引入pom <dependency><groupId>com.alibaba</groupId><artifactId>easyexcel</artifactId><version>2.2.6</version> </dependency>运用 import com.alibaba.excel.EasyExcel; import org.springframework.stereotype.Contr…

YOLOv7-tracker 目标追踪 输入视频帧

参考项目&#xff1a;https://github.com/JackWoo0831/Yolov7-tracker/tree/master github链接&#xff1a;https://github.com/Whiffe/Yolov7-tracker 码云链接&#xff1a;https://gitee.com/YFwinston/Yolov7-tracker 1 项目安装 1.1 环境搭建 平台&#xff1a;AutoDL 选…

Linux TCP协议——三次握手,四次挥手

一、TCP协议介绍 TCP协议是可靠的、面向连接的、基于字节流的传输层通信协议。 TCP的头部结构&#xff1a; 源/目的端口号: 表示数据是从哪个进程来, 到哪个进程去;&#xff08;tcp是传输层的协议&#xff0c;端与端之间的数据传输&#xff0c;在TCP和UDP协议当中不会体现出I…

【Linux】一张图了解系统文件

首先先认识磁盘结构 系统文件分布图 文件查找 文件删除 文件的增删改查都是围绕inode来完成的&#xff0c;所以当我们要进行文件删除的时候&#xff0c;只需要通过inode来获取到它对应的block bitmap和inode bitmap数据块容器和保存文件属性的位置置为 0即可 &#xff0c;如果想…

【boost网络库从青铜到王者】第六篇:asio网络编程中的socket异步读(接收)写(发送)

文章目录 1、简介2、异步写 void AsyncWriteSomeToSocketErr(const std::string& buffer)3、异步写void AsyncWriteSomeToSocket(const std::string& buffer)4、异步写void AsyncSendToSocket(const std::string& buffer)5、异步读void AsyncReadSomeToSocket(cons…

一文看懂 iova、IOMMU、DMA

目录 一、概念解释 二、深入浅出 三、应用 四、常见问题 一、概念解释 IOVA&#xff08;IO Virtual Address&#xff0c;输入/输出虚拟地址&#xff09; IOMMU&#xff08;I/O Memory Management Unit&#xff09;&#xff1a;IOMMU是一种硬件单元&#xff0c;用于管理设备…

springboot sl4j2 写入日志到mysql

问题描述 springboot初始化的时候&#xff0c;会先初始化日志然后再加载数据源如果用配置文件进行初始化&#xff0c;那么会出现数据源没有加载成功&#xff0c;导致空指针异常 报错排查如下&#xff1a; 搜索报错信息&#xff0c;OBjects.invoke is Null打断点发现。dataso…

前端基础踩坑记录

前言&#xff1a;在做vue项目时&#xff0c;有时代码没有报错&#xff0c;但运行时却各种问题&#xff0c;没有报错排查起来就很费劲&#xff0c;本人感悟&#xff1a;写前端&#xff0c;需要好的眼神&#xff01;&#xff01;&#xff01;谨以此博客记录下自己的踩坑点。 一、…

【Maven教程】(三)基础使用篇:入门使用指南——POM编写、业务代码、测试代码、打包与运行、使用Archetype生成项目骨架~

Maven基础使用篇 1️⃣ 编写 POM2️⃣ 编写业务代码3️⃣ 编写测试代码4️⃣ 打包和运行5️⃣ 使用 Archetype生成项目骨架 1️⃣ 编写 POM 到目前为止&#xff0c;已经大概了解并安装好了Maven环境, 现在&#xff0c;我们开始创建一个最简单的 Hello World 项目。如果你是初次…

IDEA下SpringBoot指定环境、配置文件启动

1、idea下的SpringBoot启动&#xff1a;指定配置文件 Springboot项目有如下配置文件 主配置文件application.yml&#xff0c; 测试环境&#xff1a;application-test.yml 生产环境&#xff1a;application-pro.yml 开发环境&#xff1a;application-dev.yml 1.1.配置文件…