深度学习(四):pytorch搭建GAN(对抗网络)

1.GAN

生成对抗网络(GAN)是一种深度学习模型,由两个网络组成:生成器(Generator)和判别器(Discriminator)。生成器负责生成假数据,而判别器则负责判断数据是真实的还是 fake的。这两个网络互相竞争,生成器试图生成更真实的数据以欺骗判别器,而判别器则试图更好地识别生成的数据。
在这里插入图片描述

GAN 的基本思想是:通过训练生成器和判别器,使得生成器能够生成与真实数据非常相似的数据,同时使得判别器能够更有效地识别这些数据。

1.1 概念

  1. 生成器(Generator):生成器是一个神经网络,其目的是生成假的数据,看起来像是真实的。生成器通常包含一些神经网络层,如卷积层、全连接层等。生成器接受随机噪声作为输入,并生成看起来像是真实数据的输出。
  2. 判别器(Discriminator):判别器也是一个神经网络,其目的是识别数据是真实的还是 fake的。判别器通常也包含一些神经网络层,如卷积层、全连接层等。判别器接受输入数据,并输出一个分数,表示输入数据是真实的还是 fake的。
  3. 生成对抗训练:生成对抗训练是指同时训练生成器和判别器。生成器试图生成更真实的数据,以欺骗判别器。判别器则试图更好地识别生成的数据,以避免被欺骗。生成器和判别器之间的竞争导致它们不断改进,以提高生成数据的真实性。
  4. 生成器损失和判别器损失:生成器损失是指生成器试图生成更真实数据的损失。生成器损失通常使用生成器的对抗损失和生成损失之和来计算。判别器损失是指判别器试图更好地识别真实数据和假数据的损失。判别器损失通常使用判别器识别真实数据和假数据的损失之和来计算。
  5. 对抗性训练:对抗性训练是指在训练过程中,使用生成器生成的假数据来训练判别器,以提高判别器的识别能力。同时,使用判别器识别的反馈来训练生成器,以提高生成器生成更真实数据的能力。

1.2 优势

GAN(Generative Adversarial Network)是一种生成对抗网络,主要由生成器和判别器组成。生成器负责生成假数据,而判别器负责判断数据是真实的还是 fake的。GAN 的训练过程相对复杂,但是它可以生成非常真实的数据,并且可以用来进行数据增强、图像生成、视频生成等应用。

GAN 的优势主要体现在以下几个方面:

  1. 生成数据非常真实:GAN 可以生成非常真实的数据,可以用来进行数据增强、图像生成、视频生成等应用。
  2. 可以生成大量数据:GAN 可以生成大量的数据,可以用来进行机器学习、深度学习等应用。
  3. 可以生成不同类型的数据:GAN 可以生成不同类型的数据,可以用来进行图像生成、视频生成等应用。
  4. 可以进行对抗训练:GAN 可以进行对抗训练,可以提高模型的鲁棒性和泛化能力。

虽然 GAN 具有优势,但是也存在一些挑战,例如训练过程复杂、生成器容易过拟合、对抗训练难以实现等。因此,在实际应用中,需要根据具体情况进行优化和调整。

1.3 训练技巧

  1. 使用批归一化(Batch Normalization):批归一化是一种在卷积神经网络中常用的加速训练和提高模型性能的方法。在 GAN 的生成器和判别器中可以使用批归一化来提高性能。
  2. 使用 Leaky ReLU 激活函数:Leaky ReLU 激活函数是一种在 ReLU 激活函数中加入一个小于 1 的常数,以避免神经元死亡的方法。在 GAN 的生成器和判别器中可以使用 Leaky ReLU 激活函数来提高性能。
  3. 使用 U-Net 结构:U-Net 是一种用于图像分割的网络结构,其结构可以同时实现编码器和解码器。在 GAN 的生成器中可以使用 U-Net 结构来提高生成图像的质量。
  4. 使用对抗性损失(Adversarial Loss):对抗性损失是一种可以增加生成器损失的方法,通过在损失函数中加入一个与真实数据接近的噪声来增加生成器的难度。在 GAN 的训练过程中可以使用对抗性损失来提高性能。
  5. 使用预训练模型:预训练模型是一种在已有数据集上训练好的模型,可以用于迁移学习和提高性能。在 GAN 的生成器和判别器中可以使用预训练模型来提高性能。
  6. 使用注意力机制(Attention):注意力机制是一种可以提高模型性能和泛化能力的方法,可以在 GAN 的生成器和判别器中使用注意力机制来提高性能。

总结起来,GAN 的训练过程需要综合考虑多个方面,包括数据预处理、损失函数选择、正则化、梯度裁剪、对抗性训练、数据增强和 early stopping 等技巧。同时,还可以使用一些额外的技巧,如批归一化、Leaky ReLU 激活函数、U-Net 结构、对抗性损失、预训练模型和注意力机制等来进一步提高 GAN 的性能。

2 代码实现

步骤:

  1. 导入所需的库和模块。
  2. 定义生成器的网络结构,包括全连接层和激活函数。
  3. 定义判别器的网络结构,也包括全连接层和激活函数。
  4. 定义训练函数,包括将模型移动到设备、定义损失函数和优化器、开始训练的循环等。
  5. 设置随机种子。
  6. 设置设备,如果有可用的GPU则使用GPU,否则使用CPU。
  7. 加载MNIST数据集,并进行数据预处理。
  8. 初始化生成器和判别器。
  9. 设置训练的参数,如训练轮数、生成器的输入维度等。
  10. 调用训练函数进行训练。
# 导入torch模块
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt# 定义生成器的网络结构
class Generator(nn.Module):def __init__(self, latent_dim):super(Generator, self).__init__()self.model = nn.Sequential(nn.Linear(latent_dim, 256),  # 全连接层,输入latent_dim维,输出256维nn.LeakyReLU(0.2),  # LeakyReLU激活函数nn.Linear(256, 512),  # 全连接层,输入256维,输出512维nn.LeakyReLU(0.2),nn.Linear(512, 1024),  # 全连接层,输入512维,输出1024维nn.LeakyReLU(0.2),nn.Linear(1024, 784),  # 全连接层,输入1024维,输出784维nn.Tanh()  # Tanh激活函数)def forward(self, x):return self.model(x)# 定义判别器的网络结构
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(784, 512),  # 全连接层,输入784维,输出512维nn.LeakyReLU(0.2),nn.Linear(512, 256),  # 全连接层,输入512维,输出256维nn.LeakyReLU(0.2),nn.Linear(256, 1),  # 全连接层,输入256维,输出1维nn.Sigmoid()  # Sigmoid激活函数)def forward(self, x):return self.model(x)# 定义训练函数
def train(generator, discriminator, dataloader, num_epochs, latent_dim, device):# 将模型移动到设备generator.to(device)discriminator.to(device)# 定义损失函数和优化器criterion = nn.BCELoss()  # 二分类交叉熵损失函数optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))  # 生成器的优化器optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))  # 判别器的优化器# 开始训练for epoch in range(num_epochs):for i, (real_images, _) in enumerate(dataloader):# 将图像转换为向量real_images = real_images.view(-1, 784).to(device)# 获取图像的batch_sizebatch_size = real_images.size(0)# 定义真实标签和 fake标签real_labels = torch.ones(batch_size, 1).to(device)fake_labels = torch.zeros(batch_size, 1).to(device)# 训练判别器optimizer_D.zero_grad()# 计算真实图像的输出real_outputs = discriminator(real_images)# 计算真实图像的损失real_loss = criterion(real_outputs, real_labels)# 生成假图像z = torch.randn(batch_size, latent_dim).to(device)fake_images = generator(z)# 计算假图像的输出fake_outputs = discriminator(fake_images.detach())# 计算假图像的损失fake_loss = criterion(fake_outputs, fake_labels)# 计算判别器的损失d_loss = real_loss + fake_loss# 反向传播d_loss.backward()# 更新参数optimizer_D.step()# 训练生成器optimizer_G.zero_grad()# 计算假图像的输出fake_outputs = discriminator(fake_images)# 计算生成器的损失g_loss = criterion(fake_outputs, real_labels)# 反向传播g_loss.backward()# 更新参数optimizer_G.step()# 每200步打印一次损失if (i+1) % 200 == 0:print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], "f"D_loss: {d_loss.item():.4f}, G_loss: {g_loss.item():.4f}")# 每1步打印一次图像if (epoch+1) % 1 == 0:# 生成图像with torch.no_grad():z = torch.randn(10, 100).to(device)generated_images = generator(z).cpu().view(-1, 28, 28)# 展示原始数据和生成数据的图像fig, axes = plt.subplots(2, 5, figsize=(10, 4))for i, ax in enumerate(axes.flat):if i < 5:ax.imshow(real_images[i].view(28, 28), cmap='gray')ax.set_title('Real')else:ax.imshow(generated_images[i-5], cmap='gray')ax.set_title('Generated')ax.axis('off')plt.tight_layout()plt.show()# 设置随机种子
torch.manual_seed(42)# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)# 初始化生成器和判别器
latent_dim = 100
generator = Generator(latent_dim)
discriminator = Discriminator()# 训练GAN模型
num_epochs = 50
train(generator, discriminator, train_dataloader, num_epochs, latent_dim, device)

2.1结果

第一轮:

在这里插入图片描述
训练之后:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/191555.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

解决Linux的端口占用报错问题

文章目录 1 Linux报错2 解决方式 1 Linux报错 Port 6006 is in use. If a gradio.Blocks is running on the port, you can close() it or gradio.close_all(). 想起之前运行Gradio 6006&#xff0c;端口被占用 2 解决方式 输入 netstat -tpl查看当前一些端口号的占用号&a…

go第三方包发布(短精细)

1、清除其他依赖项 $ go mod tidy # 清除不必要的依赖依赖清除完成后&#xff0c;查看go.mod文件配置是否规范 module github.com/fyupeng/rpc-go-netty go 1.19 require ( )2、本地版本创建 $ git tag v0.1.0 # 本地 创建标签3、版本提交 $ git push github v0.1.0 # 推送…

Selector SelectionKey基础学习

netty技术内幕一(Selector,SelectionKey) Java Nio注意事项 # selector Selector类的使用(一) SelectionKey类的使用 /* package java.nio.channels;import java.io.Closeable; import java.io.IOException; import java.nio.channels.spi.SelectorProvider; import java.u…

面试就是这么简单,offer拿到手软(一)—— 常见非技术问题回答思路

面试系列&#xff1a; 面试就是这么简单&#xff0c;offer拿到手软&#xff08;一&#xff09;—— 常见非技术问题回答思路 面试就是这么简单&#xff0c;offer拿到手软&#xff08;二&#xff09;—— 常见65道非技术面试问题 文章目录 一、前言二、常见面试问题回答思路问…

cyclictest 交叉编译与使用

目录 使用版本问题编译 numactl编译 cyclictest使用参考 cyclictest 主要是用于测试系统延时&#xff0c;进而判断系统的实时性 使用版本 rt-tests-2.6.tar.gz numactl v2.0.16 问题 编译时&#xff0c;需要先编译 numactl &#xff0c;不然会有以下报错&#xff1a; arm-…

在 Linux 中使用 udev 规则固定摄像头节点

简介 通过编写 udev 规则来固定 USB 摄像头节点&#xff0c;以便在系统中始终使用相同的设备路径访问摄像头。 确定摄像头的供应商 ID 和产品 ID 使用 lsusb 命令确定连接的 USB 摄像头的供应商 ID 和产品 ID。示例命令及输出&#xff1a; $ lsusb Bus 001 Device 030: ID 220…

AI 绘画 | Stable Diffusion 电商模特

前言 上一篇文章讲到如何给人物更换背景和服装。今天主要讲电商模特,就是服装电商们的固定服装产品图片如何变成真人模特穿上的固定服装产品效果图。如果你掌握了 《AI 绘画 | Stable Diffusion 人物 换背景|换服装》,这篇文章对你来说,上手会更轻松。 教程 提取服装蒙版…

Java实现简单飞翔小鸟游戏

一、创建新项目 首先创建一个新的项目&#xff0c;并命名为飞翔的鸟。 其次在飞翔的鸟项目下创建一个名为images的文件夹用来存放游戏相关图片。 用到的图片如下&#xff1a;0~7&#xff1a; bg&#xff1a; column&#xff1a; gameover&#xff1a; ground&#xff1a; st…

记QListWidget中QPushButton QSS样式失效的“bug”

一、场景 有一个QListWidget的列表&#xff1b;里面存放了若干QListWidgetItem&#xff1b;每个QListWidgetItem与一个自定义类对象绑定——通过QListWidget的setItemWidget()实现。自定义对象继承于QWidget&#xff0c;且内含QPushButton。 二、bug描述 在该QListWidget的外…

Mybatis 分页查询的三种实现

Mybatis 分页查询 1. 直接在 sql 中使用 limit2. 使用 RowBounds3. 使用 Mybatis 提供的拦截器机制3.1 创建一个自定义拦截器类实现 Interceptor3.2 创建分页查询函数 与 sql3.3 编写拦截逻辑3.4 注册 PageInterceptor 到 Mybatis 拦截器链中3.5 测试 准备一个分页查询类 Data…

Clion调试QTQString看不到值问题处理

环境 Clion &#xff1a;2019.3.6 Qt &#xff1a;5.9.6&#xff08;MinGW&#xff09; 环境搭建参考&#xff1a;https://blog.csdn.net/qq_27953479/article/details/132338745 调试时QString看不到值问题处理 下载文件 qt.py : https://github.com/KDE/kdevelop/blob/…

CIS|安森美微光近红外增强相机论文解析

引言 在之前的文章中&#xff0c;我们介绍了索尼、安森美以及三星等Sensor厂家在车载领域中的技术论文&#xff0c;分析了各个厂家不同的技术路线、Sensor架构以及差异点。今天&#xff0c;笔者借豪威科技在移动端200Mega Pixels产品的技术论文&#xff0c;讲解消费级CIS传感器…

高级软件工程15本书籍

如果您想学习软件工程技能并提高您的专业知识&#xff0c;那么这里是您的最佳选择。我们有一本很棒的书&#xff0c;可以极大地增强您在软件工程方面的知识。 1&#xff09;干净的代码 Robert C. Martin 写了一本名为“干净代码&#xff1a;敏捷软件工艺手册”的书。在本书中&…

如何在WordPress中批量替换图片路径?

很多站长在使用WordPress博客或者搬家时&#xff0c;需要把WordPress文章中的图片路径进行替换来解决图片不显示的问题。总结一下WordPress图片路径批量替换的过程&#xff0c;方便有此类需求的站长们学习。 什么情况下批量替换图片路径 1、更换了网站域名 有许多网站建设初期…

面试 Java 基础八股文十问十答第二期

面试 Java 基础八股文十问十答第二期 作者&#xff1a;程序员小白条 ⭐点赞⭐收藏⭐不迷路&#xff01;⭐ 11.什么是反射&#xff1f;反射有哪些作用&#xff1f;反射在Sping中的体现 (1): 什么是反射? 反射可以在运行时获取到一个类的所有信息&#xff0c;包括(成员变量&am…

关于qiankun沙箱sandbox(面试题)

为什么要有js资源隔离机制&#xff1f; 主应用和子应用&#xff0c;相同的全局变量&#xff0c;可能会发生冲突&#xff0c;子应用和子应用之间&#xff0c;相同的全局变量&#xff0c;也可能会发生冲突。在这里我们主要指的就是window。 思路&#xff1a;打开沙箱时能够修改…

Spring中@Transactional注解

在Spring框架中&#xff0c;Transactional 是一个注解&#xff0c;用于声明事务性的方法。这个注解可以被应用在方法级别或类级别上。它提供了一种声明式的事务管理方式&#xff0c;避免了在代码中直接编写事务管理相关的代码。Transactional 注解能够将一个方法纳入到一个事务…

基于SSM的生鲜在线销售系统

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;Vue 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#xff1a;是 目录…

亚马逊云科技推出新一代自研芯片

北京——2023 年12月1日 亚马逊云科技在2023 re:Invent全球大会上宣布其自研芯片家族的两个系列推出新一代&#xff0c;包括Amazon Graviton4和Amazon Trainium2&#xff0c;为机器学习&#xff08;ML&#xff09;训练和生成式人工智能&#xff08;AI&#xff09;应用等广泛的工…

Linux: 退出vim编辑模式

一、使用快捷键进行退出 1、按“Esc”键进入命令模式 当我们在vim编辑模式下输入完毕需要进行退出操作时&#xff0c;首先需要按下“Esc”键&#xff0c;将vim编辑器从插入模式或者替换模式切换到命令模式。 ESC 2、输入“:wq”保存并退出 在命令模式下&#xff0c;输入“:…