Pytorch从零开始实战17

Pytorch从零开始实战——生成对抗网络入门

本系列来源于365天深度学习训练营

原作者K同学

文章目录

  • Pytorch从零开始实战——生成对抗网络入门
    • 环境准备
    • 模型定义
    • 开始训练
    • 总结

环境准备

本文基于Jupyter notebook,使用Python3.8,Pytorch1.8+cpu,本次实验目的是了解生成对抗网络。

生成对抗网络(GAN)是一种深度学习模型。GAN由两个主要组成部分组成:生成器和判别器。这两个部分通过对抗的方式共同学习,使得生成器能够生成逼真的数据,而判别器能够区分真实数据和生成的数据。

生成器的任务是生成与真实数据相似的样本。它接收一个随机噪声向量,然后通过深度神经网络将这个随机噪声转换为具体的数据样本。在图像生成的场景中,生成器通常将随机噪声映射为图像。生成器的目标是欺骗判别器,使其无法区分生成的样本和真实的样本。生成器的训练目标是最小化生成的样本与真实样本之间的差异。

判别器的任务是对给定的样本进行分类,判断它是来自真实数据集还是由生成器生成的。它接收真实样本和生成样本,然后通过深度神经网络输出一个概率,表示输入样本是真实样本的概率。判别器的目标是准确地分类样本,使其能够正确地区分真实数据和生成的数据。判别器的训练目标是最大化正确分类的概率。

导入相关包。

import torch
import torch.nn as nn
import argparse
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

创建文件夹,分别保存训练过程中的图像、模型参数和数据集。

os.makedirs("./images/", exist_ok=True) # 训练过程中图片效果
os.makedirs("./save/", exist_ok=True) # 训练完成时模型保存位置
os.makedirs("./datasets/", exist_ok=True) # 数据集位置

设置超参数。
b1、b2为Adam优化算法的参数,其中b1是梯度的一阶矩估计的衰减系数,b2是梯度的二阶矩估计的衰减系数。
latent_dim表示生成器输入的随机噪声向量的维度。这个噪声向量用于生成器产生新样本。
sample_interval表示在训练过程中每隔多少个batch保存一次生成器生成的样本图像,以便观察生成效果。

epochs = 20
batch_size = 64
lr = 0.0002
b1 = 0.5
b2 = 0.999
latent_dim=100
img_size=28
channels=1
sample_interval=500

设定图像尺寸并检查cuda,本次使用的设备没有cuda。

img_shape = (channels, img_size, img_size) # (1, 28, 28)
img_area = np.prod(img_shape) # 784## 设置cuda
cuda = True if torch.cuda.is_available() else False
print(cuda) # False

本次使用GAN来生成手写数字,首先下载mnist数据集。

mnist = datasets.MNIST(root='./datasets/', train=True, download=True,transform=transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]))

使用dataloader划分批次与打乱。

dataloader = DataLoader(mnist,batch_size=batch_size,shuffle=True,
)len(dataloader) # 938

模型定义

首先定义鉴别器模型,代码中LeakyReLU是ReLU激活函数的变体,它引入了一个小的负斜率,在负输入值范围内,而不是将它们直接置零。这个斜率通常是一个小的正数,例如0.01。

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(img_area, 512),        nn.LeakyReLU(0.2, inplace=True),  nn.Linear(512, 256),             nn.LeakyReLU(0.2, inplace=True), nn.Linear(256, 1),              nn.Sigmoid(),                    )def forward(self, img):img_flat = img.view(img.size(0), -1) validity = self.model(img_flat)     return validity         # 返回的是一个[0, 1]间的概率

定义生成器模型,用于输出图像。

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):      layers = [nn.Linear(in_feat, out_feat)]          if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8)) layers.append(nn.LeakyReLU(0.2, inplace=True))   return layersself.model = nn.Sequential(*block(latent_dim, 128, normalize=False), *block(128, 256),                         *block(256, 512),                         *block(512, 1024),                       nn.Linear(1024, img_area),                nn.Tanh()                                )def forward(self, z):                           imgs = self.model(z)                       imgs = imgs.view(imgs.size(0), *img_shape)  # reshape成(64, 1, 28, 28)return imgs                                 # 输出为64张大小为(1, 28, 28)的图像

开始训练

创建生成器、判别器对象。

generator = Generator()
discriminator = Discriminator()

定义损失函数。这个其实就是二分类的交叉熵损失。

criterion = torch.nn.BCELoss()

定义优化函数。

optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

开始训练,实现GAN训练过程,其中生成器和判别器交替训练,通过对抗过程使得生成器生成逼真的图像,而判别器不断提高对真实和生成图像的判别能力。

for epoch in range(epochs):                   # epoch:50for i, (imgs, _) in enumerate(dataloader):  # imgs:(64, 1, 28, 28)     _:label(64)imgs = imgs.view(imgs.size(0), -1)    # 将图片展开为28*28=784  imgs:(64, 784)real_img = Variable(imgs)     # 将tensor变成Variable放入计算图中,tensor变成variable之后才能进行反向传播求梯度real_label = Variable(torch.ones(imgs.size(0), 1))    ## 定义真实的图片label为1fake_label = Variable(torch.zeros(imgs.size(0), 1))    ## 定义假的图片的label为0real_out = discriminator(real_img)            # 将真实图片放入判别器中loss_real_D = criterion(real_out, real_label) # 得到真实图片的lossreal_scores = real_out                        # 得到真实图片的判别值,输出的值越接近1越好## 计算假的图片的损失## detach(): 从当前计算图中分离下来避免梯度传到G,因为G不用更新z = Variable(torch.randn(imgs.size(0), latent_dim))     ## 随机生成一些噪声, 大小为(128, 100)fake_img    = generator(z).detach()                                    ## 随机噪声放入生成网络中,生成一张假的图片。fake_out    = discriminator(fake_img)                                  ## 判别器判断假的图片loss_fake_D = criterion(fake_out, fake_label)                       ## 得到假的图片的lossfake_scores = fake_out## 损失函数和优化loss_D = loss_real_D + loss_fake_D  # 损失包括判真损失和判假损失optimizer_D.zero_grad()             # 在反向传播之前,先将梯度归0loss_D.backward()                   # 将误差反向传播optimizer_D.step()                  # 更新参数z = Variable(torch.randn(imgs.size(0), latent_dim))     ## 得到随机噪声fake_img = generator(z)                                             ## 随机噪声输入到生成器中,得到一副假的图片output = discriminator(fake_img)                                    ## 经过判别器得到的结果## 损失函数和优化loss_G = criterion(output, real_label)                              ## 得到的假的图片与真实的图片的label的lossoptimizer_G.zero_grad()                                             ## 梯度归0loss_G.backward()                                                   ## 进行反向传播optimizer_G.step()                                                  ## step()一般用在反向传播后面,用于更新生成网络的参数## 打印训练过程中的日志## item():取出单元素张量的元素值并返回该值,保持原元素类型不变if (i + 1) % 100 == 0:print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"% (epoch, epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean()))## 保存训练过程中的图像batches_done = epoch * len(dataloader) + iif batches_done % sample_interval == 0:save_image(fake_img.data[:25], "./images/%d.png" % batches_done, nrow=5, normalize=True)

在这里插入图片描述
保存模型。

torch.save(generator.state_dict(), './save/generator.pth')
torch.save(discriminator.state_dict(), './save/discriminator.pth')

查看最初的噪声图像。
在这里插入图片描述
查看后面生成的图像。
在这里插入图片描述

总结

对于GAN,生成器的任务是从随机噪声生成逼真的数据样本,判别器的任务是对给定的数据样本进行分类,判断其是真实数据还是由生成器生成的。生成器和判别器通过对抗的方式进行训练。在每个训练迭代中,生成器试图生成逼真的样本以欺骗判别器,而判别器努力提高自己的能力,以正确地区分真实和生成的样本。这种对抗训练的动态平衡最终导致生成器生成高质量、逼真的样本。

总之,GAN实现了在无监督情况下学习数据分布的能力,被广泛用于生成逼真图像、视频等数据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/635887.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Linux】更新python版本并设置默认版本

文章目录 1.查看目前python的版本2.添加软件源并更新3.选择你想要下载的版本4.设置默认版本 网上有很多教程都是教导小白去官方下载之后编译安装。但是,小白连cmake是什么都不知道,这种教导方式实在是误人子弟。这里作者介绍了一种十分简洁的更新方法。 …

YOLOv8改进:RepBiPAN结构 + DETRHead检测头,为YOLOv8目标检测使用不一样的检测头,用于提升检测精度

💡本篇内容:YOLOv8全新Neck改进:RepBiPAN 结构升级版,为目标检测打造全新融合网络,增强定位信号,对于小目标检测的定位具有重要意义 💡🚀🚀🚀本博客 改进源代码改进 适用于 YOLOv8 按步骤操作运行改进后的代码即可 💡本文改进 Neck部分和DETRHead系列检测头…

8.2摆动序列(LC376-M)

算法: 其实动态规划和贪心算法都能做 但是动态规划的时间复杂度是O(n^2) 贪心算法的时间复杂度是O(n) 所以学习贪心算法 到底为什么用贪心?(分析局部最优和全局最优) 局部最优:删除单调坡度上的节点(不…

第一篇【传奇开心果系列】beeware开发移动应用:轮盘抽奖移动应用

系列博文目录 beeware开发移动应用示例系列博文目录一、项目目标二、开发传奇开心果轮盘抽奖安卓应用编程思路三、传奇开心果轮盘抽奖安卓应用示例代码四、补充抽奖逻辑实现五、开发传奇开心果轮盘抽奖苹果手机应用编程思路六、开发传奇开心果轮盘抽奖苹果手机应用示例代码七、…

如何配置Pycharm服务器并结合内网穿透工具实现远程开发

🔥博客主页: 小羊失眠啦. 🎥系列专栏:《C语言》 《数据结构》 《Linux》《Cpolar》 ❤️感谢大家点赞👍收藏⭐评论✍️ 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,…

NLP深入学习(五):HMM 详解及字母识别/天气预测用法

文章目录 0. 引言1. 什么是 HMM2. HMM 的例子2.1 字母序列识别2.2 天气预测 3. 参考 0. 引言 前情提要: 《NLP深入学习(一):jieba 工具包介绍》 《NLP深入学习(二):nltk 工具包介绍》 《NLP深入…

JS的数据类型和运算符

typeof()方法:检测数据类型 JS中的基本数据类型 基本数据类型 1.number 数字 2.string 字符串 3.boolean 布尔 4.null 代表空值(typeof方法检测出来的数据类型是object类型) 5.underfined 未定义;变量已声明但是未赋值 6.…

QT Model/View 设计模式中 selection 模型

1. QT 的 selection 模型是用来做什么的? Qt的selection模型用于管理TableView中的选择操作。它允许用户选择和操作特定的数据。 2. Selection 模型用途的例子? 当使用Qt的TableView时,可以使用selection模型来实现以下用途: …

vue路由-全局前置守卫

1. 介绍 详见:全局前置守卫网址 使用场景: 对于支付页,订单页等,必须是登录的用户才能访问的,游客不能进入该页面,需要做拦截处理,跳转到登录页面 全局前置守卫的原理: 全局前置…

KubeSphere 核心实战之二【在kubesphere平台上部署redis】(实操篇 2/4)

文章目录 1、登录kubesphere平台2、redis部署分析3、redis容器启动代码4、kubesphere平台部署redis4.1、创建redis配置集4.2、创建redis工作负载4.3、创建redis服务 5、测试连接redis 在kubesphere平台上部署redis应用都是基于redis镜像进行部署的,所以所有的部署操…

2022-ECCV-Adaptive Face Forgery Detection in Cross Domain

一、研究背景 1.伪造视频是逐帧生成的,因此会造成时间维度上的伪影。而鲁棒的检测模型需要对同一身份的不同帧有一致的检测结果。 1.利用频率线索进行deepfake检测效果良好,但也会导致帧间不一致问题,即不同帧检测结果不同。 2.以往方法中固定…

【JavaScript】面向后端快速学习 笔记

文章目录 JS是什么?一、JS导入二、数据类型 变量 运算符三、流程控制四、函数五、对象 与 JSON5.1 对象5.2 JSON5.3 常见对象1. 数组2. Boolean对象3. Date对象4. Math5. Number6. String 六、事件6.1 常用方法1. 鼠标事件2. 键盘事件3. 表单事件 6.2 事件的绑定**1…

【咕咕送书 | 第八期】羡慕同学进了大厂核心部门,看懂这本书你也能行!

🎬 鸽芷咕:个人主页 🔥 个人专栏:《linux深造日志》《粉丝福利》 ⛺️生活的理想,就是为了理想的生活! ⛳️ 写在前面参与规则 ✅参与方式:关注博主、点赞、收藏、评论,任意评论(每人最多评论…

VMware虚拟机设置NAT网络模式

查看本地服务器网卡ip10.9.158.77 设置vmNet8虚拟网卡ip10.9.58.177,不需要在同一网段 3.点击VMware设置“虚拟网络编辑器”,点击“NAT设置”所有设置的ip网段需要与第二步的VMNet8网卡的网一致

LeetCode 2788. 按分隔符拆分字符串

一、题目 1、题目描述 给你一个字符串数组 words 和一个字符 separator ,请你按 separator 拆分 words 中的每个字符串。 返回一个由拆分后的新字符串组成的字符串数组,不包括空字符串 。 注意 separator 用于决定拆分发生的位置,但它不包含…

算法 动态分析 及Java例题讲解

动态规划 动态规划(英语:Dynamic programming,简称 DP),是一种在数学、管理科学、计算机科学、经济学和生物信息学中使用的,通过把原问题分解为相对简单的子问题的方式求解复杂问题的方法。动态规划常常适…

day-15 按分隔符拆分字符串

思路 依次对words的每个字符进行split(),然后将非空的加入List 解题方法 String arr[]s.split(ss);利用split()方法将words的每个字符串划分为String数组 if(arr[i]!“”) //将非空的加入 list.add(arr[i]); String ss“”separator; //使用转义字符 时间复杂度:…

HCIA——18实验:NAT

学习目标: NAT 学习内容: NAT 1.要求——基本的 2.模型 3.IP分配、规划、优化 1)思路 R2为ISP路由器,其上只能配置ip地址,不得冉进行其他的任何配置—ospf配置 认证 、汇总、沉默接口、加快收敛、缺省路由 PC1-PC2…

配置免费的SSL

1 引言 本文介绍了如何在 Linux 环境下使用免费的 Let’s Encrypt 为你的网站配置 SSL 证书的方法,以及如何在 Nginx 服务器中启用 SSL。对于需要在自己的网站上启用 HTTPS 的用户来说非常实用。 2 SSL 简介 SSL,全称为 Secure Sockets Layer&#xf…

React一般可以用哪些值作为key?

在 React 中,key 是用来帮助 React 核对 Virtual DOM 中的节点是否发生变化的。key 值唯一且稳定有助于提高渲染性能,因为 React 可以根据 key 值判断哪些元素需要重新渲染。 一般来说,以下属性可以作为 key 值: 数据库中的 ID&a…