深度学习(四):pytorch搭建GAN(对抗网络)

1.GAN

生成对抗网络(GAN)是一种深度学习模型,由两个网络组成:生成器(Generator)和判别器(Discriminator)。生成器负责生成假数据,而判别器则负责判断数据是真实的还是 fake的。这两个网络互相竞争,生成器试图生成更真实的数据以欺骗判别器,而判别器则试图更好地识别生成的数据。
在这里插入图片描述

GAN 的基本思想是:通过训练生成器和判别器,使得生成器能够生成与真实数据非常相似的数据,同时使得判别器能够更有效地识别这些数据。

1.1 概念

  1. 生成器(Generator):生成器是一个神经网络,其目的是生成假的数据,看起来像是真实的。生成器通常包含一些神经网络层,如卷积层、全连接层等。生成器接受随机噪声作为输入,并生成看起来像是真实数据的输出。
  2. 判别器(Discriminator):判别器也是一个神经网络,其目的是识别数据是真实的还是 fake的。判别器通常也包含一些神经网络层,如卷积层、全连接层等。判别器接受输入数据,并输出一个分数,表示输入数据是真实的还是 fake的。
  3. 生成对抗训练:生成对抗训练是指同时训练生成器和判别器。生成器试图生成更真实的数据,以欺骗判别器。判别器则试图更好地识别生成的数据,以避免被欺骗。生成器和判别器之间的竞争导致它们不断改进,以提高生成数据的真实性。
  4. 生成器损失和判别器损失:生成器损失是指生成器试图生成更真实数据的损失。生成器损失通常使用生成器的对抗损失和生成损失之和来计算。判别器损失是指判别器试图更好地识别真实数据和假数据的损失。判别器损失通常使用判别器识别真实数据和假数据的损失之和来计算。
  5. 对抗性训练:对抗性训练是指在训练过程中,使用生成器生成的假数据来训练判别器,以提高判别器的识别能力。同时,使用判别器识别的反馈来训练生成器,以提高生成器生成更真实数据的能力。

1.2 优势

GAN(Generative Adversarial Network)是一种生成对抗网络,主要由生成器和判别器组成。生成器负责生成假数据,而判别器负责判断数据是真实的还是 fake的。GAN 的训练过程相对复杂,但是它可以生成非常真实的数据,并且可以用来进行数据增强、图像生成、视频生成等应用。

GAN 的优势主要体现在以下几个方面:

  1. 生成数据非常真实:GAN 可以生成非常真实的数据,可以用来进行数据增强、图像生成、视频生成等应用。
  2. 可以生成大量数据:GAN 可以生成大量的数据,可以用来进行机器学习、深度学习等应用。
  3. 可以生成不同类型的数据:GAN 可以生成不同类型的数据,可以用来进行图像生成、视频生成等应用。
  4. 可以进行对抗训练:GAN 可以进行对抗训练,可以提高模型的鲁棒性和泛化能力。

虽然 GAN 具有优势,但是也存在一些挑战,例如训练过程复杂、生成器容易过拟合、对抗训练难以实现等。因此,在实际应用中,需要根据具体情况进行优化和调整。

1.3 训练技巧

  1. 使用批归一化(Batch Normalization):批归一化是一种在卷积神经网络中常用的加速训练和提高模型性能的方法。在 GAN 的生成器和判别器中可以使用批归一化来提高性能。
  2. 使用 Leaky ReLU 激活函数:Leaky ReLU 激活函数是一种在 ReLU 激活函数中加入一个小于 1 的常数,以避免神经元死亡的方法。在 GAN 的生成器和判别器中可以使用 Leaky ReLU 激活函数来提高性能。
  3. 使用 U-Net 结构:U-Net 是一种用于图像分割的网络结构,其结构可以同时实现编码器和解码器。在 GAN 的生成器中可以使用 U-Net 结构来提高生成图像的质量。
  4. 使用对抗性损失(Adversarial Loss):对抗性损失是一种可以增加生成器损失的方法,通过在损失函数中加入一个与真实数据接近的噪声来增加生成器的难度。在 GAN 的训练过程中可以使用对抗性损失来提高性能。
  5. 使用预训练模型:预训练模型是一种在已有数据集上训练好的模型,可以用于迁移学习和提高性能。在 GAN 的生成器和判别器中可以使用预训练模型来提高性能。
  6. 使用注意力机制(Attention):注意力机制是一种可以提高模型性能和泛化能力的方法,可以在 GAN 的生成器和判别器中使用注意力机制来提高性能。

总结起来,GAN 的训练过程需要综合考虑多个方面,包括数据预处理、损失函数选择、正则化、梯度裁剪、对抗性训练、数据增强和 early stopping 等技巧。同时,还可以使用一些额外的技巧,如批归一化、Leaky ReLU 激活函数、U-Net 结构、对抗性损失、预训练模型和注意力机制等来进一步提高 GAN 的性能。

2 代码实现

步骤:

  1. 导入所需的库和模块。
  2. 定义生成器的网络结构,包括全连接层和激活函数。
  3. 定义判别器的网络结构,也包括全连接层和激活函数。
  4. 定义训练函数,包括将模型移动到设备、定义损失函数和优化器、开始训练的循环等。
  5. 设置随机种子。
  6. 设置设备,如果有可用的GPU则使用GPU,否则使用CPU。
  7. 加载MNIST数据集,并进行数据预处理。
  8. 初始化生成器和判别器。
  9. 设置训练的参数,如训练轮数、生成器的输入维度等。
  10. 调用训练函数进行训练。
# 导入torch模块
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt# 定义生成器的网络结构
class Generator(nn.Module):def __init__(self, latent_dim):super(Generator, self).__init__()self.model = nn.Sequential(nn.Linear(latent_dim, 256),  # 全连接层,输入latent_dim维,输出256维nn.LeakyReLU(0.2),  # LeakyReLU激活函数nn.Linear(256, 512),  # 全连接层,输入256维,输出512维nn.LeakyReLU(0.2),nn.Linear(512, 1024),  # 全连接层,输入512维,输出1024维nn.LeakyReLU(0.2),nn.Linear(1024, 784),  # 全连接层,输入1024维,输出784维nn.Tanh()  # Tanh激活函数)def forward(self, x):return self.model(x)# 定义判别器的网络结构
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(784, 512),  # 全连接层,输入784维,输出512维nn.LeakyReLU(0.2),nn.Linear(512, 256),  # 全连接层,输入512维,输出256维nn.LeakyReLU(0.2),nn.Linear(256, 1),  # 全连接层,输入256维,输出1维nn.Sigmoid()  # Sigmoid激活函数)def forward(self, x):return self.model(x)# 定义训练函数
def train(generator, discriminator, dataloader, num_epochs, latent_dim, device):# 将模型移动到设备generator.to(device)discriminator.to(device)# 定义损失函数和优化器criterion = nn.BCELoss()  # 二分类交叉熵损失函数optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))  # 生成器的优化器optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))  # 判别器的优化器# 开始训练for epoch in range(num_epochs):for i, (real_images, _) in enumerate(dataloader):# 将图像转换为向量real_images = real_images.view(-1, 784).to(device)# 获取图像的batch_sizebatch_size = real_images.size(0)# 定义真实标签和 fake标签real_labels = torch.ones(batch_size, 1).to(device)fake_labels = torch.zeros(batch_size, 1).to(device)# 训练判别器optimizer_D.zero_grad()# 计算真实图像的输出real_outputs = discriminator(real_images)# 计算真实图像的损失real_loss = criterion(real_outputs, real_labels)# 生成假图像z = torch.randn(batch_size, latent_dim).to(device)fake_images = generator(z)# 计算假图像的输出fake_outputs = discriminator(fake_images.detach())# 计算假图像的损失fake_loss = criterion(fake_outputs, fake_labels)# 计算判别器的损失d_loss = real_loss + fake_loss# 反向传播d_loss.backward()# 更新参数optimizer_D.step()# 训练生成器optimizer_G.zero_grad()# 计算假图像的输出fake_outputs = discriminator(fake_images)# 计算生成器的损失g_loss = criterion(fake_outputs, real_labels)# 反向传播g_loss.backward()# 更新参数optimizer_G.step()# 每200步打印一次损失if (i+1) % 200 == 0:print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], "f"D_loss: {d_loss.item():.4f}, G_loss: {g_loss.item():.4f}")# 每1步打印一次图像if (epoch+1) % 1 == 0:# 生成图像with torch.no_grad():z = torch.randn(10, 100).to(device)generated_images = generator(z).cpu().view(-1, 28, 28)# 展示原始数据和生成数据的图像fig, axes = plt.subplots(2, 5, figsize=(10, 4))for i, ax in enumerate(axes.flat):if i < 5:ax.imshow(real_images[i].view(28, 28), cmap='gray')ax.set_title('Real')else:ax.imshow(generated_images[i-5], cmap='gray')ax.set_title('Generated')ax.axis('off')plt.tight_layout()plt.show()# 设置随机种子
torch.manual_seed(42)# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)# 初始化生成器和判别器
latent_dim = 100
generator = Generator(latent_dim)
discriminator = Discriminator()# 训练GAN模型
num_epochs = 50
train(generator, discriminator, train_dataloader, num_epochs, latent_dim, device)

2.1结果

第一轮:

在这里插入图片描述
训练之后:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/191555.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

解决Linux的端口占用报错问题

文章目录 1 Linux报错2 解决方式 1 Linux报错 Port 6006 is in use. If a gradio.Blocks is running on the port, you can close() it or gradio.close_all(). 想起之前运行Gradio 6006&#xff0c;端口被占用 2 解决方式 输入 netstat -tpl查看当前一些端口号的占用号&a…

go第三方包发布(短精细)

1、清除其他依赖项 $ go mod tidy # 清除不必要的依赖依赖清除完成后&#xff0c;查看go.mod文件配置是否规范 module github.com/fyupeng/rpc-go-netty go 1.19 require ( )2、本地版本创建 $ git tag v0.1.0 # 本地 创建标签3、版本提交 $ git push github v0.1.0 # 推送…

面试就是这么简单,offer拿到手软(一)—— 常见非技术问题回答思路

面试系列&#xff1a; 面试就是这么简单&#xff0c;offer拿到手软&#xff08;一&#xff09;—— 常见非技术问题回答思路 面试就是这么简单&#xff0c;offer拿到手软&#xff08;二&#xff09;—— 常见65道非技术面试问题 文章目录 一、前言二、常见面试问题回答思路问…

cyclictest 交叉编译与使用

目录 使用版本问题编译 numactl编译 cyclictest使用参考 cyclictest 主要是用于测试系统延时&#xff0c;进而判断系统的实时性 使用版本 rt-tests-2.6.tar.gz numactl v2.0.16 问题 编译时&#xff0c;需要先编译 numactl &#xff0c;不然会有以下报错&#xff1a; arm-…

AI 绘画 | Stable Diffusion 电商模特

前言 上一篇文章讲到如何给人物更换背景和服装。今天主要讲电商模特,就是服装电商们的固定服装产品图片如何变成真人模特穿上的固定服装产品效果图。如果你掌握了 《AI 绘画 | Stable Diffusion 人物 换背景|换服装》,这篇文章对你来说,上手会更轻松。 教程 提取服装蒙版…

Java实现简单飞翔小鸟游戏

一、创建新项目 首先创建一个新的项目&#xff0c;并命名为飞翔的鸟。 其次在飞翔的鸟项目下创建一个名为images的文件夹用来存放游戏相关图片。 用到的图片如下&#xff1a;0~7&#xff1a; bg&#xff1a; column&#xff1a; gameover&#xff1a; ground&#xff1a; st…

Mybatis 分页查询的三种实现

Mybatis 分页查询 1. 直接在 sql 中使用 limit2. 使用 RowBounds3. 使用 Mybatis 提供的拦截器机制3.1 创建一个自定义拦截器类实现 Interceptor3.2 创建分页查询函数 与 sql3.3 编写拦截逻辑3.4 注册 PageInterceptor 到 Mybatis 拦截器链中3.5 测试 准备一个分页查询类 Data…

Clion调试QTQString看不到值问题处理

环境 Clion &#xff1a;2019.3.6 Qt &#xff1a;5.9.6&#xff08;MinGW&#xff09; 环境搭建参考&#xff1a;https://blog.csdn.net/qq_27953479/article/details/132338745 调试时QString看不到值问题处理 下载文件 qt.py : https://github.com/KDE/kdevelop/blob/…

CIS|安森美微光近红外增强相机论文解析

引言 在之前的文章中&#xff0c;我们介绍了索尼、安森美以及三星等Sensor厂家在车载领域中的技术论文&#xff0c;分析了各个厂家不同的技术路线、Sensor架构以及差异点。今天&#xff0c;笔者借豪威科技在移动端200Mega Pixels产品的技术论文&#xff0c;讲解消费级CIS传感器…

如何在WordPress中批量替换图片路径?

很多站长在使用WordPress博客或者搬家时&#xff0c;需要把WordPress文章中的图片路径进行替换来解决图片不显示的问题。总结一下WordPress图片路径批量替换的过程&#xff0c;方便有此类需求的站长们学习。 什么情况下批量替换图片路径 1、更换了网站域名 有许多网站建设初期…

基于SSM的生鲜在线销售系统

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;Vue 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#xff1a;是 目录…

亚马逊云科技推出新一代自研芯片

北京——2023 年12月1日 亚马逊云科技在2023 re:Invent全球大会上宣布其自研芯片家族的两个系列推出新一代&#xff0c;包括Amazon Graviton4和Amazon Trainium2&#xff0c;为机器学习&#xff08;ML&#xff09;训练和生成式人工智能&#xff08;AI&#xff09;应用等广泛的工…

锐捷RG-UAC应用网关 前台RCE漏洞复现

0x01 产品简介 锐捷RG-UAC系列应用管理网关是锐捷自主研发的应用管理产品。 0x02 漏洞概述 锐捷RG-UAC应用管理网关 nmc_sync.php 接口处存在命令执行漏洞&#xff0c;未经身份认证的攻击者可执行任意命令控制服务器权限。 0x03 复现环境 FOFA&#xff1a;app"Ruijie-R…

6.8 Windows驱动开发:内核枚举Registry注册表回调

在笔者上一篇文章《内核枚举LoadImage映像回调》中LyShark教大家实现了枚举系统回调中的LoadImage通知消息&#xff0c;本章将实现对Registry注册表通知消息的枚举&#xff0c;与LoadImage消息不同Registry消息不需要解密只要找到CallbackListHead消息回调链表头并解析为_CM_NO…

12-1 Springboot过滤拦截和日志处理

Springboot的日志 默认日志框架&#xff1a;logback 1.日志以文件的形式的保存 使用logback框架 ->(运行日志&#xff0c;开发中用于调式的&#xff0c;在开发中作为系统运行日志记录故障&#xff0c;从而追究问题根源) 2.日志相关的表 记录用户相关操作信息 -> 需要我…

<Linux>(极简关键、省时省力)《Linux操作系统原理分析之linux存储管理(2)》(18)

《Linux操作系统原理分析之linux存储管理&#xff08;1&#xff09;》&#xff08;17&#xff09; 6 Linux存储管理6.2 选段符与段描述符6.2.1 选段符6.2.2 段描述符6.2.3 分段机制的存储保护 6.3 80x86 的分页机制6.3.180x86 的分页机制6.3.2 分页机制的地址转换6.3.3 页表目录…

嵌入式WIFI芯片通过lwip获取心知天气实时天气信息(包含完整代码)

一、天气API 1. 心知天气的产品简介 HyperData 是心知天气的高精度气象数据产品&#xff0c;通过标准的 Restful API 接口&#xff0c;提供标准化的数据访问。无论是 APP、智能硬件还是企业级系统都可以轻松接入心知的精细化天气数据。 HyperData API V4版是当前的最新…

要致富 先撸树——判断循环语句(六)

引子 什么&#xff1f;万年丕更的作者更新了&#xff1f; 没错&#xff01;而且我们不当标题党&#xff0c;我决定把《我的世界》串进文章里。 什么&#xff1f;你不玩《我的世界》&#xff1f; 木有关系 本专栏文章主要在讲c语言的语法点和知识&#xff0c;保证让不玩《我…

Azure Machine Learning - 在 Azure 门户中创建AI搜索技能组

你将了解 Azure AI 搜索中的技能组如何通过添加光学字符识别 (OCR)、图像分析、语言检测、文本翻译和实体识别&#xff0c;在搜索索引中创建可搜索文本的内容。 关注TechLead&#xff0c;分享AI全维度知识。作者拥有10年互联网服务架构、AI产品研发经验、团队管理经验&#xff…

Python程序员入门指南:就业前景

文章目录 标题Python程序员入门指南&#xff1a;就业前景Python 就业数据Python的就业前景SWOT分析法Python 就业分析 标题 Python程序员入门指南&#xff1a;就业前景 Python是一种流行的编程语言&#xff0c;它具有简洁、易读和灵活的特点。Python可以用于多种领域&#xff…