【Week-G1】调用官方GAN实现MNIST数字识别,Pytorch框架

文章目录

  • 1. 准备数据
    • 1.1 配置超参数
    • 1.2 下载数据
    • 1.3 配置数据
  • 2. 创建模型
    • 2.1 定义鉴别器
    • 2.2 定义生成器
  • 3. 训练模型
    • 3.1 创建实例
    • 3.2 开始训练
    • 3.3 保存模型
  • 4. 什么是GAN(对抗生成网络)?

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

说明:
(1)使用CPU时,屏蔽.cuda(),否则报错:
在这里插入图片描述

1. 准备数据

系统环境:
语言:Python3.7.8
编译器:VSCode
深度学习框架:torch 1.13.1

1.1 配置超参数

print("***********1.1 配置超参数*****************")
import argparse
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch.nn as nn
import torch
## 创建文件夹
# 程序在路径 D:\jupyter notebook\DL-100-days\下运行,也就是下方的 ./
os.makedirs("./GAN/G1/images/", exist_ok=True)  # 记录训练过程的图片效果
os.makedirs("./GAN/G1/save/", exist_ok=True)    # 训练完成时,模型的保存位置
os.makedirs("./GAN/G1/mnist/", exist_ok=True)   # 下载数据集存放的位置
## 超参数配置
n_epochs  = 50
batch_size = 64
lr  = 0.0002
b1 = 0.5
b2 = 0.999
n_cpu = 2
latent_dim = 100
img_size = 28
channels = 1
sample_interval = 500
#图像的尺寸(1, 28, 28),和图像的像素面积(784)
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)
#设置cuda: (cuda:0)
cuda = True if torch.cuda.is_available() else False
print("CUDA: ", cuda)
print("\n")

文件路径如下图:
在这里插入图片描述
使用CPU版本,所以打印的CUDA结果为FALSE;
在这里插入图片描述

1.2 下载数据

print("***********2. 下载数据*****************")
mnist = datasets.MNIST(root='./datasets/', train=True, download=True, transform=transforms.Compose([transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),
)
print("\n")

在这里插入图片描述

1.3 配置数据

print("***********1.3 配置数据*****************")
dataloader = DataLoader(mnist,batch_size=batch_size,shuffle=True
)
print("\n")

2. 创建模型

2.1 定义鉴别器

print("***********2. 创建模型********************")
print("***********2.1 定义鉴别器*****************")
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(img_area, 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1),nn.Sigmoid(),)def forward(self, img):img_flat = img.view(img.size(0), -1)validity = self.model(img_flat)return validity
print("\n")

2.2 定义生成器

print("***********2.2 定义生成器*****************")
class Generate(nn.Module):def __init__(self):super(Generate, self).__init__()def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*block(latent_dim, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, img_area),nn.Tanh())def forward(self, z):imgs = self.model(z)imgs = imgs.view(imgs.size(0), *img_shape)return imgs
print("\n")

3. 训练模型

3.1 创建实例

print("***********3. 训练模型*****************")
print("***********3.1 创建实例****************")
generator = Generate()
discriminator = Discriminator()criterion = torch.nn.BCELoss()optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))if torch.cuda.is_available():generator = generator.cuda()discriminator = discriminator#.cuda()criterion = criterion.cuda()
print("\n")

3.2 开始训练

print("***********3.2 开始训练*****************")
for epoch in range(n_epochs):  # epoch:50for i, (imgs, _) in enumerate(dataloader):## =============================训练判别器==================## view(): 相当于numpy中的reshape,重新定义矩阵的形状, 相当于reshape(128,784)  原来是(128, 1, 28, 28)imgs = imgs.view(imgs.size(0), -1)  # 将图片展开为28*28=784  imgs:(64, 784)real_img = Variable(imgs).cuda()  # 将tensor变成Variable放入计算图中,tensor变成variable之后才能进行反向传播求梯度real_label = Variable(torch.ones(imgs.size(0), 1))#.cuda()  ## 定义真实的图片label为1fake_label = Variable(torch.zeros(imgs.size(0), 1))#.cuda()  ## 定义假的图片的label为0## ---------------------##  Train Discriminator## 分为两部分:1、真的图像判别为真;2、假的图像判别为假## ---------------------## 计算真实图片的损失real_out = discriminator(real_img)  # 将真实图片放入判别器中loss_real_D = criterion(real_out, real_label)  # 得到真实图片的lossreal_scores = real_out  # 得到真实图片的判别值,输出的值越接近1越好## 计算假的图片的损失## detach(): 从当前计算图中分离下来避免梯度传到G,因为G不用更新z = Variable(torch.randn(imgs.size(0), latent_dim))#.cuda()  ## 随机生成一些噪声, 大小为(128, 100)fake_img = generator(z).detach()  ## 随机噪声放入生成网络中,生成一张假的图片。fake_out = discriminator(fake_img)  ## 判别器判断假的图片loss_fake_D = criterion(fake_out, fake_label)  ## 得到假的图片的lossfake_scores = fake_out  ## 得到假图片的判别值,对于判别器来说,假图片的损失越接近0越好## 损失函数和优化loss_D = loss_real_D + loss_fake_D  # 损失包括判真损失和判假损失optimizer_D.zero_grad()  # 在反向传播之前,先将梯度归0loss_D.backward()  # 将误差反向传播optimizer_D.step()  # 更新参数## -----------------##  Train Generator## 原理:目的是希望生成的假的图片被判别器判断为真的图片,## 在此过程中,将判别器固定,将假的图片传入判别器的结果与真实的label对应,## 反向传播更新的参数是生成网络里面的参数,## 这样可以通过更新生成网络里面的参数,来训练网络,使得生成的图片让判别器以为是真的, 这样就达到了对抗的目的## -----------------z = Variable(torch.randn(imgs.size(0), latent_dim))#.cuda()  ## 得到随机噪声fake_img = generator(z)  ## 随机噪声输入到生成器中,得到一副假的图片output = discriminator(fake_img)  ## 经过判别器得到的结果## 损失函数和优化loss_G = criterion(output, real_label)  ## 得到的假的图片与真实的图片的label的lossoptimizer_G.zero_grad()  ## 梯度归0loss_G.backward()  ## 进行反向传播optimizer_G.step()  ## step()一般用在反向传播后面,用于更新生成网络的参数## 打印训练过程中的日志## item():取出单元素张量的元素值并返回该值,保持原元素类型不变if (i + 1) % 300 == 0:print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"% (epoch, n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(),fake_scores.data.mean()))## 保存训练过程中的图像batches_done = epoch * len(dataloader) + iif batches_done % sample_interval == 0:save_image(fake_img.data[:25], "./GAN/G1/images/%d.png" % batches_done, nrow=5, normalize=True)
print("\n")

训练结果:

***********3.2 开始训练*****************
[Epoch 0/50] [Batch 299/938] [D loss: 1.363185] [G loss: 1.811396] [D real: 0.922554] [D fake: 0.692913]
[Epoch 0/50] [Batch 599/938] [D loss: 0.753042] [G loss: 2.229975] [D real: 0.815112] [D fake: 0.410212]
[Epoch 0/50] [Batch 899/938] [D loss: 1.049122] [G loss: 1.940738] [D real: 0.789812] [D fake: 0.548839]
... ... ...
[Epoch 48/50] [Batch 299/938] [D loss: 0.956054] [G loss: 1.398938] [D real: 0.661061] [D fake: 0.327662]
[Epoch 48/50] [Batch 599/938] [D loss: 1.070262] [G loss: 0.950201] [D real: 0.538358] [D fake: 0.234096]
[Epoch 48/50] [Batch 899/938] [D loss: 1.012980] [G loss: 1.247620] [D real: 0.650552] [D fake: 0.319423]
[Epoch 49/50] [Batch 299/938] [D loss: 1.254801] [G loss: 1.048441] [D real: 0.522079] [D fake: 0.313869]
[Epoch 49/50] [Batch 599/938] [D loss: 0.884523] [G loss: 1.709880] [D real: 0.767361] [D fake: 0.402201]
[Epoch 49/50] [Batch 899/938] [D loss: 1.019181] [G loss: 1.608823] [D real: 0.739194] [D fake: 0.421154]

./GAN/G1/images/下的缩略图如下:
在这里插入图片描述
部分详细图如下:
在这里插入图片描述

3.3 保存模型

print("***********3.3 保存模型*****************")
torch.save(generator.state_dict(), "./GAN/G1/save/generator.pth")
torch.save(discriminator.state_dict(). "./GAN/G1/save/discriminator.pth")
print("\n")

保存的模型文件如下:
在这里插入图片描述

4. 什么是GAN(对抗生成网络)?

【详解1】
【详解2】

机器学习的模型大体分为两类:判别模型(Discriminative Model)和生成模型(Generative Model)。

  • 判别模型:输入变量,使用模型进行预测
  • 生成模型:给出目标的隐含信息,随机产生观测数据。比如:给出一系列猫的图片,来生成一张新的猫的图片。重要点在于“生成”二字。

GAN:适用于无监督学习,该网络的框架由(至少)两个模块构成,即判别模型(Discriminative Model)和生成模型(Generative Model),通过二者的互相博弈学习来产生相当好的输出。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/861567.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【BES2500x系列 -- RTX5操作系统】深入探索CMSIS-RTOS RTX -- 同步与通信篇 -- 信号量和互斥锁 --(三)

💌 所属专栏:【BES2500x系列】 😀 作  者:我是夜阑的狗🐶 🚀 个人简介:一个正在努力学技术的CV工程师,专注基础和实战分享 ,欢迎咨询! &#x1f49…

帕金森患者吞咽困难?如何让饮食更顺畅!

在帕金森病患者的日常生活中,吞咽困难是一个常见而又棘手的问题。它不仅影响了患者的饮食质量,还可能导致营养不良、吸入性肺炎等严重并发症。那么,面对帕金森综合症导致的吞咽困难,我们该如何应对呢? 一、了解帕金森综…

一个去掉PDF背景水印的思路

起因 昨天测试 使用“https://github.com/VikParuchuri/marker” 将 pdf 转 Markdown的过程中,发现转换后的文件中会保护一些背景图片,是转换过程中,程序把背景图识别为了内容。于是想着怎么把背景图片去掉。 背景水印图片的特征 我这里拿…

openGuass数据库极简版安装和远程连接实战(阿里云服务器操作)

openGauss部署之后,在服务器上提供了在命令行下运行的数据库连接工具gsql。此工具除了具备操作数据库的基本功能,还提供了若干高级特性,便于用户使用。但图形化工具除了官方的Data Studio外,还可以使用SQLynx进行连接(…

Taro + vue3 中微信小程序中实现拉起支付

在前端开发中 H5 的拉起支付和微信小程序的拉起支付 是不太一样的 现在分享一下微信小程序中的拉起支付 逻辑都在后端 我是用的Taro 框架 其实就是一个Api Taro 文档

酷开系统丨开启家庭智能教育让学习成为一种乐趣

在数字化时代,孩子们接触的信息日益增多,而酷开系统洞察到了家长对孩子成长环境的关切。酷开系统,作为家庭娱乐与教育的融合平台,不仅注重提供丰富的教育资源,更致力于创造一个健康、有益的学习和娱乐环境。 在酷开系…

【数据同步】什么是ETL增量抽取?

目录 一、什么是ETL增量抽取 二、企业如何应用ETL增量抽取 三、如何进行ETL增量抽取 1.基于时间戳的增量抽取 2.基于主键的增量抽取 在当今信息化时代,数据的快速增长和多样化使得企业面临着巨大的数据管理挑战。为了高效地处理和利用数据,ETL&#xff0…

零知识证明基础:对称加密与非对称加密

1、绪论 在密码学体系中,对称加密、非对称加密、单向散列函数、消息认证码、数字签名和伪随机数生成器被统称为密码学家的工具箱。其中,对称加密和非对称加密主要是用来保证机密性;单向散列函数用来保证消息的完整性;消息认证码的…

权限 chmod

参考: Linux chmod 命令 | 菜鸟教程 (runoob.com) Linux chmod(英文全拼:change mode)命令是控制用户对文件的权限的命令 Linux/Unix 的文件调用权限分为三级 : 文件所有者(Owner Users)用户组&#xff08…

Arduino - MG996R

Arduino - MG996R In this tutorial, we are going to learn how to use the MG996R high-torque servo motor with Arduino. 在本教程中,我们将学习如何将MG996R高扭矩伺服电机与Arduino一起使用。 Hardware Required 所需硬件 1Arduino UNO or Genuino UNO Ard…

windows系统如何快速查看显卡详情信息

winR,输入dxdiag 打开DirectX诊断工具,可以看到显卡的详细硬件信息

Vue原生写全选反选框

效果 场景:Vue全选框在头部,子框在v-for循环内部。 实现:点击全选框,所有子项选中,再次点击取消;子项全选中,全选框自动勾选,子项并未全选,全选框不勾选;已选…

国产音频放大器工作原理以及应用领域

音频放大器是在产生声音的输出元件上重建输入的音频信号的设备,其重建的信号音量和功率级都要理想:如实、有效且失真低。音频范围为约20Hz~20000Hz,因此放大器在此范围内必须有良好的频率响应(驱动频带受限的扬声器时要…

无人机操作注意事项

检查飞行设备 每次飞行前,要认真检查无人机的各处细节,遥控器等地面设备也不例外。 确保设备电量充足 起飞前,检查无人机是否电量充足,以及辅助设备如遥控器、手机等。 选择空旷的飞行场地 选择适宜的场地进行操作&#xff0…

机器学习原理和代码实现专辑

1. 往期文章推荐 1.【机器学习】图神经网络(NRI)模型原理和运动轨迹预测代码实现 2. 【机器学习】基于Gumbel-Sinkhorn网络的“潜在排列问题”求解 3. 【机器学习】基于Gumbel Top-k松弛技术的图形采样 4. 【机器学习】基于Softmax松弛技术的离散数据采样 5. 【机器学习】正则…

GNU、Unix、Linux、Makefile、GCC、GDB、GPL、CentOS 7、Ubuntu之间的关系

全文总结 早期,Unix系统作为一类强大的操作系统,在计算领域奠定了基础。然而,出于对软件自由的追求,Richard Stallman在1983年发起了GNU项目,旨在创建一个完全自由的、与Unix兼容的操作系统。GNU项目不仅倡议软件自由…

空间转录组学联合单细胞转录组学揭示卵巢癌生存相关受配体对

卵巢癌,作为女性生殖系统中的一种常见恶性肿瘤,其高级别浆液性卵巢癌(HGSC)亚型尤其致命。尽管多数患者对初次治疗反应良好,但超过75%的晚期HGSC患者会在治疗后复发,并且对化疗药物产生耐药性。然而&#x…

vs code + Keil Assistant 配置 Keil 单片机开发

1、 先安装vscode完成后 安装插件 2 安装C/C 与 keil Assistant 需说明一下 Assistant 1.7.0版本有bug F7按不了 所以安装1.6.2版本 以下是我的安装插件 EMBEDDED IDE 可安装 可不安装 随便你 3 配置 Assistant 4、设置C/C 目录 ${workspaceFolder}/**D:/Keil_v5/C51/INC/**…

排序算法系列一:选择排序、插入排序 与 希尔排序

零、说在前面 本文是一个系列,入口请移步这里 一、理论部分 1.1:选择排序 1.1.1:算法解读: 使用二分法和插入排序两种算法的思想来实现。流程分为“拆分”、“合并”两大部分,前者就是普通的二分思想,将…

应急响应靶机-Linux(2)

前言 本次应急响应靶机采用的是知攻善防实验室的Linux-2应急响应靶机 靶机下载地址为: https://pan.quark.cn/s/4b6dffd0c51a 相关账户密码: root/Inch957821.(记住要带最后的点.) 解题 启动靶机 不建议直接使用账号密码登录,建议用另一台主…