Pytorch Advanced(一) Generative Adversarial Networks

生成对抗神经网络GAN,发挥神经网络的想象力,可以说是十分厉害了

参考

1、AI作家
2、将模糊图变清晰(去雨,去雾,去抖动,去马赛克等),这需要AI具有“想象力”,能脑补情节;
3、进行数据增强,根据已有数据生成更多新数据供以feed,可以减缓模型过拟合现象。

那到底是怎么实现的呢?


GAN中有两大组成部分G和D

G是generator,生成器: 负责凭空捏造数据出来

D是discriminator,判别器: 负责判断数据是不是真数据

示例图如下:

给一个随机噪声z,通过G生成一张假图,然后用D去分辨是真图还是假图。假设G生成了一张图,在D那里的得分很高,那么G就很成功的骗过了D,如果D很轻松的分辨出了假图,那么G的效果不好,那么就需要调整参数了。


G和D是两个单独的网络,那么他们的参数都是训练好的吗?并不是,两个网络的参数是需要在博弈的过程中分别优化的。

下面就是一个训练的过程:

GAN在一轮反向传播中分为两步,先训练D在训练G。

训练D时,上一轮G产生的图片,和真实图片一起作为x进行输入,假图为0,真图标签为1,通过x生成一个score,通过score和标签y计算损失,就可以进行反向传播了。

训练G时,G和D是一个整体,取名为D_on_G。输入随机噪声,G产生一个假图,D去分辨,score = 1就是需要我们需要优化的目标,意思就是我们要让生成的图片变成真的。这里的D是不需要参与梯度计算的,我们通过反向传播来优化G,让他生成更加真实的图片。这就好比:如果你参加考试,你别指望能改变老师的评分标准


GAN无监督学习,(cGAN是有监督的),以后会学习的。怎么理解无监督学习呢?这里给的真图是没有经过人工标注的,只知道这是真的,D是不知道这是什么的,只需要分辨真假。G也不知道生成了什么,只需要学真图去骗D。


具体如何实施呢?

import os
import torch
import torchvision
import torch.nn as nn 
from torchvision import transforms
from torchvision.utils import save_imagedevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples'

注意这里有个归一化的过程,MNIST是单通道,但是如果mean=(0.5,0.5,0.5)会报错,因为是对3通道操作 。

if not os.path.exists(sample_dir):os.makedirs(sample_dir)transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5,),   # 3 for RGB channelsstd=(0.5,))])# MNIST dataset
mnist = torchvision.datasets.MNIST(root='./data/',train=True,transform=transform,download=True)
# Data loader
data_loader = torch.utils.data.DataLoader(dataset=mnist,batch_size=batch_size, shuffle=True)

定义生成器和判别器:

生成器:可以看到输入的维度为64,是一组噪声图像,通过生成器将特征扩大到了MNIST图像大小784。

判别器:输入维度为图像大小,最后输出特征个数为1,采用sigmoid激活(不用softmax的)

# Discriminator
D = nn.Sequential(nn.Linear(image_size, hidden_size),nn.LeakyReLU(0.2),nn.Linear(hidden_size, hidden_size),nn.LeakyReLU(0.2),nn.Linear(hidden_size, 1),nn.Sigmoid())# Generator 
G = nn.Sequential(nn.Linear(latent_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, image_size),nn.Tanh())
# Device setting
D = D.to(device)
G = G.to(device)# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)def denorm(x):out = (x + 1) / 2return out.clamp(0, 1)def reset_grad():d_optimizer.zero_grad()g_optimizer.zero_grad()

 重点看训练部分,我们到底是如何来训练GAN的。

判别器部分:判别器的损失值分为两部分,(一)将mini_batch定义为正样本,告诉他我是正品,所以设置标签为1。优化判别器判断正品的能力;(二)生成一幅赝品,再给判别器判别,这时候赝品的标签为0,优化判断赝品的能力。所以总损失为这两部分之和,计算梯度,优化判别器参数。

G_on_D:输入一个噪声,让生成器生成一幅图像,然后让D去判别,计算和正品之间的距离,即损失。反向传播,优化G的参数。

# Start training
total_step = len(data_loader)
for epoch in range(num_epochs):for i, (images, _) in enumerate(data_loader):images = images.reshape(batch_size, -1).to(device)# Create the labels which are later used as input for the BCE lossreal_labels = torch.ones(batch_size, 1).to(device)fake_labels = torch.zeros(batch_size, 1).to(device)# ================================================================== ##                      Train the discriminator                       ## ================================================================== ## Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))# Second term of the loss is always zero since real_labels == 1outputs = D(images)d_loss_real = criterion(outputs, real_labels)real_score = outputs# Compute BCELoss using fake images# First term of the loss is always zero since fake_labels == 0z = torch.randn(batch_size, latent_size).to(device)fake_images = G(z)outputs = D(fake_images)d_loss_fake = criterion(outputs, fake_labels)fake_score = outputs# Backprop and optimized_loss = d_loss_real + d_loss_fakereset_grad()d_loss.backward()d_optimizer.step()# ================================================================== ##                        Train the generator                         ## ================================================================== ## Compute loss with fake imagesz = torch.randn(batch_size, latent_size).to(device)fake_images = G(z)outputs = D(fake_images)# We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))# For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdfg_loss = criterion(outputs, real_labels)# Backprop and optimizereset_grad()g_loss.backward()g_optimizer.step()if (i+1) % 200 == 0:print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), real_score.mean().item(), fake_score.mean().item()))# Save real imagesif (epoch+1) == 1:images = images.reshape(images.size(0), 1, 28, 28)save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))# Save sampled imagesfake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))

训练完了怎么用?

只要用我们的生成器就可以随意生成了。

import matplotlib.pyplot as plt
z = torch.randn(1,latent_size).to(device)
output = G(z)
plt.imshow(output.cpu().data.numpy().reshape(28,28),cmap='gray') 
plt.show()

 下面就是随机生成的图像了!

  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/76495.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

centos 下 Makefile 独立模块编译ko

1、安装编译内核环境包 编译需要用到kernel 源码,centos 下需先安装 kernel-devel 包,要下与自己kernel 对应版本 yum install kernel-devel 2、首先从内核或自己写的模块,发到编译环境中 注:就像我自己拷贝一个 bcache 驱动的目…

从零开始在树莓派上搭建WordPress博客网站并实现公网访问

文章目录 序幕概述1. 安装 PHP2. 安装MySQL数据库3. 安装 Wordpress4. 设置您的 WordPress 数据库设置 MySQL/MariaDB创建 WordPress 数据库 5. WordPress configuration6. 将WordPress站点发布到公网安装相对URL插件修改config.php配置 7. 支持好友链接样式8. 定制主题 序幕 …

时序预测 | MATLAB实现LSSVM最小二乘支持向量机时间序列预测未来

时序预测 | MATLAB实现LSSVM最小二乘支持向量机时间序列预测未来 目录 时序预测 | MATLAB实现LSSVM最小二乘支持向量机时间序列预测未来预测效果基本介绍程序设计参考资料 预测效果 基本介绍 1.Matlab实现LSSVM时间序列预测未来(最小二乘支持向量机); 2.运行环境Mat…

LabVIEW更改Tab所选标签的颜色

LabVIEW更改Tab所选标签的颜色 在开发过程中,有时会出现要将不同tab页设置不同颜色的情况。此VI允许编程方式更改前面板选项卡控件上选项卡的颜色。它是突出显示所选选项卡的理想选择 在某些应用程序中,用户希望在按下时突出显示选项卡控件。此VI使用事…

借助ChatGPT使用Pandas实现Excel数据汇总

一、问题的提出 现在有如下一个Excel表: 上述Excel表中8万多条数据,记录的都是三年以来花菜类的销量,现在要求按月汇总实现统计每个月花菜类的销量总和,如果使用Python的话要给出代码。 二、问题的解决 1.首先可以用透视表的方…

idea配置git(gitee)并提交(commit)推送(push)

Intellij Idea VCS | 版本控制 - 知乎 IDEA项目上传到gitee仓库_idea上传代码到gitee_robin19712的博客-CSDN博客 git程序下载国内镜像地址: https://registry.npmmirror.com/binary.html?pathgit-for-windows/v2.42.0.windows.2/ 解压后放到固定路径&#xff1a…

ADW300物联网电表支持MODBUSTCP协议、MQTT协议-安科瑞黄安南

摘要 随着通信技术的应用越来越广泛,具有通信功能的电子产品越来越多,同时也随着Wi-Fi无线覆盖网络区域的形成,如何利用无线网络覆盖广、带宽高、低使用费率的优势组建物联网系统,变成了一个很实际的问题。 安科瑞也紧跟趋势推出…

antv-G6知识图谱安装--使用(实例)--连接线修改成动态,并添加跟随线移动的光圈,设置分支跟踪定位功能

这系列文章主要是完成一个图谱的自定义修改(最近太忙了长篇分段更新自己使用流程) 1. 连接线修改成动态,并添加跟随线移动的光圈 2. 自定义卡片样式和文字内容 3. 自定义伸缩节点的样式,并添加动画样式 3. 自定义弹窗样式 4. 自定…

lvs负载均衡、LVS集群部署

四:LVS集群部署 lvs给nginx做负载均衡项目 218lvs(DR 负载均衡器) yum -y install ipvsadm(安装这个工具来管理lvs) 设置VIP192.168.142.120 创建ipvsadm的文件用来存放lvs的规则 定义策略 ipvsadm -C //清空现有…

自己设计CPU学习之路——基于《Xilinx FPGA应用开发》

1. 一个32组位宽为32的寄存器堆 框图 代码 regfile.h ifndef __FEGFILE_HEADER__define __REGFILE_HEADER__define HIGH 1b1define LOW 1b0define ENABLE_ 1b0define DISABLE_ 1b1define DATA_W 32define DataBus 31:0define DATA_D 32d…

正中优配:月线macd指标参数设置?

随着投资者长期持有股票的越来越受欢迎,月线MACD目标已成为识别趋势和交易信号的重要东西。但是,许多投资者在设置MACD目标参数时仍然感到困惑。本文将从多个视点剖析,为您解答月线MACD目标参数设置的问题。 什么是MACD目标? MAC…

vue中v-for循环数组使用方法中splice删除数组元素(错误:每次都删掉点击的下面的一项)

总结:平常使用v-for的key都是使用index,这里vue官方文档也不推荐,这个时候就出问题了,我们需要key为唯一标识,这里我使用了时间戳(new Date().getTime())处理比较复杂的情况, 本文章…

Spring-Cloud-Openfeign如何做熔断降级?

微服务系统中为了防止服务雪崩问题&#xff0c;服务之间相互调用的时候一般需要开启熔断与降级&#xff0c;下面就来看下feign如何集成hystrix来做熔断与降级。 依赖 <dependency><groupId>org.springframework.cloud</groupId><artifactId>spring-c…

Full authentication is required to access this resource解决办法

我们在使用postman调接口时候&#xff0c;有的时候需要权限才可以访问&#xff0c;否则可能会报下面这个错误 {"success": false,"message": "Full authentication is required to access this resource","code": 401,"result&q…

浏览器缓存机制及其分类

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ 强缓存&#xff08;Cache-Control 和 Expires&#xff09;⭐ 协商缓存&#xff08;ETag 和 Last-Modified&#xff09;⭐ 写在最后 ⭐ 专栏简介 前端入门之旅&#xff1a;探索Web开发的奇妙世界 记得点击上方或者右侧链接订阅本专栏哦 几…

四川百幕晟科技:抖音新店怎么快速起店?

抖音作为全球最大的短视频平台&#xff0c;拥有庞大的用户基础和强大的影响力&#xff0c;成为众多商家宣传产品、增加销量的理想选择。那么&#xff0c;如何快速开店并成功运营呢&#xff1f;下面描述了一些关键步骤。 1、如何快速开新店&#xff1f; 1、确定产品定位&#x…

JVM GC垃圾回收

一、GC垃圾回收算法 标记-清除算法 算法分为“标记”和“清除”阶段&#xff1a;标记存活的对象&#xff0c; 统一回收所有未被标记的对象(一般选择这种)&#xff1b;也可以反过来&#xff0c;标记出所有需要回收的对象&#xff0c;在标记完成后统一回收所有被标记的对象 。它…

geopandas笔记:汇总连接两个区域的边

比如这样的两个区域&#xff0c;我们想知道从蓝到绿、从绿到蓝都有哪些边 1 读取openstreetmap import osmnx as ox import geopandas as gpdGox.graph_from_place(Singapore,simplifyTrue,network_typedrive)ox.plot_graph(G) 2 得到对应的边的信息 nodes,edgesox.graph_to_…

正中优配:股票出现xd是好还是坏?

近年来&#xff0c;股票市场的日渐成熟和开展使得出资者们关于股票价格的涨跌也愈加灵敏&#xff0c;特别是股票呈现XD之后&#xff0c;更是引起了一系列热议。那么&#xff0c;股票呈现XD是好还是坏&#xff1f;本文将从多个角度进行剖析。 首要&#xff0c;需要清晰XD的定义…

2023年9月制造业NPDP产品经理国际认证报名来这错不了

产品经理国际资格认证NPDP是新产品开发方面的认证&#xff0c;集理论、方法与实践为一体的全方位的知识体系&#xff0c;为公司组织层级进行规划、决策、执行提供良好的方法体系支撑。 【认证机构】 产品开发与管理协会&#xff08;PDMA&#xff09;成立于1979年&#xff0c;是…