5 分钟内构建一个简单的基于 Python 的 GAN

文章目录

  • 一、说明
  • 二、代码
  • 三、训练
  • 四、后记

一、说明

生成对抗网络(GAN)因其能力而在学术界引起轩然大波。机器能够创作出新颖、富有灵感的作品,这让每个人都感到敬畏和恐惧。因此,人们开始好奇,如何构建一个这样的网络?

生成对抗网络 (GAN) 是一种深度学习模型,可生成与某些输入数据相似的新合成数据。GAN 由两个神经网络组成:生成器和鉴别器。生成器经过训练可生成与输入数据相同的合成数据,而鉴别器经过训练可区分合成数据和真实数据。

生成模型学习输入数据 f (x)的内在分布函数,使其能够生成合成输入x’和输出y’,通常给定一些隐藏参数。GAN 的优势在于它们能够生成最清晰的图像,并且易于训练。

二、代码

此代码会训练 GAN 一定数量的周期,其中周期定义为对整个数据集的一次遍历。在每个周期中,代码会迭代数据加载器(应该是包装数据集的 PyTorch DataLoader 对象)中的数据,并在每个批次上训练鉴别器和生成器。

在这里插入图片描述

生成器的训练方式是试图欺骗鉴别器,而鉴别器则被训练来区分真实图像和假图像。这里使用的损失函数是二元交叉熵损失,这是 GAN 的常见选择。使用的优化器是 Adam,它是一种随机梯度下降优化器。

首先,导入必要的库并定义生成器和鉴别器模型。

import torch
import torch.nn as nn
import torch.optim as optim

生成器应该是一个神经网络,它接受随机噪声向量并生成合成数据。同时,鉴别器应该是一个神经网络,它接受真实数据或合成数据并输出输入数据为真实的概率。
类 生成器(nn.Module):

class Generator(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(Generator, self).__init__()self.fc1 = nn.Linear(input_size, hidden_size)self.fc2 = nn.Linear(hidden_size, output_size)def forward(self, x):x = torch.relu(self.fc1(x))x = torch.tanh(self.fc2(x))return x
class Discriminator(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(Discriminator, self).__init__()self.fc1 = nn.Linear(input_size, hidden_size)self.fc2 = nn.Linear(hidden_size, output_size)def forward(self, x):x = torch.relu(self.fc1(x))x = torch.sigmoid(self.fc2(x))return x
  1. 在下面的代码块中,我们设置了 GAN 的环境。这包括:

设置鉴别器和生成器网络的输入层、隐藏层和输出层的大小。
创建 Generator 和 Discriminator 类的实例
设置损失函数和优化器

# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# Set the input and output sizes
input_size = 784
hidden_size = 256
output_size = 1# Create the discriminator and generator
discriminator = Discriminator(input_size, hidden_size, output_size).to(device)
generator = Generator(input_size, hidden_size, output_size).to(device)# Set the loss function and optimizers
loss_fn = nn.BCEWithLogitsLoss()
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002)# Set the number of epochs and the noise size
num_epochs = 200
noise_size = 100# Training loop
for epoch in range(num_epochs):for i, (real_images, _) in enumerate(dataloader):# Get the batch sizebatch_size = real_images.size(0)

三、训练

  1. 在下面的代码中,生成器通过尝试欺骗鉴别器来训练,而鉴别器经过训练可以区分真假图像。为此,

我们给生成器一批噪声样本作为输入,并生成一批假图像。然后这些假图像通过鉴别器,鉴别器对批次中的每幅图像产生预测。
然后计算生成器的损失,代码通过生成器反向传播损失,并使用 Adam 优化器优化生成器的参数。此过程会以减少损失和提高生成器欺骗鉴别器的能力的方向更新生成器的参数。

 # Generate fake imagesnoise = torch.randn(batch_size, noise_size).to(device)fake_images = generator(noise)# Train the discriminator on real and fake imagesd_real = discriminator(real_images)d_fake = discriminator(fake_images)# Calculate the lossreal_loss = loss_fn(d_real, torch.ones_like(d_real))fake_loss = loss_fn(d_fake, torch.zeros_like(d_fake))d_loss = real_loss + fake_loss# Backpropagate and optimized_optimizer.zero_grad()d_loss.backward()d_optimizer.step()# Train the generatord_fake = discriminator(fake_images)g_loss = loss_fn(d_fake, torch.ones_like(d_fake))# Backpropagate and optimizeg_optimizer.zero_grad()g_loss.backward()g_optimizer.step()# Print the loss every 50 batchesif (i+1) % 50 == 0:print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}' .format(epoch+1, num_epochs, i+1, len(dataloader), d_loss.item(), g_loss.item()))

就这样……一个可以快速使用的 GAN 模型就完成了。

四、后记

关于成对抗网络(GAN)由两部分组成:

  • 生成器学习生成可信的数据。生成的实例将成为鉴别器的反面训练示例。
  • 鉴别器学会区分生成器的虚假数据和真实数据。鉴别器会惩罚产生不合理结果的生成器。
    当训练开始时,生成器会生成明显是假的数据,而鉴别器很快就能分辨出这是假的。
    更多的阐述将在本系列文章中展现。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/849529.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

计算机网络面试基础(一)

文章目录 一、HTTP基本概念1.HTTP是什么?2.HTTP 常见的状态码有哪些?3.http常见字段 二、GET和POST1.get和post有什么区别 三、HTTP缓存技术1.HTTP 缓存有哪些实现方式?2.什么是强制缓存?3.什么是协商缓存?(不太懂) 四…

长文预警:九头蛇的进化——Tesla AutoPilot 纯视觉方案解析

九头蛇的进化:Tesla AutoPilot 纯视觉方案解析 前言 本文整理自原文链接,写的非常好,给了博主很多启发,投原创是因为平台机制,希望能被更多人看到。 嘿嘿,漫威粉不要打我←_←不是Hail Hydra&#xff0c…

分享:各种原理测厚仪的发展历程!

板材厚度的检测离不开测厚仪的应用,目前激光测厚仪、射线测厚仪、超声波测厚仪等都已被广泛的应用于板材生产线中,那你了解他们各自的发展历程吗? 激光测厚仪的发展: 激光测厚仪是随着激光技术和CCD(电荷耦合器件&…

swaggerHole:针对swaggerHub的公共API安全扫描工具

关于swaggerHole swaggerHole是一款针对swaggerHub的API安全扫描工具,该工具基于纯Python 3开发,可以帮助广大研究人员检索swaggerHub上公共API的相关敏感信息,整个任务过程均以自动化形式实现,且具备多线程特性和管道模式。 工具…

网络安全实验BUAA-全套实验报告打包

下面是部分BUAA网络安全实验✅的实验内容 : 认识路由器、交换机。掌握路由器配置的基本指令。掌握正确配置路由器的方法,使网络正常工作。 本博客包括网络安全课程所有的实验报告:内容详细,一次下载打包 实验1-路由器配置实验2-AP…

快速搭建高效运营体系,Xinstall App下载自动绑定助您一臂之力

在互联网的浪潮中,App的推广与运营面临着诸多挑战。如何在多变的互联网环境下迅速搭建起能时刻满足用户需求的运营体系,成为了众多企业关注的焦点。今天,我们就来聊聊如何通过Xinstall的App下载自动绑定功能,轻松解决App推广与运营…

PXE、无人值守实验

PXE部署 [roottest2 ~]# systemctl stop firewalld [roottest2 ~]# setenforce 0一、部署tftp服务 [roottest2 ~]# yum -y install tftp-server.x86_64 xinetd.x86_64 [roottest2 ~]# systemctl start tftp [roottest2 ~]# systemctl enable tftp [roottest2 ~]# systemctl …

因为宇宙一片漆黑,所以地球才有昼夜之分,宇宙为什么是黑的?

因为宇宙一片漆黑,所以地球才有昼夜之分,宇宙为什么是黑的? 地球为何会有昼夜之分? 乍一看,这个问题很是简单,当然是因为地球一直在自转了,当地球的一部分被太阳照射时就是白昼,而…

UI框架与MVC模式详解(1)——逻辑与数据分离

【效率最高的耦合方式】 以实际的例子来说明,更容易理解些。 这里从上到下,从左到右共有8个显示项,如果只需要显示这8个,不会做任何改变,数据固定,那么我们只需要最常规的思路去写就好,这是最…

【JSP】如何在IDEA上部署JSP WEB开发项目

以我的课设为例,教大家拿到他人的项目后,如何在IDEA上部署。 需要准备: JDK17(或者JDK13)IntelliJ IDEA 2023.2.6MySQL 8.0Tomcat 9.0 一,新建项目添加文件 1.1复制“位置”的路径 1.2找到该文件夹 1.3…

linux嵌入式设备测试wifi信号强度方法

首先我们要清楚设备具体链接在哪个wifi热点上 执行:nmcli dev wifi list rootubuntu:/home/ubuntu# nmcli dev wifi list IN-USE BSSID SSID MODE CHAN RATE SIGNAL BARS > * 14:EB:08:51:7D:20 wifi22222_5G Infr…

米尔NXP i.MX 93开发板的Qt开发指南

1. 概述 Qt 是一个跨平台的图形应用开发框架,被应用在不同尺寸设备和平台上,同时提供不同版权版本供用户选择。米尔 NXP i.MX 93 开发板(MYD-LMX9X开发板)使用 Qt6.5 版本进行应用开发。在 Qt 应用开发中,推荐使用 Qt…

NSSCTF CRYPTO MISC题解(一)

陇剑杯 2021刷题记录_[陇剑杯 2021]签到-CSDN博客 [陇剑杯 2021]签到 下载附件压缩包,解压后得到 后缀为.pcpang,为流量包,流量分析,使用wireshark打开 {NSSCTF} [陇剑杯 2021]签到 详解-CSDN博客 选择统计里面的协议分级 发现流…

Vxe UI vxe-table 实现自定义列拖拽,列拖拽排序功能

Vxe UI vue vxe-table 实现自定义列拖拽&#xff0c;列拖拽排序功能 开启自定义列 vxe-toolbar 工具栏&#xff0c;通过 custom 启用后就可以开启自定义列功能 <template><div><vxe-toolbar ref"toolbarRef" custom></vxe-toolbar><vx…

【java基础】内部类

1、 非静态成员内部类可以访问所在类的全部方法和对象&#xff08;就相当于一个对象方法&#xff08;属于对象阶层和非静态方法同时加载在类加载之后&#xff09;&#xff09; 2、非静态成员内部类无法在该类&#xff08;就是非静态成员内部类所在的类&#xff09;的静态方法中…

MS1112驱动开发

作者简介&#xff1a; 一个平凡而乐于分享的小比特&#xff0c;中南民族大学通信工程专业研究生在读&#xff0c;研究方向无线联邦学习 擅长领域&#xff1a;驱动开发&#xff0c;嵌入式软件开发&#xff0c;BSP开发 作者主页&#xff1a;一个平凡而乐于分享的小比特的个人主页…

java版B/S架构UWB人员定位系统源码spring boot+vue技术架构uwb定位装置-工业级UWB室内定位系统源码

java版B/S架构UWB人员定位系统源码spring bootvue技术架构uwb定位装置-工业级UWB室内定位系统源码 本套系统运用UWB定位技术&#xff0c;开发的高精度人员定位系统&#xff0c;通过独特的射频处理&#xff0c;配合先进的位置算法&#xff0c;可以有效计算复杂环境下的人员与物…

自动驾驶仿真(高速道路)LaneKeeping

前言 A high-level decision agent trained by deep reinforcement learning (DRL) performs quantitative interpretation of behavioral planning performed in an autonomous driving (AD) highway simulation. The framework relies on the calculation of SHAP values an…

流批一体计算引擎-10-[Flink]中的常用算子和DataStream转换

pyflink 处理 kafka数据 1 DataStream API 示例代码 从非空集合中读取数据&#xff0c;并将结果写入本地文件系统。 from pyflink.common.serialization import Encoder from pyflink.common.typeinfo import Types from pyflink.datastream import StreamExecutionEnviron…

[网鼎杯 2020 青龙组]jocker

运行程序,发现是要我们自己输入 那么肯定是拿到enc慢慢还原 32位,无壳 进来就红一下报错 这里可以看见长度为24 动调一下看看 这里进行了大量的异或 这里是对地址开始的硬编码进行异或,从而达到smc的效果 所以你也可以发现在进行这一步操作之前 encry函数全是报错 你点开…