GAN原理 代码解读

模型架构

在这里插入图片描述

代码

数据准备

import os
import time
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
import torch.nn as nn
import torch# 创建文件夹存放图片
os.makedirs("data", exist_ok=True)
transform = transforms.Compose([transforms.ToTensor(), #它会进行0-1归一化,h方向/h,w方向/w。 然后将图片格式转换为 (channel,h,w)transforms.Normalize(0.5,0.5),#把数据归一化为均值为0.5,方差为0.5,图像的数值范围变成-1到1
])
# 下载训练数据后对图片进行transform里的toTensor和用均值方差归一化
train_dataset = datasets.MNIST('data',train=True,transform=transform,download=True)
dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=64,shuffle=True)

定义生成器

'''输入:正态分布随机数噪声(长度为100)输出:生成的图片,(1,28,28)中间过程:linear1: 100 -> 256linear2: 256 -> 512linear3: 512 -> 28*28reshape: 28x28 -> (1,28,28)
'''
class Generator(nn.Module):def __init__(self):super(Generator,self).__init__() # super().__init__() 是调用父类的__init__函数self.model = nn.Sequential(nn.Linear(100,256),nn.ReLU(),nn.Linear(256,512),nn.ReLU(),# 最后一层用tanh激活,将数据压缩到-1到1nn.Linear(512,28*28),nn.Tanh())def forward(self,x):img = self.model(x)img = img.view(-1,28,28,1) # 得到的是28*28=784,把它reshape为 (批量,h,w,channel)return img

定义判别器

'''判别器输入:(1,28,28)的图片输出:二分类的概率值 用sigmoid压缩到0-1之间内容:判别器 推荐使用LeakyRelu,因为生成器难以训练,Relu的负值直接变成0没有梯度了
'''
class Discriminator(nn.Module):def __init__(self):super(Discriminator,self).__init__()self.model = nn.Sequential(nn.Linear(28*28,512),nn.LeakyReLU(),nn.Linear(512,256),nn.LeakyReLU(),nn.Linear(256,1),nn.Sigmoid(),)def forward(self,x):x = x.view(-1,28*28)x = self.model(x)return x

初始化模型,优化器及损失计算函数

device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device) # 初始化并放到了相应的设备上
dis = Discriminator().to(device)
dis_optim = torch.optim.Adam(dis.parameters(),lr=0.0001)
gen_optim = torch.optim.Adam(gen.parameters(),lr=0.0001)
bce_loss = torch.nn.BCELoss()

画生成器生成的图的绘图函数

def gen_img_plot(model,epoch,test_input):prediction = model(test_input).detach().cpu().numpy() # 放在内存上 并转换为Numpyprediction = np.squeeze(prediction) # np.squeeze是一个numpy函数,删除数组中形状为1的维度fig = plt.figure(figsize=(4,4))for i in range(16): # 迭代这n张图片plt.subplot(4,4,i+1)plt.imshow((prediction[i] + 1) / 2) # 生成器生成的图片是-1到1之间的,无法绘图。通过 (原+1)/2把[-1,1]压缩到[0,1]plt.axis('off')plt.show()

显示图片的函数

def img_plot(img):img = np.squeeze(img) # np.squeeze是一个numpy函数,删除数组中形状为1的维度fig = plt.figure(figsize=(4,4))for i in range(16): # 迭代这n张图片plt.subplot(4,4,i+1)plt.imshow((img[i] + 1) / 2) # 生成器生成的图片是-1到1之间的,无法绘图。通过 (原+1)/2把[-1,1]压缩到[0,1]plt.axis('off')plt.show()

定义训练函数


def train(num_epoch,test_input):D_loss = []G_loss = []# 训练循环for epoch in range(num_epoch):d_epoch_loss = 0g_epoch_loss = 0count = len(dataloader) # 返回批次数for step,(img,_) in enumerate(dataloader): # _是标签数据,img是(批次,h,w),每次取的img形状为(64,1,28,28)# print(f'step={step},img.shape={img.shape}')# img_plot(img)img = img.to(device)size = img.size(0) # 得到一个批次的图片random_noise = torch.randn(size,100,device=device) # 生成器的输入'''一. 训练判别器''''''用真实图片训练判别器'''dis_optim.zero_grad()real_output = dis(img) # 对判别取输入真实的图片,输出对真实图片的预测结果# 判别器在真实图像上的损失d_real_loss = bce_loss(real_output,# torch.ones_like(real_output) 创建一个根real_loss一样形状的全1数组,作为标签。torch.ones_like(real_output))d_real_loss.backward()'''用生成的图片训练判别器'''gen_img = gen(random_noise)# 因为此时是为了训练判别器,所以不能让生成器的梯度参与进来。所以用detach()取出无梯度的tensorfake_output = dis(gen_img.detach())d_fake_loss = bce_loss(fake_output,torch.zeros_like(fake_output))d_fake_loss.backward()d_loss = d_real_loss+d_fake_lossdis_optim.step() # 对参数进行优化'''二.训练生成器'''gen_optim.zero_grad()# 刚才是去掉生成器生成的图片的梯度,来训练判别器。此处不需要去掉梯度。让判别器进行判别fake_output = dis(gen_img)# 思想:目的是生成越来越逼真的图片瞒过判别器,让判别器判定生成的图片是真实的图片。# 实现方法:把判别器的结果输入到bce_loss,用1作为标签,看判别器把生成的图片判别为真的损失。g_loss = bce_loss(fake_output,torch.ones_like(fake_output))g_loss.backward()gen_optim.step()# 计算一个epoch的损失with torch.no_grad(): #  禁止梯度计算和参数更新d_epoch_loss +=d_lossg_epoch_loss +=g_loss# 计算整体loss每个epoch的平均Losswith torch.no_grad(): #  禁止梯度计算和参数更新d_epoch_loss /= countg_epoch_loss /= countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)print('Epoch:', epoch+1)print(f'd_epoch_loss={d_epoch_loss}')print(f'g_epoch_loss={g_epoch_loss}')# 将16个长度为100的噪音输入到生成器并画图gen_img_plot(gen,test_input)

开始训练

'''开始计时'''
start_time = time.time()'''开始训练'''
test_input = torch.randn(16,100,device=device) # 生成16个 长度为100的正太分布随机数。放到GPU中 作为输入
print(test_input)
num_epoch = 50
train(num_epoch,test_input)
# 保存训练50次的参数
torch.save(gen.state_dict(),'gen_weights.pth')
torch.save(dis.state_dict(),'dis_weights.pth')'''计时结束'''
end_time = time.time()
run_time = end_time - start_time
# 将输出的秒数保留两位小数
if int(run_time)<60:print(f'{round(run_time,2)}s')
else:print(f'{round(run_time/60,2)}minutes')

结果可视化

在这里插入图片描述

加载训练好的参数

gen.load_state_dict(torch.load('/opt/software/computer_vision/codes/My_codes/paper_codes/GAN/weights/gen_weights.pth'))

用训练好的生成器生成图片并画图

test_new_input = torch.randn(16,100,device=device) # 生成16个 长度为100的正太分布随机数。放到GPU中 作为输入
gen_img_plot(gen,test_new_input)

在这里插入图片描述
GAN的生成是随机的,不同的噪声,生成不同的数字

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/53270.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

堆排序简介

概念&#xff1a; 堆排序是一种基于二叉堆数据结构的排序算法。它的概念是通过将待排序的元素构建成一个二叉堆&#xff0c;然后通过不断地取出堆顶元素并重新调整堆的结构来实现排序。 算法步骤&#xff1a; 构建最大堆&#xff08;或最小堆&#xff09;&#xff1a;将待排…

搭建 Qt6 开发环境

作者&#xff1a; 一去、二三里 个人微信号&#xff1a; iwaleon 微信公众号&#xff1a; 高效程序员 Qt 是一个跨平台的 C 应用程序开发框架&#xff0c;它提供了丰富的组件库和工具&#xff0c;使得开发人员可以在各种平台上轻松地开发 GUI 应用程序。 由于我们的教程 《细说…

CnetSDK .NET OCR SDK Crack

CnetSDK .NET OCR SDK Crack CnetSDK.NET OCR库SDK是一款高度准确的.NET OCR扫描仪软件&#xff0c;用于使用手写、文本和其他符号等图像进行字符识别。它是一款.NET OCR库软件&#xff0c;使用Tesseract OCR引擎技术&#xff0c;可将字符识别准确率提高99%。通过将此.NET OCR扫…

Rancher使用cert-manager安装报错解决

报错&#xff1a; rancher-rke-01:~/rke/rancher-helm/rancher # helm install rancher rancher-stable/rancher --namespace cattle-system --set hostnamewww.rancher.local Error: INSTALLATION FAILED: Internal error occurred: failed calling webhook "webhook…

sentinel的基本使用

在一些互联网项目中高并发的场景很多&#xff0c;瞬间流量很大&#xff0c;会导致我们服务不可用。 sentinel则可以保证我们服务的正常运行&#xff0c;提供限流、熔断、降级等方法来实现 一.限流&#xff1a; 1.导入坐标 <dependency><groupId>com.alibaba.c…

快速排序三种思路详解!

一、快速排序的介绍 快速排序是Hoare于1962年提出的一种二叉树结构的交换排序方法&#xff0c;其基本思想为&#xff1a;任取待排序元素序列中 的某元素作为基准值&#xff0c;按照该排序码将待排序集合分割成两子序列&#xff0c;左子序列中所有元素均小于基准值&#xff0c;…

激活函数总结(十九):激活函数补充(ISRU、ISRLU)

激活函数总结&#xff08;十九&#xff09;&#xff1a;激活函数补充 1 引言2 激活函数2.1 Inverse Square Root Unit &#xff08;ISRU&#xff09;激活函数2.2 Inverse Square Root Linear Unit (ISRLU)激活函数 3. 总结 1 引言 在前面的文章中已经介绍了介绍了一系列激活函…

python AI绘图教程

前提 1.安装python 2.安装git 步骤 下载stable-diffusion-webui项目&#xff08;链接&#xff1a;GitHub - AUTOMATIC1111/stable-diffusion-webui: Stable Diffusion web UI&#xff09; git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git 安装st…

商城-学习整理-高级-消息队列(十七)

目录 一、RabbitMQ简介(消息中间件)1、RabbitMQ简介&#xff1a;2、核心概念1、Message2、Publisher3、Exchange4、Queue5、Binding6、Connection7、Channel8、Consumer9、Virtual Host10、Broker 二、一些概念1、异步处理2、应用解耦3、流量控制5、概述 三、Docker安装RabbitM…

leetcode做题笔记100. 相同的树

给你两棵二叉树的根节点 p 和 q &#xff0c;编写一个函数来检验这两棵树是否相同。 如果两个树在结构上相同&#xff0c;并且节点具有相同的值&#xff0c;则认为它们是相同的。 思路一&#xff1a; bool isSameTree(struct TreeNode* p, struct TreeNode* q){if(p NULL &…

js实现img图片懒加载

在前端中&#xff0c;可以使用 JavaScript 来实现图片的懒加载。下面是一种常见的实现方式&#xff1a; 在 HTML 文件中&#xff0c;将需要懒加载的图片的 src 属性替换为一个占位符&#xff0c;例如使用一个透明的空白图片或者是一个包含背景色的 div。给这些图片添加一个自定…

【C++ 学习 ⑰】- 继承(下)

目录 一、派生类的默认成员函数 二、继承与友元 三、继承与静态成员 四、复杂的菱形继承及菱形虚拟继承 五、继承和组合 一、派生类的默认成员函数 派生类的构造函数必须调用基类的构造函数初始化基类的那一部分成员。如果基类没有默认构造函数&#xff0c;那么必须在派生…

报错sql_mode=only_full_group_by

首发博客地址 https://blog.zysicyj.top/ 报错内容 ### The error may exist in file[D:\code\cppCode20221025\leader-system\target\classes\mapper\system\TJsonDataMapper.xml] ### The error may involve defaultParameterMap ### The error occurred while…

如何使用LLM实现文本自动生成视频

推荐&#xff1a;使用 NSDT场景编辑器 助你快速搭建可二次编辑的3D应用场景 介绍 基于扩散的图像生成模型代表了计算机视觉领域的革命性突破。这些进步由Imagen&#xff0c;DallE和MidJourney等模型开创&#xff0c;展示了文本条件图像生成的卓越功能。有关这些模型内部工作的…

【C++】UDP通信,实现文件的传输

目录 1 TCP与UDP比较 2 UDP 3 通信流程 4 实践 5 运行结果 1 TCP与UDP比较 2 UDP简介 UDP通信是无连接的,因此不需要

Spring与Mybatis集成且Aop整合(放飞双手,迅速完成CRUD及分页)

目录 一、概述 二、集成 ( 1 ) 为什么 ( 2 ) 优点 ( 3 ) 实例 三、整合 3.1 讲述 3.2 整合进行分页 带我们带来的收获 一、概述 集成是指将不同的组件、系统或框架整合在一起&#xff0c;使它们能够协同工作&#xff0c;共同完成某个功能或提供某种服务。在软件开发中&…

Linux操作系统--系统管理

1.Linux中的服务和进程 计算机中,一个正在执行的程序或命令,被叫做“进程”(process)。 启动之后一直存在、常驻内存的进程,一般被称作“服务”(service)。 服务可以理解为系统需要持续的为用户提供某一种服务。比如网络服务。这里还有一个概念就是守护进程(daemon),一…

uni-app中学习笔记记录(1)

常用生命周期函数 onLoad 页面加载时触发&#xff0c;用onLoad可以接受路由传参&#xff1b;onReady 页面组件渲染完毕时触发&#xff0c;类似于vue2中的mounted生命周期函数&#xff1b;onShow 页面出现在屏幕上时触发&#xff0c;由于在h5或者小程序中&#xff0c;页面初始化…

C语言之三子棋游戏实现篇

目录 主函数test.c 菜单函数 选择实现 游戏函数 &#xff08;函数调用&#xff09; 打印棋盘数据 打印展示棋盘 玩家下棋 电脑下棋 判断输赢 循环 test.c总代码 头文件&函数声明game.h 头文件的包含 游戏符号声明 游戏函数声明 game.h总代码 游戏函数ga…

服务器中了mkp勒索病毒该怎么办?勒索病毒解密,数据恢复

mkp勒索病毒算的上是一种比较常见的勒索病毒类型了。它的感染数量上也常年排在前几名的位置。所以接下来就由云天数据恢复中心的技术工程师来对mkp勒索病毒做一个分析&#xff0c;以及中招以后应该怎么办。 一&#xff0c;中了mkp勒索病毒的表现 桌面以及多个文件夹当中都有一封…