GAN原理 代码解读

模型架构

在这里插入图片描述

代码

数据准备

import os
import time
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
import torch.nn as nn
import torch# 创建文件夹存放图片
os.makedirs("data", exist_ok=True)
transform = transforms.Compose([transforms.ToTensor(), #它会进行0-1归一化,h方向/h,w方向/w。 然后将图片格式转换为 (channel,h,w)transforms.Normalize(0.5,0.5),#把数据归一化为均值为0.5,方差为0.5,图像的数值范围变成-1到1
])
# 下载训练数据后对图片进行transform里的toTensor和用均值方差归一化
train_dataset = datasets.MNIST('data',train=True,transform=transform,download=True)
dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=64,shuffle=True)

定义生成器

'''输入:正态分布随机数噪声(长度为100)输出:生成的图片,(1,28,28)中间过程:linear1: 100 -> 256linear2: 256 -> 512linear3: 512 -> 28*28reshape: 28x28 -> (1,28,28)
'''
class Generator(nn.Module):def __init__(self):super(Generator,self).__init__() # super().__init__() 是调用父类的__init__函数self.model = nn.Sequential(nn.Linear(100,256),nn.ReLU(),nn.Linear(256,512),nn.ReLU(),# 最后一层用tanh激活,将数据压缩到-1到1nn.Linear(512,28*28),nn.Tanh())def forward(self,x):img = self.model(x)img = img.view(-1,28,28,1) # 得到的是28*28=784,把它reshape为 (批量,h,w,channel)return img

定义判别器

'''判别器输入:(1,28,28)的图片输出:二分类的概率值 用sigmoid压缩到0-1之间内容:判别器 推荐使用LeakyRelu,因为生成器难以训练,Relu的负值直接变成0没有梯度了
'''
class Discriminator(nn.Module):def __init__(self):super(Discriminator,self).__init__()self.model = nn.Sequential(nn.Linear(28*28,512),nn.LeakyReLU(),nn.Linear(512,256),nn.LeakyReLU(),nn.Linear(256,1),nn.Sigmoid(),)def forward(self,x):x = x.view(-1,28*28)x = self.model(x)return x

初始化模型,优化器及损失计算函数

device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device) # 初始化并放到了相应的设备上
dis = Discriminator().to(device)
dis_optim = torch.optim.Adam(dis.parameters(),lr=0.0001)
gen_optim = torch.optim.Adam(gen.parameters(),lr=0.0001)
bce_loss = torch.nn.BCELoss()

画生成器生成的图的绘图函数

def gen_img_plot(model,epoch,test_input):prediction = model(test_input).detach().cpu().numpy() # 放在内存上 并转换为Numpyprediction = np.squeeze(prediction) # np.squeeze是一个numpy函数,删除数组中形状为1的维度fig = plt.figure(figsize=(4,4))for i in range(16): # 迭代这n张图片plt.subplot(4,4,i+1)plt.imshow((prediction[i] + 1) / 2) # 生成器生成的图片是-1到1之间的,无法绘图。通过 (原+1)/2把[-1,1]压缩到[0,1]plt.axis('off')plt.show()

显示图片的函数

def img_plot(img):img = np.squeeze(img) # np.squeeze是一个numpy函数,删除数组中形状为1的维度fig = plt.figure(figsize=(4,4))for i in range(16): # 迭代这n张图片plt.subplot(4,4,i+1)plt.imshow((img[i] + 1) / 2) # 生成器生成的图片是-1到1之间的,无法绘图。通过 (原+1)/2把[-1,1]压缩到[0,1]plt.axis('off')plt.show()

定义训练函数


def train(num_epoch,test_input):D_loss = []G_loss = []# 训练循环for epoch in range(num_epoch):d_epoch_loss = 0g_epoch_loss = 0count = len(dataloader) # 返回批次数for step,(img,_) in enumerate(dataloader): # _是标签数据,img是(批次,h,w),每次取的img形状为(64,1,28,28)# print(f'step={step},img.shape={img.shape}')# img_plot(img)img = img.to(device)size = img.size(0) # 得到一个批次的图片random_noise = torch.randn(size,100,device=device) # 生成器的输入'''一. 训练判别器''''''用真实图片训练判别器'''dis_optim.zero_grad()real_output = dis(img) # 对判别取输入真实的图片,输出对真实图片的预测结果# 判别器在真实图像上的损失d_real_loss = bce_loss(real_output,# torch.ones_like(real_output) 创建一个根real_loss一样形状的全1数组,作为标签。torch.ones_like(real_output))d_real_loss.backward()'''用生成的图片训练判别器'''gen_img = gen(random_noise)# 因为此时是为了训练判别器,所以不能让生成器的梯度参与进来。所以用detach()取出无梯度的tensorfake_output = dis(gen_img.detach())d_fake_loss = bce_loss(fake_output,torch.zeros_like(fake_output))d_fake_loss.backward()d_loss = d_real_loss+d_fake_lossdis_optim.step() # 对参数进行优化'''二.训练生成器'''gen_optim.zero_grad()# 刚才是去掉生成器生成的图片的梯度,来训练判别器。此处不需要去掉梯度。让判别器进行判别fake_output = dis(gen_img)# 思想:目的是生成越来越逼真的图片瞒过判别器,让判别器判定生成的图片是真实的图片。# 实现方法:把判别器的结果输入到bce_loss,用1作为标签,看判别器把生成的图片判别为真的损失。g_loss = bce_loss(fake_output,torch.ones_like(fake_output))g_loss.backward()gen_optim.step()# 计算一个epoch的损失with torch.no_grad(): #  禁止梯度计算和参数更新d_epoch_loss +=d_lossg_epoch_loss +=g_loss# 计算整体loss每个epoch的平均Losswith torch.no_grad(): #  禁止梯度计算和参数更新d_epoch_loss /= countg_epoch_loss /= countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)print('Epoch:', epoch+1)print(f'd_epoch_loss={d_epoch_loss}')print(f'g_epoch_loss={g_epoch_loss}')# 将16个长度为100的噪音输入到生成器并画图gen_img_plot(gen,test_input)

开始训练

'''开始计时'''
start_time = time.time()'''开始训练'''
test_input = torch.randn(16,100,device=device) # 生成16个 长度为100的正太分布随机数。放到GPU中 作为输入
print(test_input)
num_epoch = 50
train(num_epoch,test_input)
# 保存训练50次的参数
torch.save(gen.state_dict(),'gen_weights.pth')
torch.save(dis.state_dict(),'dis_weights.pth')'''计时结束'''
end_time = time.time()
run_time = end_time - start_time
# 将输出的秒数保留两位小数
if int(run_time)<60:print(f'{round(run_time,2)}s')
else:print(f'{round(run_time/60,2)}minutes')

结果可视化

在这里插入图片描述

加载训练好的参数

gen.load_state_dict(torch.load('/opt/software/computer_vision/codes/My_codes/paper_codes/GAN/weights/gen_weights.pth'))

用训练好的生成器生成图片并画图

test_new_input = torch.randn(16,100,device=device) # 生成16个 长度为100的正太分布随机数。放到GPU中 作为输入
gen_img_plot(gen,test_new_input)

在这里插入图片描述
GAN的生成是随机的,不同的噪声,生成不同的数字

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/53270.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

搭建 Qt6 开发环境

作者&#xff1a; 一去、二三里 个人微信号&#xff1a; iwaleon 微信公众号&#xff1a; 高效程序员 Qt 是一个跨平台的 C 应用程序开发框架&#xff0c;它提供了丰富的组件库和工具&#xff0c;使得开发人员可以在各种平台上轻松地开发 GUI 应用程序。 由于我们的教程 《细说…

CnetSDK .NET OCR SDK Crack

CnetSDK .NET OCR SDK Crack CnetSDK.NET OCR库SDK是一款高度准确的.NET OCR扫描仪软件&#xff0c;用于使用手写、文本和其他符号等图像进行字符识别。它是一款.NET OCR库软件&#xff0c;使用Tesseract OCR引擎技术&#xff0c;可将字符识别准确率提高99%。通过将此.NET OCR扫…

Rancher使用cert-manager安装报错解决

报错&#xff1a; rancher-rke-01:~/rke/rancher-helm/rancher # helm install rancher rancher-stable/rancher --namespace cattle-system --set hostnamewww.rancher.local Error: INSTALLATION FAILED: Internal error occurred: failed calling webhook "webhook…

sentinel的基本使用

在一些互联网项目中高并发的场景很多&#xff0c;瞬间流量很大&#xff0c;会导致我们服务不可用。 sentinel则可以保证我们服务的正常运行&#xff0c;提供限流、熔断、降级等方法来实现 一.限流&#xff1a; 1.导入坐标 <dependency><groupId>com.alibaba.c…

快速排序三种思路详解!

一、快速排序的介绍 快速排序是Hoare于1962年提出的一种二叉树结构的交换排序方法&#xff0c;其基本思想为&#xff1a;任取待排序元素序列中 的某元素作为基准值&#xff0c;按照该排序码将待排序集合分割成两子序列&#xff0c;左子序列中所有元素均小于基准值&#xff0c;…

激活函数总结(十九):激活函数补充(ISRU、ISRLU)

激活函数总结&#xff08;十九&#xff09;&#xff1a;激活函数补充 1 引言2 激活函数2.1 Inverse Square Root Unit &#xff08;ISRU&#xff09;激活函数2.2 Inverse Square Root Linear Unit (ISRLU)激活函数 3. 总结 1 引言 在前面的文章中已经介绍了介绍了一系列激活函…

python AI绘图教程

前提 1.安装python 2.安装git 步骤 下载stable-diffusion-webui项目&#xff08;链接&#xff1a;GitHub - AUTOMATIC1111/stable-diffusion-webui: Stable Diffusion web UI&#xff09; git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git 安装st…

商城-学习整理-高级-消息队列(十七)

目录 一、RabbitMQ简介(消息中间件)1、RabbitMQ简介&#xff1a;2、核心概念1、Message2、Publisher3、Exchange4、Queue5、Binding6、Connection7、Channel8、Consumer9、Virtual Host10、Broker 二、一些概念1、异步处理2、应用解耦3、流量控制5、概述 三、Docker安装RabbitM…

【C++ 学习 ⑰】- 继承(下)

目录 一、派生类的默认成员函数 二、继承与友元 三、继承与静态成员 四、复杂的菱形继承及菱形虚拟继承 五、继承和组合 一、派生类的默认成员函数 派生类的构造函数必须调用基类的构造函数初始化基类的那一部分成员。如果基类没有默认构造函数&#xff0c;那么必须在派生…

报错sql_mode=only_full_group_by

首发博客地址 https://blog.zysicyj.top/ 报错内容 ### The error may exist in file[D:\code\cppCode20221025\leader-system\target\classes\mapper\system\TJsonDataMapper.xml] ### The error may involve defaultParameterMap ### The error occurred while…

如何使用LLM实现文本自动生成视频

推荐&#xff1a;使用 NSDT场景编辑器 助你快速搭建可二次编辑的3D应用场景 介绍 基于扩散的图像生成模型代表了计算机视觉领域的革命性突破。这些进步由Imagen&#xff0c;DallE和MidJourney等模型开创&#xff0c;展示了文本条件图像生成的卓越功能。有关这些模型内部工作的…

【C++】UDP通信,实现文件的传输

目录 1 TCP与UDP比较 2 UDP 3 通信流程 4 实践 5 运行结果 1 TCP与UDP比较 2 UDP简介 UDP通信是无连接的,因此不需要

Spring与Mybatis集成且Aop整合(放飞双手,迅速完成CRUD及分页)

目录 一、概述 二、集成 ( 1 ) 为什么 ( 2 ) 优点 ( 3 ) 实例 三、整合 3.1 讲述 3.2 整合进行分页 带我们带来的收获 一、概述 集成是指将不同的组件、系统或框架整合在一起&#xff0c;使它们能够协同工作&#xff0c;共同完成某个功能或提供某种服务。在软件开发中&…

C语言之三子棋游戏实现篇

目录 主函数test.c 菜单函数 选择实现 游戏函数 &#xff08;函数调用&#xff09; 打印棋盘数据 打印展示棋盘 玩家下棋 电脑下棋 判断输赢 循环 test.c总代码 头文件&函数声明game.h 头文件的包含 游戏符号声明 游戏函数声明 game.h总代码 游戏函数ga…

服务器中了mkp勒索病毒该怎么办?勒索病毒解密,数据恢复

mkp勒索病毒算的上是一种比较常见的勒索病毒类型了。它的感染数量上也常年排在前几名的位置。所以接下来就由云天数据恢复中心的技术工程师来对mkp勒索病毒做一个分析&#xff0c;以及中招以后应该怎么办。 一&#xff0c;中了mkp勒索病毒的表现 桌面以及多个文件夹当中都有一封…

Linux:shell脚本:基础使用(6)《正则表达式-awk工具》

简介 awk是行处理器: 相比较屏幕处理的优点&#xff0c;在处理庞大文件时不会出现内存溢出或是处理缓慢的问题&#xff0c;通常用来格式化文本信息 awk处理过程: 依次对每一行进行处理&#xff0c;然后输出 1&#xff09;awk命令会逐行读取文件的内容进行处理 2&#xff09;a…

Neo4j实现表字段级血缘关系

需求背景 需要在前端页面展示当前表字段的所有上下游血缘关系&#xff0c;以进一步做数据诊断治理。大致效果图如下&#xff1a; 首先这里解释什么是表字段血缘关系&#xff0c;SQL 示例&#xff1a; CREATE TABLE IF NOT EXISTS table_b AS SELECT order_id, order_status F…

PB4引脚作GPIO上电高电平问题

问题说明 给旧项目debug&#xff0c;芯片是国民技术 N32G452VEL7 &#xff08;用起来跟32没多大差 包括PB4在内有多个引脚作为输出&#xff0c;默认低电平&#xff0c;在状态机内先输出高电平再回到低电平&#xff0c;来模拟按键的状态&#xff08;相当于按键按下松开后按键功…

计算机竞赛 基于大数据的时间序列股价预测分析与可视化 - lstm

文章目录 1 前言2 时间序列的由来2.1 四种模型的名称&#xff1a; 3 数据预览4 理论公式4.1 协方差4.2 相关系数4.3 scikit-learn计算相关性 5 金融数据的时序分析5.1 数据概况5.2 序列变化情况计算 最后 1 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &…

词向量及文本向量

文章目录 引言1. 文本向量化2. one-hot编码3. 词向量-word2vec3.1 词向量-基于语言模型 4 词向量 - word2vec基于窗口4.1 词向量-如何训练 5. Huffman树6. 负采样-negative sampling7. Glove基于共现矩阵7.1 Glove词向量7.2 Glove对比word2vec 8. 词向量训练总结9. 词向量应用9…