【GAN】简单的GAN模型搭建 -- 以线性模型和MNIST数据集为例子

文章目录

  • 确定损失函数
  • 生成器网络架构

不讲原理,从简单的代码一步步开始,学会怎么用、怎么设计损失函数即可。

确定损失函数

生成器的任务是生成足够以假乱真的数据,判别器的任务是分辨出哪些数据是真实的,哪些数据是假的。因此,对于判别器来讲,需要判别真伪,也就是true/false,从这个角度看,是个二分类问题。所以损失函数使用二类分类损失,即BCELoss。

import torch
import torch.nn as nnadversarial_loss = nn.BCELoss()

生成器网络架构

这里使用纯线性网络作为生成器,得到的输出为[batch_size, np.pord(28*28)]

import torch.nn as nn
import numpy as npclass Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_features, out_features, normalization=True):layers = [nn.Linear(in_features, out_features)]if normalization:layers.append(nn.BatchNorm1d(out_features, 0.8))layers.append(nn.LeakyReLU(0.2))return layersself.model = nn.Sequential(*block(100, 128, normalization=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, int(np.prod((1, 28, 28)))), # generate a photo size, but in line mode.nn.Tanh())def forward(self, x):return self.model(x)

判别器网络架构:
判别器的功能为判断出来哪一个是生成的图片,哪一个是真实的图片。对于生成的图片,我们希望判别器打上假的标签,对于真实的图片,我们希望判别器打上真的标签,因此,判别器的输出为一个数,即0或者1。

import torch.nn as nn
import numpy as np
class Disctiminator(nn.Module):def __init__(self):super(Disctiminator, self).__init__()self.model = nn.Sequential(nn.Linear(int(np.prod((1, 28, 28))), 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1),nn.Sigmoid())def forward(self, x):img_flat = x.view(x.size(0), -1)validity = self.model(img_flat)return validity

训练流程:

载入数据 — 训练 (生成图片 — 损失 — 反向传播) — 测试(这里没有加测试代码,可以照着训练代码改一下)

损失函数:生成器损失函数和判别器损失函数,两个损失函数分别进行反向传播,即生成器损失函数优化生成器,判别器损失函数优化判别器。

import torch
import torch.nn as nn
import argparse
import os
import numpy as np
from torch.utils.data import DataLoader, Dataset, dataset
from torchvision import datasets
from torchvision.transforms import transforms
from torchvision.utils import save_image
from torch.autograd import Variable
from models.generator import Generator
from models.dicsriminator import Disctiminatoros.makedirs('/home/sjr/gxj/study/data/mnist', exist_ok=True)dataloader = DataLoader(datasets.MNIST('/home/sjr/gxj/study/data/mnist',train=True, download=True,transform=transforms.Compose([transforms.Resize(28), transforms.ToTensor(),transforms.Normalize([0.5], [0.5])])),batch_size=64, shuffle=True, num_workers=4)adversarial_loss = torch.nn.BCELoss()
generator = Generator()
discriminator = Disctiminator()device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
if torch.cuda.is_available():adversarial_loss.cuda(device)generator.cuda(device)discriminator.cuda(device)optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensorfor epoch in range(60):for i, (imgs, _) in enumerate(dataloader):valid = Tensor(imgs.size(0), 1).fill_(1.0)fake = Tensor(imgs.size(0), 1).fill_(0.0)real_imgs = Variable(imgs.type(Tensor))optimizer_G.zero_grad()z = Tensor(np.random.normal(0, 1, (imgs.shape[0], 100)))gen_imgs = generator(z)g_loss = adversarial_loss(discriminator(gen_imgs), valid)g_loss.backward()optimizer_G.step()optimizer_D.zero_grad()real_loss = adversarial_loss(discriminator(real_imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)d_loss = (real_loss + fake_loss) / 2d_loss.backward()optimizer_D.step()#print(f"[Epoch {epoch}/{200}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")#if (epoch + 1) % 20 == 0:save_image(gen_imgs.data[:25], f'images/{epoch+1}.png', nrow=5, normalize=True)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/6913.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【刷题】代码随想录算法训练营第二十九天|491、递增子序列,46、全排列,47、全排列II

目录 491、递增子序列46、全排列47、全排列II 491、递增子序列 讲解&#xff1a;https://programmercarl.com/0491.%E9%80%92%E5%A2%9E%E5%AD%90%E5%BA%8F%E5%88%97.html class Solution { private:vector<vector<int>> result;vector<int> path;void backt…

pandas读取文件导致jupyter内核崩溃如何解决

读取execl文件出现以下问题: str_name "D:\\cao_use\\2017_2021(new).xlsx" train_df pd.read_excel(str_name, usecols[0])崩溃的指示图如下所示: bug原因:读入的文件太大&#xff0c;所需时间过长&#xff0c;在读取的过程中&#xff0c;使用中断按钮暂停会直…

超级好用的C++实用库之动态库加载器

概述 在C中&#xff0c;动态库也称为共享库或DLL&#xff0c;是一种可执行文件形式&#xff0c;其中包含可以被多个应用程序同时加载并使用的函数和数据。相较于静态库&#xff0c;动态库在运行时而不是编译链接阶段被程序所使用。加载动态库&#xff0c;在Windows和Linux操作系…

OpenAI神秘模型,再次被Sam Altman提及

5月6日&#xff0c;OpenAI首席执行官Sam Altman在社交平台分享了一条推文“我是一个优秀的GPT-2聊天机器人”。 而在4月30日&#xff0c;Altman就提起过该模型非常喜欢GPT-2。按道理说一个只有15亿参数在2019年发布的开源模型&#xff0c;被反复提及两次就很不寻常。 更意外的…

Yarn 的安装和使用指南

Yarn 的安装和使用指南 Yarn 是一个快速、可靠、安全的 JavaScript 依赖管理工具&#xff0c;它可以帮助开发人员更高效地管理项目的依赖关系。本文将介绍如何安装 Yarn 并展示一些常用的 Yarn 命令和用法。 安装 Yarn 使用 npm 安装 Yarn 在安装 Yarn 之前&#xff0c;首先…

volatile原理

文章目录 如何保证可见性如何保证有序性double-checked locking 问题double-checked locking 解决 volatile 的底层实现原理是内存屏障&#xff0c;Memory Barrier&#xff08;Memory Fence&#xff09; 对 volatile 变量的写指令后会加入写屏障对 volatile 变量的读指令前会加…

正则表达式_字符匹配/可选字符集

正则表达式&#xff08;Regular Expression&#xff09;也叫匹配模式(Pattern)&#xff0c;用来检验字符串是否满足特 定规则&#xff0c;或从字符串中捕获满足特定规则的子串。 字符匹配 最简单的正则表达式由“普通字符”和“通配符”组成。比如“Room\d\d\d”就这样 的正则…

短网址短链接哪个好用?2024年最好的缩短链接短网址推荐

短网址&#xff0c;又称短链接&#xff0c;英文名为Short URL&#xff0c;是一种形式上比较短的网址&#xff0c;使用跳转到方式代替长网址链接&#xff0c;形式美观&#xff0c;而且更容易分享。最出名的短网址服务有国外的bit.ly和谷歌goo.gl&#xff0c;以及国内的百度短网址…

thinkphp5.1 新建模块

thinkphp5.1 新建模块 在ThinkPHP5.1中&#xff0c;创建一个新模块的步骤如下&#xff1a;使用命令行工具创建模块目录结构。 在模块目录中创建相应的文件和目录。 以下是具体的操作步骤和示例代码&#xff1a; 1. 使用命令行工具进入到项目的根目录下&#xff0c;执行以下…

AI+客服行业落地应用

一、客服行业变迁 1.传统客服时代 &#xff08;1&#xff09;客服工作重复性高&#xff0c;技术含量低 &#xff08;2&#xff09;呼出效率低&#xff0c;客服水平参差不齐 &#xff08;3&#xff09;管理难度高&#xff0c;情绪不稳定 &#xff08;4&#xff09;服务质量…

《视觉十四讲》例程运行记录(1)—— 课本源码下载和3rdparty文件夹是空的解决办法

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、第二版十四讲课本源码下载1. 安装git工具 二、Pangolin下载和安装1. 源码下载2. Pangolin的安装(1) 安装依赖项(2) 源码编译安装(2) 测试是否安装成功 二、…

4:分配器测试

文章目录 分配器作用容器中默认的分配器分配器测试程序这节课并没有总结各种分配器的使用结果 分配器作用 负责分配和管理容器的空间的 不需要用户手动创建 容器中默认的分配器 第二个参数表示默认的分配器 每一个容器初始化的时候 带一个默认的分配器 分配器测试程序 右边的…

商城数据库88张表结构完整示意图61~70(十四)

六十一&#xff1a; 六十二&#xff1a; 六十三&#xff1a; 六十四&#xff1a; 六十五&#xff1a; 六十六&#xff1a; 六十七&#xff1a; 六十八&#xff1a; 六十九&#xff1a; 七十&#xff1a;

深度学习之基于YOLOv5的山羊行为识别系统

欢迎大家点赞、收藏、关注、评论啦 &#xff0c;由于篇幅有限&#xff0c;只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 深度学习之基于YOLOv5的山羊行为识别系统是一个创新的项目&#xff0c;旨在通过深度学习和目标检测技术&#xff0c…

【数据结构(邓俊辉)学习笔记】列表04——排序器

文章目录 0. 统一入口1. 选择排序1.1 构思1.2 实例1.3 实现1.4 复杂度 2. 插入排序2.1 构思2.2 实例2.3 实现2.4 复杂度分析2.5 性能分析 3. 归并排序3.1 二路归并算法3.1.1 二路归并算法原理3.1.2 二路归并算法实现3.1.3 归并时间 3.2 分治策略3.2.1 实现3.2.2 排序时间 4. 总…

【Java】基本程序设计结构(二)

前言&#xff1a;上一篇我们详细介绍了Java基本程序设计结构中前半部分&#xff0c;一个简单的Java应用&#xff0c;注释&#xff0c;数据类型&#xff0c;变量与常量&#xff0c;运算符&#xff0c;字符串。包括本篇将延续上篇内容介绍后续内容&#xff0c;包括输入输出&#…

正则表达式之python中re模块的使用以及一些习题

正则表达式 正则表达式是一种用来描述字符串模式的方法。它是一种强大的工具&#xff0c;用于在文本中搜索、匹配和编辑特定模式的字符串。正则表达式可以用来验证输入是否符合某种模式&#xff0c;提取文本中的特定信息&#xff0c;以及进行文本的替换和分割等操作。在计算机…

AutoTable, Hibernate自动建立表替代方案

痛点 之前一直使用JPA为主要ORM技术栈&#xff0c;主要是因为Mybatis没有实体逆向建表功能。虽然Mybatis有从数据库建立实体&#xff0c;但是实际应用却没那么美好&#xff1a;当实体变更时&#xff0c;往往不会单独再建立一个数据库重新生成表&#xff0c;然后把表再逆向为实…

python关键字(break)

7、break 深入理解Python 3.8中的break关键字 在Python编程中&#xff0c;break是一个控制流语句&#xff0c;用于立即退出最内层的循环。它对于需要中断循环并在满足特定条件时继续执行的程序非常有用。本文将带您从基础到进阶&#xff0c;深入了解break在Python 3.8中的用法…