5 分钟内构建一个简单的基于 Python 的 GAN

文章目录

  • 一、说明
  • 二、代码
  • 三、训练
  • 四、后记

一、说明

生成对抗网络(GAN)因其能力而在学术界引起轩然大波。机器能够创作出新颖、富有灵感的作品,这让每个人都感到敬畏和恐惧。因此,人们开始好奇,如何构建一个这样的网络?

生成对抗网络 (GAN) 是一种深度学习模型,可生成与某些输入数据相似的新合成数据。GAN 由两个神经网络组成:生成器和鉴别器。生成器经过训练可生成与输入数据相同的合成数据,而鉴别器经过训练可区分合成数据和真实数据。

生成模型学习输入数据 f (x)的内在分布函数,使其能够生成合成输入x’和输出y’,通常给定一些隐藏参数。GAN 的优势在于它们能够生成最清晰的图像,并且易于训练。

二、代码

此代码会训练 GAN 一定数量的周期,其中周期定义为对整个数据集的一次遍历。在每个周期中,代码会迭代数据加载器(应该是包装数据集的 PyTorch DataLoader 对象)中的数据,并在每个批次上训练鉴别器和生成器。

在这里插入图片描述

生成器的训练方式是试图欺骗鉴别器,而鉴别器则被训练来区分真实图像和假图像。这里使用的损失函数是二元交叉熵损失,这是 GAN 的常见选择。使用的优化器是 Adam,它是一种随机梯度下降优化器。

首先,导入必要的库并定义生成器和鉴别器模型。

import torch
import torch.nn as nn
import torch.optim as optim

生成器应该是一个神经网络,它接受随机噪声向量并生成合成数据。同时,鉴别器应该是一个神经网络,它接受真实数据或合成数据并输出输入数据为真实的概率。
类 生成器(nn.Module):

class Generator(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(Generator, self).__init__()self.fc1 = nn.Linear(input_size, hidden_size)self.fc2 = nn.Linear(hidden_size, output_size)def forward(self, x):x = torch.relu(self.fc1(x))x = torch.tanh(self.fc2(x))return x
class Discriminator(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(Discriminator, self).__init__()self.fc1 = nn.Linear(input_size, hidden_size)self.fc2 = nn.Linear(hidden_size, output_size)def forward(self, x):x = torch.relu(self.fc1(x))x = torch.sigmoid(self.fc2(x))return x
  1. 在下面的代码块中,我们设置了 GAN 的环境。这包括:

设置鉴别器和生成器网络的输入层、隐藏层和输出层的大小。
创建 Generator 和 Discriminator 类的实例
设置损失函数和优化器

# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# Set the input and output sizes
input_size = 784
hidden_size = 256
output_size = 1# Create the discriminator and generator
discriminator = Discriminator(input_size, hidden_size, output_size).to(device)
generator = Generator(input_size, hidden_size, output_size).to(device)# Set the loss function and optimizers
loss_fn = nn.BCEWithLogitsLoss()
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002)# Set the number of epochs and the noise size
num_epochs = 200
noise_size = 100# Training loop
for epoch in range(num_epochs):for i, (real_images, _) in enumerate(dataloader):# Get the batch sizebatch_size = real_images.size(0)

三、训练

  1. 在下面的代码中,生成器通过尝试欺骗鉴别器来训练,而鉴别器经过训练可以区分真假图像。为此,

我们给生成器一批噪声样本作为输入,并生成一批假图像。然后这些假图像通过鉴别器,鉴别器对批次中的每幅图像产生预测。
然后计算生成器的损失,代码通过生成器反向传播损失,并使用 Adam 优化器优化生成器的参数。此过程会以减少损失和提高生成器欺骗鉴别器的能力的方向更新生成器的参数。

 # Generate fake imagesnoise = torch.randn(batch_size, noise_size).to(device)fake_images = generator(noise)# Train the discriminator on real and fake imagesd_real = discriminator(real_images)d_fake = discriminator(fake_images)# Calculate the lossreal_loss = loss_fn(d_real, torch.ones_like(d_real))fake_loss = loss_fn(d_fake, torch.zeros_like(d_fake))d_loss = real_loss + fake_loss# Backpropagate and optimized_optimizer.zero_grad()d_loss.backward()d_optimizer.step()# Train the generatord_fake = discriminator(fake_images)g_loss = loss_fn(d_fake, torch.ones_like(d_fake))# Backpropagate and optimizeg_optimizer.zero_grad()g_loss.backward()g_optimizer.step()# Print the loss every 50 batchesif (i+1) % 50 == 0:print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}' .format(epoch+1, num_epochs, i+1, len(dataloader), d_loss.item(), g_loss.item()))

就这样……一个可以快速使用的 GAN 模型就完成了。

四、后记

关于成对抗网络(GAN)由两部分组成:

  • 生成器学习生成可信的数据。生成的实例将成为鉴别器的反面训练示例。
  • 鉴别器学会区分生成器的虚假数据和真实数据。鉴别器会惩罚产生不合理结果的生成器。
    当训练开始时,生成器会生成明显是假的数据,而鉴别器很快就能分辨出这是假的。
    更多的阐述将在本系列文章中展现。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/849529.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

计算机网络面试基础(一)

文章目录 一、HTTP基本概念1.HTTP是什么?2.HTTP 常见的状态码有哪些?3.http常见字段 二、GET和POST1.get和post有什么区别 三、HTTP缓存技术1.HTTP 缓存有哪些实现方式?2.什么是强制缓存?3.什么是协商缓存?(不太懂) 四…

分治算法例子

分治算法概述 分治算法是一种将问题分解为更小的子问题来解决,然后将这些子问题的解合并起来得到原问题的解的算法。这些示例展示了分治算法如何将问题分解为更小的子问题,通过递归和合并来解决复杂问题。以下是两个常见的分治算法示例及其Python实现: 快速排序 (Quick So…

长文预警:九头蛇的进化——Tesla AutoPilot 纯视觉方案解析

九头蛇的进化:Tesla AutoPilot 纯视觉方案解析 前言 本文整理自原文链接,写的非常好,给了博主很多启发,投原创是因为平台机制,希望能被更多人看到。 嘿嘿,漫威粉不要打我←_←不是Hail Hydra&#xff0c…

分享:各种原理测厚仪的发展历程!

板材厚度的检测离不开测厚仪的应用,目前激光测厚仪、射线测厚仪、超声波测厚仪等都已被广泛的应用于板材生产线中,那你了解他们各自的发展历程吗? 激光测厚仪的发展: 激光测厚仪是随着激光技术和CCD(电荷耦合器件&…

swaggerHole:针对swaggerHub的公共API安全扫描工具

关于swaggerHole swaggerHole是一款针对swaggerHub的API安全扫描工具,该工具基于纯Python 3开发,可以帮助广大研究人员检索swaggerHub上公共API的相关敏感信息,整个任务过程均以自动化形式实现,且具备多线程特性和管道模式。 工具…

网络安全实验BUAA-全套实验报告打包

下面是部分BUAA网络安全实验✅的实验内容 : 认识路由器、交换机。掌握路由器配置的基本指令。掌握正确配置路由器的方法,使网络正常工作。 本博客包括网络安全课程所有的实验报告:内容详细,一次下载打包 实验1-路由器配置实验2-AP…

快速搭建高效运营体系,Xinstall App下载自动绑定助您一臂之力

在互联网的浪潮中,App的推广与运营面临着诸多挑战。如何在多变的互联网环境下迅速搭建起能时刻满足用户需求的运营体系,成为了众多企业关注的焦点。今天,我们就来聊聊如何通过Xinstall的App下载自动绑定功能,轻松解决App推广与运营…

PXE、无人值守实验

PXE部署 [roottest2 ~]# systemctl stop firewalld [roottest2 ~]# setenforce 0一、部署tftp服务 [roottest2 ~]# yum -y install tftp-server.x86_64 xinetd.x86_64 [roottest2 ~]# systemctl start tftp [roottest2 ~]# systemctl enable tftp [roottest2 ~]# systemctl …

因为宇宙一片漆黑,所以地球才有昼夜之分,宇宙为什么是黑的?

因为宇宙一片漆黑,所以地球才有昼夜之分,宇宙为什么是黑的? 地球为何会有昼夜之分? 乍一看,这个问题很是简单,当然是因为地球一直在自转了,当地球的一部分被太阳照射时就是白昼,而…

UI框架与MVC模式详解(1)——逻辑与数据分离

【效率最高的耦合方式】 以实际的例子来说明,更容易理解些。 这里从上到下,从左到右共有8个显示项,如果只需要显示这8个,不会做任何改变,数据固定,那么我们只需要最常规的思路去写就好,这是最…

【JSP】如何在IDEA上部署JSP WEB开发项目

以我的课设为例,教大家拿到他人的项目后,如何在IDEA上部署。 需要准备: JDK17(或者JDK13)IntelliJ IDEA 2023.2.6MySQL 8.0Tomcat 9.0 一,新建项目添加文件 1.1复制“位置”的路径 1.2找到该文件夹 1.3…

Python语言兼职:探索、挑战与机遇

Python语言兼职:探索、挑战与机遇 在数字化浪潮汹涌的今天,Python语言因其简洁易懂、功能强大的特点,成为了众多编程爱好者的首选。而兼职Python开发者这一职业,也逐渐成为了一种新兴的工作模式。本文将深入探讨Python语言兼职的…

linux嵌入式设备测试wifi信号强度方法

首先我们要清楚设备具体链接在哪个wifi热点上 执行:nmcli dev wifi list rootubuntu:/home/ubuntu# nmcli dev wifi list IN-USE BSSID SSID MODE CHAN RATE SIGNAL BARS > * 14:EB:08:51:7D:20 wifi22222_5G Infr…

米尔NXP i.MX 93开发板的Qt开发指南

1. 概述 Qt 是一个跨平台的图形应用开发框架,被应用在不同尺寸设备和平台上,同时提供不同版权版本供用户选择。米尔 NXP i.MX 93 开发板(MYD-LMX9X开发板)使用 Qt6.5 版本进行应用开发。在 Qt 应用开发中,推荐使用 Qt…

NSSCTF CRYPTO MISC题解(一)

陇剑杯 2021刷题记录_[陇剑杯 2021]签到-CSDN博客 [陇剑杯 2021]签到 下载附件压缩包,解压后得到 后缀为.pcpang,为流量包,流量分析,使用wireshark打开 {NSSCTF} [陇剑杯 2021]签到 详解-CSDN博客 选择统计里面的协议分级 发现流…

Vxe UI vxe-table 实现自定义列拖拽,列拖拽排序功能

Vxe UI vue vxe-table 实现自定义列拖拽&#xff0c;列拖拽排序功能 开启自定义列 vxe-toolbar 工具栏&#xff0c;通过 custom 启用后就可以开启自定义列功能 <template><div><vxe-toolbar ref"toolbarRef" custom></vxe-toolbar><vx…

【java基础】内部类

1、 非静态成员内部类可以访问所在类的全部方法和对象&#xff08;就相当于一个对象方法&#xff08;属于对象阶层和非静态方法同时加载在类加载之后&#xff09;&#xff09; 2、非静态成员内部类无法在该类&#xff08;就是非静态成员内部类所在的类&#xff09;的静态方法中…

人类的深度学习与机器的深度学习不同

人类的深度学习和机器的深度学习存在一些重要的区别&#xff1a; 人类的深度学习是自主进行的&#xff0c;无需明确编程。我们可以通过观察周围环境、与他人互动和实践来不断学习和适应&#xff0c;人类是具有意识和自我意识的智能体&#xff0c;能够理解和处理抽象概念&#x…

Python - 获取文件行数

看了很多教程&#xff0c;使用 readlines() 读取感觉效率比较低 这里我使用 Python 调用 wc -l 命令 PS : 这个命令存在一个小的可能的不一致是&#xff0c;如果文件最后一行没有换行符&#xff0c;则这一行不会被统计 import osfile_path /Users/user/Documents/happy.md…

一个程序员的牢狱生涯(55)改判

星期二 改判 我在‘赵老大’的号子里待到快到十点的时候,赵老大才送我回了自己的号子。 回到号子里后,我把‘赵老大’给我的烟又递给了头铺,头铺爽快地拿过去,从中抽出两根烟递给我,让我晚上站班的时候抽,并告诉我打火机就放在枕头边,到时候自己取就行。 晚上我和老杨还…