【深度学习】gan网络原理生成对抗网络

【深度学习】gan网络原理生成对抗网络

GAN的基本思想源自博弈论你的二人零和博弈,由一个生成器和一个判别器构成,通过对抗学习的方式训练,目的是估测数据样本的潜在分布并生成新的数据样本。
在这里插入图片描述

1.下载数据并对数据进行规范

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5 , 0.5)
])
train_ds = torchvision.datasets.MNIST('data', train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)

下载MNIST数据集,并对数据进行规范化。transforms.Compose 是用于定义一系列数据变换的类,ToTensor() 将图像转换为PyTorch张量,Normalize(0.5, 0.5) 对张量进行归一化。然后,创建一个 DataLoader,它将数据集划分成小批次,使得在训练时更容易处理。

2.生成器的代码

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.main = nn.Sequential(nn.Linear(100, 256),nn.ReLU(),nn.Linear(256, 512),nn.ReLU(),nn.Linear(512, 28*28),nn.Tanh())def forward(self, x):img = self.main(x)img = img.reshape(-1, 28, 28)return img

这一部分定义了生成器的神经网络模型。生成器的输入是一个大小为100的随机向量,通过多个线性层和激活函数(ReLU),最后通过 nn.Tanh() 激活函数生成大小为28x28的图像。forward 方法定义了前向传播的过程。

3.判别器的代码

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.main = nn.Sequential(nn.Linear(28*28, 512),nn.LeakyReLU(),nn.Linear(512, 256),nn.LeakyReLU(),nn.Linear(256, 1),nn.Sigmoid())def forward(self, x):x = x.view(-1, 28*28)x = self.main(x)return x

这一部分定义了判别器的神经网络模型。判别器的输入是28x28大小的图像,通过多个线性层和激活函数(LeakyReLU),最后通过 nn.Sigmoid() 激活函数输出一个0到1之间的值,表示输入图像是真实图像的概率。

4. 定义损失函数和优化函数

device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)
gen_opt = optim.Adam(gen.parameters(), lr=0.0001)
dis_opt = optim.Adam(dis.parameters(), lr=0.0001)
loss_fn = torch.nn.BCELoss()

这一部分设置了设备(GPU或CPU)、初始化了生成器和判别器的实例,并定义了优化器(Adam优化器)和损失函数(二分类交叉熵损失)。将生成器和判别器移动到设备上进行加速计算。

5.定义绘图函数

def gen_img_plot(model,test_input):prediction = np.squeeze(model(test_input).detach().cpu().numpy())fig = plt.figure(figsize=(4, 4))for i in range(16):plt.subplot(4, 4, i+1)plt.imshow((prediction[i]+1)/2)plt.axis('off')plt.show()

6. 开始训练,并显示出生成器所产生的图像

test_input = torch.randn(16, 100, device=device)
D_loss = []
G_loss = []
for epoch in range(30):d_epoch_loss = 0g_epoch_loss = 0count = len(dataloader)for step, (img, _) in enumerate(dataloader):img = img.to(device)               # 获得用于训练的mnist图像size = img.size(0)                 # 获得1批次数据量大小# 随机生成size个100维的向量样本值,也即是噪声,用于输入生成器 生成 和mnist一样的图像数据random_noise = torch.randn(size, 100, device=device)########################### 先训练判别器 #############################dis_opt.zero_grad()real_output = dis(img)d_real_loss = loss_fn(real_output, torch.ones_like(real_output))  # 真实值的loss,也即是真图片与1标签的损失d_real_loss.backward()gen_img = gen(random_noise)fake_output = dis(gen_img.detach())d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output)) # 假的值的loss,也即是生成的图像与0标签的损失d_fake_loss.backward()d_loss = d_real_loss + d_fake_lossdis_opt.step()########################### 下面再训练生成器 #############################gen_opt.zero_grad()fake_output = dis(gen_img)g_loss = loss_fn(fake_output, torch.ones_like(fake_output))g_loss.backward()gen_opt.step()#########################################################################with torch.no_grad():d_epoch_loss += d_lossg_epoch_loss += g_loss
with torch.no_grad():d_epoch_loss /= countg_epoch_loss /= countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)print('epoch:', epoch)gen_img_plot(gen, test_input)

1.设置 test_input 作为模型的输入,并初始化用于存储判别器(D)和生成器(G)的损失值的列表。

2.开始 30 轮次的训练循环。在每一轮中:

3.对数据集进行遍历。每次迭代,加载一批图像数据 (img)。

4.将图像数据移动到设备(device)上,并获取批次大小。

5.生成随机噪声,作为输入给生成器。

6.训练判别器(D):

  • 对真实图像计算判别器的损失 (d_real_loss),并反向传播计算梯度。
  • 生成生成器产生的图像,并计算判别器的对这些生成图像的损失 (d_fake_loss),再反向传播计算梯度。
  • 计算总的判别器损失 d_loss,并更新判别器的参数。

7.训练生成器(G):

  • 生成器生成图像,并将其输入到判别器中,计算生成器的损失 (g_loss),并反向传播计算梯度。
  • 更新生成器的参数。

这个过程是 GAN 中交替训练生成器和判别器的典型过程,目的是让生成器生成逼真的图像,同时让判别器能够准确区分真假图像。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/188885.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《ChatGPT实操应用大全》探索无限可能

🗣️探索ChatGPT,开启无限可能🚀 文末有免费送书福利!!! ChatGPT是人类有史以来最伟大的发明。他能写作、绘画、翻译、看病、做菜、编程、数据分析、制作视频、解高等数学题…,他会的技能…

《2023开发者生态系统现状》:ChatGPT 是最常用的 AI 工具,60%开发者使用代码生成工具辅助编程

在前有编程语言历经 80 年的迭代,后有 GitHub Copilot、ChatGPT 等 AI 辅助编程工具的层出不穷,开发者的开发方式发生了什么样的变化?行业中领头的 Java IDE IntelliJ IDEA、Kotlin 编程语言背后的软件工具开发公司 JetBrains 基于全球 26,34…

【Java 基础】09 封装 继承 多态

我们都知道 Java 是以面向对象而著称,最著名的当然就是面向对象的三大特性啦,接下来就逐一举例说明一下。 1. 封装 封装指的是将类的内部细节隐藏起来,只对外提供必要的访问方式。 例如: 我们使用的计算器做一个乘法运算&#x…

[原创]Delphi的SizeOf(), Length(), 动态数组, 静态数组的关系.

[简介] 常用网名: 猪头三 出生日期: 1981.XX.XXQQ: 643439947 个人网站: 80x86汇编小站 https://www.x86asm.org 编程生涯: 2001年~至今[共22年] 职业生涯: 20年 开发语言: C/C、80x86ASM、PHP、Perl、Objective-C、Object Pascal、C#、Python 开发工具: Visual Studio、Delphi…

SQL Sever 复习笔记【一】

SQL Sever 基础知识 一、查询数据第1节 基本 SQL Server 语句SELECT第2节 SELECT语句示例2.1 SELECT - 检索表示例的某些列2.2 SELECT - 检索表的所有列2.3 SELECT - 对结果集进行筛选2.4 SELECT - 对结果集进行排序2.5 SELECT - 对结果集进行分组2.5 SELECT - 对结果集进行筛选…

【前端】多线程 worker

VUE3 引用 npm install worker-loader 在vue.config.js文件的defineConfig里加上配置参数 chainWebpack: config > {config.module.rule(worker-loader).test(/\.worker\.js$/).use({loader: worker-loader,options: {inline: true}}).loader(worker-loader).end()}先在…

C语言错误处理之 “<errno.h>与<error.h>”

目录 前言 错误号处理方式 errno.h头文件 error.h头文件 参数解释: 关于的”__attribute__“解释: 关于“属性”的解释: 实例一: 实例二: error.h与errno.h的区别 补充内容: 前言 在开始学习…

后仿真 ERROR

后仿真 error ERROR (SFE-23): "input.scs" 252: The instance _57_D32_noxref is referencing an undefined model or subcircuit, parasitic_nwd. Either include the file containing the definition of parasitic_nwd, or define parasitic_nwd before running t…

VUE2+THREE.JS点击事件

THREE.JS点击事件 1.增加监听点击事件2.点击事件实现3.记得关闭页面时 销毁此监听事件 1.增加监听点击事件 renderer.domElement.addEventListener("click", this.onClick, false); 注:初始化render时监听 2.点击事件实现 onClick(event) {const raycaster new …

LV.12 D21 PWM实验 学习笔记

一、PWD简介 1.1 蜂鸣器工作原理 有源蜂鸣器 有源蜂鸣器只要接上额定电源就可以发出声音 无源蜂鸣器 无源蜂鸣器利用电磁感应原理,为音圈接入交变电流后形成的电磁铁与永磁铁相吸或相斥而推动振膜发声 1.2 使用GPIO控制 while(1) { GPX2.DATGPX2.D…

Python中用于机器学习的Lazy Predict库

Python是一种多功能语言,你可以用它来做任何事情。Python的一个伟大之处在于,有这么多的库使它变得更加强大。Lazy Predict就是其中一个库。它是机器学习和数据科学的一个很好的工具。在本文中,我们将了解它是什么,它做什么&#…

matlab 无迹卡尔曼滤波

1、内容简介 略 26-可以交流、咨询、答疑 2、内容说明 无迹卡尔曼滤波 无迹卡尔曼滤波 无迹卡尔曼滤波 3、仿真分析 %该文件用于编写无迹卡尔曼滤波算法及其测试 %注解:主要子程序包括:轨迹发生器、系统方程 % 测量方程、UKF滤波器 %----…

LeetCode Hot100 31.下一个排列

题目: 整数数组的一个 排列 就是将其所有成员以序列或线性顺序排列。 例如,arr [1,2,3] ,以下这些都可以视作 arr 的排列:[1,2,3]、[1,3,2]、[3,1,2]、[2,3,1] 。 整数数组的 下一个排列 是指其整数的下一个字典序更大的排列…

刷题笔记12.01 贪心策略

P1090 [NOIP2004 提高组] 合并果子 / [USACO06NOV] Fence Repair G - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 说最大不超过.不用高精度,好说 #include <bits/stdc.h> using namespace std; int n,n2,a; long long a1[10004],a2[10004],sum; int main() {ios::sync_…

【话题】程序员养生指南(AI生成)

目录 程序猿可能出现的职业病有哪些&#xff1f; 如何预防和对付这些职业病&#xff1f; 一、颈椎病的预防 二、神经衰弱的调适 三、肩周炎的防护 四、视力下降的保护 五、饮食与运动的重要性 六、消化系统职业病的预防 程序员养生心得&#xff1a;呵护健康&#xff0c…

百元挂耳式蓝牙耳机推荐,几款性价比高的开放式蓝牙耳机

在百元价位段&#xff0c;挂耳式蓝牙耳机是备受消费者追捧&#xff0c;提供了出色的音质、便携性和无拘束的使用体验&#xff0c;无论您是追求音乐品质&#xff0c;还是需要在办公或运动中保持通讯畅通&#xff0c;基本上都是离不开耳机的身影&#xff0c;今天小编为大家推荐几…

Burp Suite序列之目录扫描

如果你是一名渗透测试爱好者或者专业人士&#xff0c;你一定知道目录扫描是渗透测试中非常重要的一步。通过目录扫描&#xff0c;我们可以发现网站的敏感信息&#xff0c;隐藏的功能&#xff0c;甚至是后台入口。目录扫描可以帮助我们更好地了解目标网站的结构和漏洞。 但是&a…

卡码网15 .链表的基本操作III

链表的基础操作III 时间限制&#xff1a;1.000S 空间限制&#xff1a;128MB 题目描述 请编写一个程序&#xff0c;实现以下链表操作&#xff1a;构建一个单向链表&#xff0c;链表中包含一组整数数据。 1. 实现在链表的第 n 个位置插入一个元素&#xff0c;输出整个链表的…

软著项目推荐 深度学习实现语义分割算法系统 - 机器视觉

文章目录 1 前言2 概念介绍2.1 什么是图像语义分割 3 条件随机场的深度学习模型3\. 1 多尺度特征融合 4 语义分割开发过程4.1 建立4.2 下载CamVid数据集4.3 加载CamVid图像4.4 加载CamVid像素标签图像 5 PyTorch 实现语义分割5.1 数据集准备5.2 训练基准模型5.3 损失函数5.4 归…

计算机软件的分类

以功能进行分类&#xff0c;计算机软件通常可以分为系统软件和应用软件两大类。 系统软件&#xff1a;系统软件是计算机运行和管理的基本软件&#xff0c;包括操作系统、驱动程序、系统工具和服务程序等。操作系统是系统软件的核心&#xff0c;负责管理计算机的硬件资源、提供用…