人工智能应用-实验8-用生成对抗网络生成数字图像

文章目录

    • 🧡🧡实验内容🧡🧡
    • 🧡🧡代码🧡🧡
    • 🧡🧡分析结果🧡🧡
    • 🧡🧡实验总结🧡🧡

🧡🧡实验内容🧡🧡

以MNIST 数据集为训练数据,用生成对抗网络生成手写数字 5的图像(编程语言不限,如Python 等)。


🧡🧡代码🧡🧡

import torch
from torch import nn
from torch.optim import Adam
import torch.nn.functional as F
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
import time
import pandastransform = transforms.Compose([transforms.ToTensor(),
])train_set = datasets.MNIST('data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=False) # 批次为1,不打乱数据# !nvidia-smi
# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)#@title 模型
#返回size大小的均值为0,均方误差为1的随机数
def generate_random(size):random_data = torch.randn(size)return random_data# def generate_random(size): # 均匀分布的随机数,会产生模式崩溃
#     random_data = torch.rand(size)
#     return random_data#判别器
class Discriminator(nn.Module):def __init__(self):super().__init__()self.model=nn.Sequential(nn.Linear(784, 200), # 全连接层 784维特征(像素点) => 200维特征nn.LeakyReLU(0.02), # 激活层:f(x)=max(ax,x) ann.LayerNorm(200), # 归一化层nn.Linear(200, 1), # 全连接层 200维特征(像素点) => 1维标量nn.Sigmoid() # 将1维标量缩放结果到0-1之间,以0.5作为二分类结果)self.loss_function = nn.BCELoss() # 定义损失函数self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001) # 创建优化器,使用Adam梯度下降# 计数器和损失记录self.counter = 0self.loss_list = []def forward(self, inputs):return self.model(inputs)def train(self, inputs, targets):outputs = self.forward(inputs)  # 计算网络前向传播输出loss = self.loss_function(outputs, targets) # 计算损失值self.counter += 1if (self.counter % 10 == 0): # 每训练10次记录损失值self.loss_list.append(loss.item())if (self.counter % 10000 == 0): # 每训练10000次打印进程print("counter = ", self.counter)self.optimiser.zero_grad() #在反向传播前先把梯度归零loss.backward() #反向传播,计算各参数对于损失loss的梯度self.optimiser.step()  #根据反向传播得到的梯度,更新模型权重参数def plot_loss_process(self):df = pandas.DataFrame(self.loss_list, columns=['Discriminator Loss'])ax = df.plot(figsize=(12,6), alpha=0.1,marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))ax.set_title("Discriminator Loss")# 生成器
class Generator(nn.Module):def __init__(self):super().__init__()# 定义神经网络层self.model = nn.Sequential(nn.Linear(100, 200), # 全连接层 100维噪声 => 200维特征nn.LeakyReLU(0.02), # 激活函数nn.LayerNorm(200), # 标准化nn.Linear(200, 784), # 200维特征 => 784像素特征nn.Sigmoid() # 每个像素点缩放到0-1)# 创建生成器,使用Adam梯度下降self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)# 计数器和损失记录self.counter = 0self.loss_list = []def forward(self, inputs):# 运行模型return self.model(inputs)def train(self, D, inputs, targets):g_output = self.forward(inputs) # 计算网络输出d_output = D.forward(g_output) # 输入判别器loss = D.loss_function(d_output, targets) # 计算损失值self.counter += 1if (self.counter % 10 == 0):  # 每训练10次记录损失值self.loss_list.append(loss.item())# 梯度归零,反向传播,并更新权重self.optimiser.zero_grad()loss.backward()#更新由self.optimiser而不是D.optimiser触发。这样一来,只有生成器的链接权重得到更新self.optimiser.step()def plot_loss_process(self):df = pandas.DataFrame(self.loss_list, columns=['Generator Loss'])ax = df.plot(figsize=(12,6), alpha=0.1,marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))ax.set_title("Generator Loss")D = Discriminator()
G = Generator()
D = D.to(device)
G = G.to(device)#@title train
epochs=1
start_time=time.time()
for epoch in range(epochs):print(f"=============Epoch={epoch}============")for step, (images, labels) in enumerate(train_loader):images = images.to(device)image_data_tensor=images.view(-1)# ==使用真实数据训练判别器, 并标注真实数据为正样本(1)==D.train( image_data_tensor, torch.FloatTensor([1.0]).to(device) )# ==用生成数据(fake)训练判别器, 并标注生成数据为负样本(0)==# 同时使用detach()以避免计算生成器G中的梯度D.train( G.forward(generate_random(100).to(device)).detach(), torch.FloatTensor([0.0]).to(device) )# ==训练生成器, 让判别器对于生成器的生成数据评分尽可能接近正样本(1)==G.train( D, generate_random(100).to(device), torch.FloatTensor([1.0]).to(device) )
print(f"cost all time={(time.time()-start_time)/60} minutes")# 保存模型
torch.save(D, 'GAN_Digits_D.pt')
torch.save(G, 'GAN_Digits_G.pt')
# 加载模型
D=torch.load('GAN_Digits_D.pt')
G=torch.load('GAN_Digits_G.pt')
G.plot_loss_process()
D.plot_loss_process()
# 生成效果图
f, axarr = plt.subplots(2,3, figsize=(16,8))
for i in range(2):for j in range(3):output = G.forward(generate_random(100).to(device))output = output.cpu()img = output.detach().numpy().reshape(28,28)axarr[i,j].imshow(img, interpolation='none', cmap='Blues')

🧡🧡分析结果🧡🧡

数据预处理:
加载数据集:
加载torch库中自带的minst数据集
转换数据:
转为tensor变量(相当于直接除255归一化到值域为(0,1))。
此处不同于CNN和BP网络实验,不再对其进行transforms.Normalize()处理,因为对抗网络中,生成器输入的是一个随机噪声向量,不是预处理后的图像;判别器中,输入的是真实图像和生成图像,而不是预处理后的图像,如果对输入数据进行归一化处理,会改变图像的数值范围,可能会影响判别器的判断结果。

构建对抗网络
构造判别器:
在这里插入图片描述

  • nn.Linear():全连接层,转换特征维度。
  • nn.LeakyReLU(0.02):激活层,激活函数如下,0.02即为negative_slope,用于控制负斜率的角度。相比于不具备负值响应(x<0,则y为0)的传统ReLU,LeakyReLU在负数区间表现的更加平滑,增强非线性表达能力,有助于判别器更好地区分真实样本和真实样本。
    在这里插入图片描述
  • nn.LayerNorm(200):对中间层的输出值进行标准化,让它们均值为0,避免较大值引起的梯度消失。200表示要标准化的维度数目。
  • nn.Sigmoid():将1维标量缩放结果到0-1之间,以0.5作为二分类结果。

构造生成器:
在这里插入图片描述

  • nn.Linear():全连接层,转换特征维度。这里设定输入的随机噪声维度为100,最后输出一张784像素图片。
  • nn.LeakyReLU、nn.LayerNorm、nn.Sigmoid作用同上述类似

选取损失函数:
对于分类问题,损失函数使用二元交叉熵BCELoss()往往比均方误差MSELoss()效果更好。因为它能对正确分类进行奖励,而对错误分类进行惩罚。
由于生成器无需定义损失函数,所以我们只需要修改鉴别器的损失函数即可:

训练和评估
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
每10张图记录1次loss,1次epoch训练60000张图,则1次epoch记录6000次loss,6次epoch记录36000次loss。而1次epoch训练1次生成器,训练2次判别器(1次正样本判别、1次负样本判别),所以生成器loss迭代变化横坐标为36000次,判别器loss迭代变化横坐标为72000次。
在这里插入图片描述
loss迭代变化如下图。
在这里插入图片描述
在这里插入图片描述
从图中整体来看,一开始生成器loss较高,判别器接近0,后面生成器和判别器loss逐渐分布均匀(方差减少,数值大小越来越集中)。

分析生成对抗网络中生成器和判别器的关系
实验中,判别器的loss定义为:区分真实图像和假图像的能力,即loss越小,区分能力越强
而生成器虽然没有直接定义loss,但是利用了判别器的loss,使得判别器对生成器生成的假图像的评分尽可能接近正样本,也即loss越小,生成器生成的假数据越来越接近真实图像。
上述loss的记录迭代次数太多,可能不够直观观察判别器和生成器的相对变化,计算每次epoch的平均loss如下图:
在这里插入图片描述
可以看到,刚开始生成器与判别器的博弈中处于下风,随着训练进行,生成器的loss大幅减少,说明生成器生成的图像越来越逼真,反观判别器loss增大,说明判别器开始处于下风。最后,可以看到两者的loss都趋于平稳,说明此时渐渐达到了博弈平衡,从直观的图像清晰度也能看到,对比训练初期,图像5相比最开始变得比较清晰,但当迭代一定训练次数后,清晰度似乎不再变化了。


🧡🧡实验总结🧡🧡

理论理解:
GAN的核心思想:生成器G和判别器D的一代代博弈

  • 生成器:
    生成网络,通过输入生成图像
  • 判别器:
    二分类网络,将生成器生成图像作为负样本,真实图像作为正样本
  • 优化 判别器D:
    给定G,通过G生成图像产生负样本,并结合真实图像作为正样本来训练D
  • 优化 生成器G:
    给定D,以使得D对G生成图像的评分尽可能接近正样本作为目标来训练G

G和D的训练过程交替进行,这个对抗的过程使得G生成的图像越来越逼真,D辨别的能力也越来越强。

代码实操:

  • 模式崩溃:
    在生成器生成随机数时,若生成的方法不对,可能会导致模式崩溃问题,它指的是生成器倾向于生成相似或重复的样本,而不是多样化的输出(如下图)。
    在这里插入图片描述
    在python中,torch.rand()产生的是0-1之间均匀分布的随机数,很容易导致模式崩溃,因为均匀分布的随机数无法提供足够的多样性,从而使得生成器可能会生成类似的样本。为了解决这个问题,使用torch,randn()函数从高斯分布中抽取随机数,从而增大生成器的多样性。
  • 判断对抗网络模型的收敛情况
    一方面生成器和判别器的损失函数值来监控两者的优化过程,它们的相对变化可以一定程度反映它们的博弈情况,当它们的loss的变化都慢慢趋于平稳时,可以认为模型达到收敛。当然,另一方面,通过观察图像清晰度也是比较直观的方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/840424.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

嵌入式实时操作系统笔记2:UCOS基础知识_UC/OS-III移植(STM32F4)_编写简单的UC/OS-III任务例程(失败.....)

今日学习嵌入式实时操作系统RTOS&#xff1a;UC/OS-III实时操作系统 本文只是个人学习笔记备忘用&#xff0c;附图、描述等 部分都是对网上资料的整合...... 文章主要研究如何将UC/OS-III 移植到 STM32 F407VET6上&#xff0c;提供测试工程下载 &#xff08;2024.5.21 文章未…

Java web应用性能分析之【高并发之缓存-多级缓存】

说到缓存&#xff0c;作为java开发第一时间想到的是不是上图所示的Redis&#xff0c;又或者是Guava Cache、Caffeine、EhCache这些&#xff1b;Redis作为分布式缓存、其他的可以作为本地缓存。但是作为一名资深开发人员&#xff0c;着眼的层面应该再提升一个级别&#xff0c;从…

Prometheus监控平台配置--监控mysql

上一篇中讲述了怎么安装Prometheus&#xff0c;然后对服务器集群资源信息进行监控并通过grafana展示监控信息&#xff0c;在这一篇中我们只讲和mysql相关的监控&#xff0c;关于prometheus的监控原理以及安装可以看下上一篇。 1.上传 通过rz命令将安装包上传到任意目录&#xf…

翻译AnyDoor: Zero-shot Object-level Image Customization

摘要 本研究介绍了AnyDoor&#xff0c;这是一款基于扩散模型的图像生成器&#xff0c;能够在用户指定的位置&#xff0c;以期望的形状将目标对象传送到新场景中。与为每个对象调整参数不同&#xff0c;我们的模型仅需训练一次&#xff0c;就能在推理阶段轻松地泛化到多样化的对…

SpringBoot——整合Redis

目录 Redis 创建Commodity表 启动MySQL和Redis 新建一个SpringBoot项目 pom.xml application.properties Commodity实体类 ComMapper接口 ComService业务层接口 ComServiceImpl业务接口的实现类 ComController控制器 RedisConfig配置类 SpringbootRdisApplication启…

在Visual Studio Code和Visual Studio 2022下配置Clang-Format,格式化成Google C++ Style

项目开发要求好的编写代码格式规范&#xff0c;常用的是根据Google C Style Guide 网上查了很多博文&#xff0c;都不太一样有的也跑不起来&#xff0c;通过尝试之后&#xff0c;自己可算折腾好了&#xff0c;整理一下过程 背景&#xff1a; 编译器主要有三部分&#xff1a;前…

C++第三方库 【HTTP/HTTPS】— httplib库

目录 认识httplib库 安装httplib库 httplib的使用 httplib请求类 httplib响应类 Server类 Client类 httplib库搭建简单服务器&客户端 认识httplib库 httplib库&#xff0c;是一个C11单头文件的&#xff0c;轻量级的跨平台HTTP/HTTPS库&#xff0c;可以用来创建简单的…

【Text2SQL】WikiSQL 数据集与 Seq2SQL 模型

论文&#xff1a;Seq2SQL: Generating Structured Queries from Natural Language using Reinforcement Learning ⭐⭐⭐⭐⭐ ICLR 2018 Dataset: github.com/salesforce/WikiSQL Code&#xff1a;Seq2SQL 模型实现 一、论文速读 本文提出了 Text2SQL 方向的一个经典数据集 —…

Linux--10---安装JDK、MySQL

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 安装JDK[Linux命令--03----JDK .Nginx. 数据库](https://blog.csdn.net/weixin_48052161/article/details/108997148) 第一步 查询系统中自带的JDK第二步 卸载系统中…

Unity Physics入门

概述 在unity中物理属性是非常重要的&#xff0c;它可以模拟真实物理的效果在unity中&#xff0c;其中的组件是非常多的&#xff0c;让我们来学习一下这部分的内容吧。 Unity组件入门篇总目录----------点击导航 Character Controller(角色控制) 说明&#xff1a;组件是Unity提…

华为编程题目(实时更新)

1.大小端整数 计算机中对整型数据的表示有两种方式&#xff1a;大端序和小端序&#xff0c;大端序的高位字节在低地址&#xff0c;小端序的高位字节在高地址。例如&#xff1a;对数字 65538&#xff0c;其4字节表示的大端序内容为00 01 00 02&#xff0c;小端序内容为02 00 01…

【案例分享】医疗布草数字化管理系统:聚通宝赋能仟溪信息科技

内容概要 本文介绍了北京聚通宝科技有限公司与河南仟溪信息科技有限公司合作开发的医疗布草数字化管理系统。该系统利用物联网技术实现了医疗布草生产过程的实时监控和数据分析&#xff0c;解决了医疗布草洗涤厂面临的诸多挑战&#xff0c;包括人工记录、生产低效率和缺乏实时…

DNF手游攻略:角色培养与技能搭配!游戏辅助!

角色培养和技能搭配是《地下城与勇士》中提升战斗力的关键环节。每个职业都有独特的技能和发展路线&#xff0c;合理的属性加点和技能搭配可以最大化角色的潜力&#xff0c;帮助玩家在各种战斗中立于不败之地。接下来&#xff0c;我们将探讨如何有效地培养角色并搭配技能。 角色…

JavaEE之线程(9) _定时器的实现代码

前言 定时器也是软件开发中的一个重要组件. 类似于一个 “闹钟”。 达到一个设定的时间之后&#xff0c;就执行某个指定好的代码&#xff0c;比如&#xff1a; 在受上述场景中&#xff0c;当客户端发出去请求之后&#xff0c; 就要等待响应&#xff0c;如果服务器迟迟没有响应&…

大小字符判断

//函数int my_isalpha(char c)的功能是返回字符种类 //大写字母返回1&#xff0c;小写字母返回-1.其它字符返回0 //void a 调用my_isalpha()&#xff0c;返回大写&#xff0c;输出*&#xff1b;返回小写&#xff0c;输出#&#xff1b;其它&#xff0c;输出&#xff1f; #inclu…

【Linux】Linux的安装

文章目录 一、Linux环境的安装虚拟机 镜像文件云服务器&#xff08;可能需要花钱&#xff09; 未完待续 一、Linux环境的安装 我们往后的学习用的Linux版本为——CentOs 7 &#xff0c;使用 Ubuntu 也可以 。这里提供几个安装方法&#xff1a; 电脑安装双系统&#xff08;不…

深入解析力扣162题:寻找峰值(线性扫描与二分查找详解)

❤️❤️❤️ 欢迎来到我的博客。希望您能在这里找到既有价值又有趣的内容&#xff0c;和我一起探索、学习和成长。欢迎评论区畅所欲言、享受知识的乐趣&#xff01; 推荐&#xff1a;数据分析螺丝钉的首页 格物致知 终身学习 期待您的关注 导航&#xff1a; LeetCode解锁100…

virtual box ubuntu20 全屏展示

virtual box 虚拟机 ubuntu20 系统 全屏展示 ubuntu20.04 视图-自动调整窗口大小 视图-自动调整显示尺寸 系统黑屏解决 ##设备-安装增强功能 ##进入终端 ##终端打不开&#xff0c;解决方案-传送门ubuntu Open in Terminal打不开终端解决方案-CSDN博客 ##点击cd盘按钮进入文…

【RabbitMQ】使用SpringAMQP的Publish/Subscribe(发布/订阅)

Publish/Subscribe **发布(Publish)、订阅(Subscribe)&#xff1a;**允许将同一个消息发送给多个消费者 **注意&#xff1a;**exchange负责消息路由&#xff0c;而不是存储&#xff0c;路由失败则消息丢失 常见的**X(exchange–交换机)***类型&#xff1a; Fanout 广播Direc…

【设计模式】JAVA Design Patterns——Callback(回调模式)

&#x1f50d;目的 回调是一部分被当为参数来传递给其他代码的可执行代码&#xff0c;接收方的代码可以在一些方便的时候来调用它。 &#x1f50d;解释 真实世界例子 我们需要被通知当执行的任务结束时。我们为调用者传递一个回调方法然后等它调用通知我们。 通俗描述 回调是一…