学习笔记:Pytorch利用MNIST数据集训练生成对抗网络(GAN)

2023.8.27

       在进行深度学习的进阶的时候,我发了生成对抗网络是一个很神奇的东西,为什么它可以“将一堆随机噪声经过生成器变成一张图片”,特此记录一下学习心得。

一、生成对抗网络百科

        2014年,还在蒙特利尔读博士的Ian Goodfellow发表了论 文《Generative Adversarial Networks》(网址: https://arxiv.org/abs/1406.2661),将生成对抗网络引入 深度学习领域。2016年,GAN热潮席卷AI领域顶级会议, 从ICLR到NIPS,大量高质量论文被发表和探讨。Yann LeCun曾评价GAN是“20年来机器学习领域最酷的想法”。

机器学习的模型可大体分为两类,生成模型( Generative Model)和判别模型(Discriminative Model)。判别模型需要输入变量 ,通过某种模型来 预测 。生成模型是给定某种隐含信息,来随机产生观 测数据。

GAN百科:

GAN(生成对抗网络)的系统全面介绍(醍醐灌顶)_打灰人的博客-CSDN博客

二、GAN代码

训练代码:

                epoch=1000时的效果就不错啦

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as pltclass Generator(nn.Module):  # 生成器def __init__(self, latent_dim):super(Generator, self).__init__()self.model = nn.Sequential(nn.Linear(latent_dim, 256),nn.LeakyReLU(0.2),nn.Linear(256, 512),nn.LeakyReLU(0.2),nn.Linear(512, 1024),nn.LeakyReLU(0.2),nn.Linear(1024, 784),nn.Tanh())def forward(self, z):img = self.model(z)img = img.view(img.size(0), 1, 28, 28)return imgclass Discriminator(nn.Module):  # 判别器def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(784, 512),nn.LeakyReLU(0.2),nn.Linear(512, 256),nn.LeakyReLU(0.2),nn.Linear(256, 1),nn.Sigmoid())def forward(self, img):img = img.view(img.size(0), -1)validity = self.model(img)return validitydef gen_img_plot(model, test_input):pred = np.squeeze(model(test_input).detach().cpu().numpy())fig = plt.figure(figsize=(4, 4))for i in range(16):plt.subplot(4, 4, i + 1)plt.imshow((pred[i] + 1) / 2)plt.axis('off')plt.show(block=False)plt.pause(3)  # 停留0.5splt.close()# 调用GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 超参数设置
lr = 0.0001
batch_size = 128
latent_dim = 100
epochs = 1000# 数据集载入和数据变换
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 训练数据
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=False)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)# 测试数据 torch.randn()函数的作用是生成一组均值为0,方差为1(即标准正态分布)的随机数
# test_data = torch.randn(batch_size, latent_dim).to(device)
test_data = torch.FloatTensor(batch_size, latent_dim).to(device)# 实例化生成器和判别器,并定义损失函数和优化器
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)# 开始训练模型
for epoch in range(epochs):for i, (imgs, _) in enumerate(train_loader):batch_size = imgs.shape[0]real_imgs = imgs.to(device)# 训练判别器z = torch.FloatTensor(batch_size, latent_dim).to(device)z.data.normal_(0, 1)fake_imgs = generator(z)  # 生成器生成假的图片real_labels = torch.full((batch_size, 1), 1.0).to(device)fake_labels = torch.full((batch_size, 1), 0.0).to(device)real_loss = adversarial_loss(discriminator(real_imgs), real_labels)fake_loss = adversarial_loss(discriminator(fake_imgs.detach()), fake_labels)d_loss = (real_loss + fake_loss) / 2optimizer_D.zero_grad()d_loss.backward()optimizer_D.step()# 训练生成器z.data.normal_(0, 1)fake_imgs = generator(z)g_loss = adversarial_loss(discriminator(fake_imgs), real_labels)optimizer_G.zero_grad()g_loss.backward()optimizer_G.step()torch.save(generator.state_dict(), "Generator_mnist.pth")print(f"Epoch [{epoch}/{epochs}] Loss_D: {d_loss.item():.4f} Loss_G: {g_loss.item():.4f}")# gen_img_plot(Generator, test_data)
gen_img_plot(generator, test_data)

测试代码:

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import randomdevice = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')class Generator(nn.Module):  # 生成器def __init__(self, latent_dim):super(Generator, self).__init__()self.model = nn.Sequential(nn.Linear(latent_dim, 256),nn.LeakyReLU(0.2),nn.Linear(256, 512),nn.LeakyReLU(0.2),nn.Linear(512, 1024),nn.LeakyReLU(0.2),nn.Linear(1024, 784),nn.Tanh())def forward(self, z):img = self.model(z)img = img.view(img.size(0), 1, 28, 28)return img# test_data = torch.FloatTensor(128, 100).to(device)
test_data = torch.randn(128, 100).to(device)  # 随机噪声model = Generator(100).to(device)
model.load_state_dict(torch.load('Generator_mnist.pth'))
model.eval()pred = np.squeeze(model(test_data).detach().cpu().numpy())for i in range(64):plt.subplot(8, 8, i + 1)plt.imshow((pred[i] + 1) / 2)plt.axis('off')
plt.savefig(fname='image.png', figsize=[5, 5])
plt.show()

三、结果

       在超参数设置 epoch=1000,batch_size=128,lr=0.0001,latent_dim = 100 时,gan生成的权重测的结果如图所示

四,GAN的损失函数曲线

                一开始训练时,我的gan的损失函数的曲线是类似这样的,就是知乎这文章里一样,生成器损失函数的曲线一直发散。首先,这个loss的曲线一看就是网络崩了,一般正常的情况,d_loss的值会一直下降然后收敛,而g_loss的曲线会先增大后减少,最后同样也会收敛。其次,网络拿到手以后先不要训练太多次,容易出现过拟合的情况。

生成对抗网络的损失函数图像如下合理吗? - 知乎

这是训练了10轮的生成器和鉴别器的损失函数值变化吧:

效果如图所示: 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/54673.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux操作系统--shell编程(条件判断)

(1).基本的语法 test condition [ condition ] 注意condition前后要有空格;在使用该种表达式的时候,条件非空即为 true,[ hello ]返回 true,[ ] 返回 false。我们可以通过echo $?来判断上一次执行的情况来判断真假(0真1假)。

探索数据的维度:多元线性回归在实际应用中的威力

文章目录 🍀引言🍀什么是多元线性回归?🍀多元线性回归的应用🍀构建多元线性回归模型的步骤🍀R-squared(R平方)🍀多元线性回归案例---波士顿房价 🍀引言 当谈…

ATF(TF-A)安全通告 TFV-2 (CVE-2017-7564)

安全之安全(security)博客目录导读 ATF(TF-A)安全通告汇总 目录 一、ATF(TF-A)安全通告 TFV-2 (CVE-2017-7564) 二、 CVE-2017-7564 一、ATF(TF-A)安全通告 TFV-2 (CVE-2017-7564) Title 启用安全自托管侵入式调试接口,可允许非安全世界引发安全世界panic CV…

海外ios应用商店优化排名因素之视频预览与截图

当我们找到感兴趣的应用程序并转到该应用程序的页面时,首先引起注意的是预览视频。视频旨在以更具吸引力的方式展示应用程序的用户体验和UI。视频长度最多为30秒,其中前5秒最为重要,一定要让它尽可能引人注目。 1、关于优化预览视频的提示。…

React + Next.js 搭建项目(配有对比介绍一起食用)

文章标题 01 Next.js 是什么02 Next.js 搭建工具 create-next-app03 create-react-app 与 create-next-app 的区别04 快速构建 Next.js 项目05 App Router 与 Pages Router 的区别 01 Next.js 是什么 Next.js 是一个 React 框架,它允许你使用 React 框架建立超强的…

程序的编译链接【编译链接大概步骤】

全文目录 😀 前言🙂 翻译环境和执行环境😶 编译和链接😵‍💫 预编译(预处理)😵‍💫 编译😵‍💫 汇编😵‍💫 链接 &#x1…

分布式定时任务框架Quartz总结和实践(2)—持久化到Mysql数据库

本文主要介绍分布式定时任务框架Quartz集成SpringBoot持久化数据到Mysql数据库的操作,上一篇文章使用Quartz创建定时任务都是保存在内存中,如果服务重启定时任务就会失效,所以Quartz官方也提供将定时任务等信息持久化到Mysql数据库的功能&…

【ES6】—数组的扩展

一、类数组/ 伪数组 1. 类/伪数组: 并不是真正意义的数组,有长度的属性,但无法使用Array原型上的方法 let divs document.getElementsByTagName(div) console.log(divs) // HTMLCollection []let divs2 document.getElementsByClassName("xxx&q…

Git gui教程---第七篇 Git gui的使用 返回上一次提交

1. 查看历史,打开gitk程序 2. 选中需要返回的版本,右键,然后点击Rest master branch to here 3.出现弹窗 每个选项我们都试一下,从Hard开始 返回的选项 HardMixedSoft Hard 会丢失所有的修改【此处的…

从0开始做yolov5模型剪枝

文章目录 从0开始做yolov5模型剪枝 ****1 前言2 GitHub取源码3 原理3.1 原理3.2 network slimming过程 4 具体实施步骤4.1 安装虚拟环境4.2 配置参数4.2.1 数据集参数4.2.2 模型结构参数4.2.3 train.py中的参数 4.3 正常训练4.3.1 准备4.3.2 训练及问题解决 4.4 稀疏化训练4.4.…

Leetcode 2235.两整数相加

一、两整数相加 给你两个整数 num1 和 num2,返回这两个整数的和。 示例 1: 输入:num1 12, num2 5 输出:17 解释:num1 是 12,num2 是 5 ,它们的和是 12 5 17 ,因此返回 17 。示例…

渗透测试方法论

文章目录 渗透测试方法论1. 渗透测试种类黑盒测试白盒测试脆弱性评估 2. 安全测试方法论2.1 OWASP TOP 102.3 CWE2.4 CVE 3. 渗透测试流程3.1 通用渗透测试框架3.1.1 范围界定3.1.2 信息搜集3.1.3 目标识别3.1.4 服务枚举3.1.5 漏洞映射3.1.6 社会工程学3.1.7 漏洞利用3.1.8 权…

[LitCTF 2023]Flag点击就送!

进入环境后是一个输入框,可以提交名字 然后就可以点击获取flag,结果回显提示,需要获取管理员 可以尝试将名字改为admin 触发报错,说明可能存在其他的验证是否为管理员的方式 通过抓包后,在cookie字段发现了 特殊的东西…

嵌入式系统入门实战:探索基本概念和应用领域

嵌入式系统是一种专用的计算机系统,它是为了满足特定任务而设计的。这些系统通常具有较低的硬件资源(如处理器速度、内存容量和存储容量),但具有较高的可靠性和实时性。嵌入式系统广泛应用于各种领域,如家用电器、汽车、工业控制、医疗设备等。 嵌入式系统的基本概念 微控…

实战项目 在线学院springcloud调用篇3(nacos,feging,hystrix,gateway)

一 springcloud与springboot的关系 1.1 关系 1.2 版本关系 1.3 list转json串 public class Test {public static void main(String[] args) {List<String> dataListnew ArrayList<String>();dataList.add("12");dataList.add("45");dataLi…

2023国赛数学建模思路 - 案例:退火算法

文章目录 1 退火算法原理1.1 物理背景1.2 背后的数学模型 2 退火算法实现2.1 算法流程2.2算法实现 建模资料 ## 0 赛题思路 &#xff08;赛题出来以后第一时间在CSDN分享&#xff09; https://blog.csdn.net/dc_sinor?typeblog 1 退火算法原理 1.1 物理背景 在热力学上&a…

深入剖析Kubernetes之控制器模式的实现-Deployment

文章目录 Deployment Deployment Deployment 实现了 Kubernetes 项目中一个非常重要的功能&#xff1a;Pod 的“水平扩展 / 收缩”&#xff08;horizontal scaling out/in&#xff09;。这个功能&#xff0c;是从 PaaS 时代开始&#xff0c;一个平台级项目就必须具备的编排能力…

Idea配置Remote Host

一、打开RemoteHost窗口 双击shift打开全局搜索 搜索Tools→Deployment→Browse Remote Host或 idea项目顶部Tools→Deployment→Browse Remote Host 二、添加服务 右侧边栏打开RemoteHost&#xff0c;点击三个点&#xff0c;起个名字&#xff0c;选择type为SFTP&#xff…

使用Nacos与Spring Boot实现配置管理

&#x1f337;&#x1f341; 博主猫头虎 带您 Go to New World.✨&#x1f341; &#x1f984; 博客首页——猫头虎的博客&#x1f390; &#x1f433;《面试题大全专栏》 文章图文并茂&#x1f995;生动形象&#x1f996;简单易学&#xff01;欢迎大家来踩踩~&#x1f33a; &a…

直击成都国际车展:远航汽车多款车型登陆车展,打造完美驾乘体验

随着市场渗透率日益高涨&#xff0c;新能源汽车成为今年成都国际车展的关注焦点。在本届车展上&#xff0c;新能源品牌占比再创新高&#xff0c;覆盖两个展馆&#xff0c;印证了当下新能源汽车市场的火爆。作为大运集团重磅打造的高端品牌&#xff0c;远航汽车深度洞察高端智能…