【人工智能基础】GAN与WGAN实验

一、GAN网络概述

GAN:生成对抗网络。GAN网络中存在两个网络:G(Generator,生成网络)和D(Discriminator,判别网络)。

Generator接收一个随机的噪声z,通过这个噪声生成图片,记做G(z)

Discriminator功能是判别一张图片的真实。它的输入是一张图片x,输出D(x)代表x为真实图片的概率,如果为1就代表图片真实,而输出为0,就代表图片不真实。

在GAN网络的训练中,Generator的目标就是尽量生成真实的图片去欺骗Discriminator

Discriminator的目标就是尽量把Generator生成的图片和真实的图片分别开来

二、GAN实验环境准备

除了之前使用过的pytorch-nplnumpy以外,我们还需要安装visdom

pip install visdom

启动visdom

python -m visdom.server

visdom启动成功如下图,会占用8097端口,我们可以通过8097端口访问visdom

visdom启动.png

三、GAN网络实验

环境参数配置

import torch
from torch import nn,optim,autograd
import numpy as np
import visdom
import randomh_dim = 400
batchsz = 512
viz = visdom.Visdom()

生成网络定义

class Generator(nn.Module):def __init__(self):super(Generator,self).__init__()self.net = nn.Sequential(# input[b, 2]nn.Linear(2,h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, 2)# output[b,2])def forward(self, z):output = self.net(z)return output

判别网络定义

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.net = nn.Sequential(nn.Linear(2, h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, 1),nn.Sigmoid())def forward(self, x):output = self.net(x)return output.view(-1)

数据集生成函数

def data_generator():# 生成中心点scale = 2centers = [(1, 0),(-1, 0),(0, 1),(0, -1),(1. / np.sqrt(2), 1. / np.sqrt(2)),(1. / np.sqrt(2), -1. / np.sqrt(2)),(-1. / np.sqrt(2), 1. / np.sqrt(2)),(-1. / np.sqrt(2), -1. / np.sqrt(2))]centers = [(scale * x, scale * y) for x,y in centers] while True:dataset = []for i in range(batchsz):point = np.random.randn(2) * 0.02# 随机选取一个中心点center = random.choice(centers)# 把刚刚随机到的高斯分布点根据center进行移动point[0] += center[0]point[1] += center[1]dataset.append(point)dataset = np.array(dataset).astype(np.float32)dataset /= 1.414yield dataset

可视化函数

将图片生成到visdom

import matplotlib.pyplot as plt
def generate_image(D, G, xr, epoch):N_POINTS = 128RANGE = 3plt.clf()points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')points[:,:,0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]points[:,:,1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]points = points.reshape((-1,2))with torch.no_grad():points = torch.Tensor(points).cpu()disc_map = D(points).cpu().numpy()x = y = np.linspace(-RANGE,RANGE,N_POINTS)cs = plt.contour(x,y,disc_map.reshape((len(x), len(y))).transpose())plt.clabel(cs, inline=1,fontsize=10)with torch.no_grad():z = torch.randn(batchsz, 2).cpu()samples = G(z).cpu().numpy()plt.scatter(xr[:,0],xr[:,1],c='orange',marker='.')plt.scatter(samples[:,0], samples[:,1], c='green',marker='+')viz.matplot(plt, win='contour',opts=dict(title='p(x):%d'%epoch))

运行函数

def run():torch.manual_seed(23)np.random.seed(23)data_iter = data_generator()x = next(data_iter)# print(x.shape)# G = Generator().cuda()# D = Discriminator().cuda()# 无显卡环境device = torch.device("cpu")G = Generator().cpu()print(G)D = Discriminator().cpu()print(D)optim_G = optim.Adam(G.parameters(), lr = 5e-4, betas=(0.5,0.9))optim_D = optim.Adam(D.parameters(), lr = 5e-4, betas=(0.5,0.9))viz.line([[0,0]],[0],win='loss', opts=dict(title='loss',legend=['D','G']))"""gan核心部分"""for epoch in range(50000):# 训练判别网络for _ in range(5):# 真实数据训练xr = next(data_iter)xr = torch.from_numpy(xr).cpu()predr = D(xr)# 放大真实数据lossr = -predr.mean()# 虚假数据训练z = torch.randn(batchsz,2).cpu()xf = G(z).detach()predf = D(xf)# 缩小虚假数据lossf = predf.mean()loss_D = lossr + lossf# 梯度清零optim_D.zero_grad()# 向后传播loss_D.backward()optim_D.step()# 训练生成网络z = torch.randn(batchsz,2).cpu()xf = G(z)predf = D(xf)loss_G = -predf.mean()optim_G.zero_grad()loss_G.backward()optim_G.step()if epoch % 100 == 0:viz.line([[loss_D.item(),loss_G.item()]], [epoch],win='loss', update='append')print(loss_D.item(), loss_G.item())generate_image(D, G, xr, epoch)

执行(GAN的不稳定性)

run()

从结果中可以看到,判别网络的loss一直为0,而生成网络一直得不到更新,生成的数据点远离我们创建的中心点

gan运行.png

四、wgan实验

WGAN主要从损失函数的角度对GAN做了改进,对更新后的权重强制截断到一定范围内

增加一个梯度惩罚函数

def gradient_penalty(D,xr,xf):# [b,1]t = torch.rand(batchsz, 1).cpu()# 扩展为[b, 2]t = t.expand_as(xr)# 插值mid = t * xr + (1 - t) * xf# 设置需要的倒数信息mid.requires_grad_()pred = D(mid)grads = autograd.grad(outputs=pred, inputs=mid,grad_outputs=torch.ones_like(pred),create_graph=True,retain_graph=True,only_inputs=True)[0]gp = torch.pow(grads.norm(2, dim=1) - 1, 2).mean()return gp

修改运行函数

def run():torch.manual_seed(23)np.random.seed(23)data_iter = data_generator()x = next(data_iter)# print(x.shape)# G = Generator().cuda()# D = Discriminator().cuda()# 无显卡环境device = torch.device("cpu")G = Generator().cpu()print(G)D = Discriminator().cpu()print(D)optim_G = optim.Adam(G.parameters(), lr = 5e-4, betas=(0.5,0.9))optim_D = optim.Adam(D.parameters(), lr = 5e-4, betas=(0.5,0.9))viz.line([[0,0]],[0],win='loss', opts=dict(title='loss',legend=['D','G']))"""gan核心部分"""for epoch in range(50000):# 训练判别网络for _ in range(5):# 真实数据训练xr = next(data_iter)xr = torch.from_numpy(xr).cpu()predr = D(xr)# 放大真实数据lossr = -predr.mean()# 虚假数据训练z = torch.randn(batchsz,2).cpu()xf = G(z).detach()predf = D(xf)# 缩小虚假数据lossf = predf.mean()# 梯度惩罚值gp = gradient_penalty(D,xr,xf.detach())loss_D = lossr + lossf + 0.2 * gp# 梯度清零optim_D.zero_grad()# 向后传播loss_D.backward()optim_D.step()# 训练生成网络z = torch.randn(batchsz,2).cpu()xf = G(z)predf = D(xf)loss_G = -predf.mean()optim_G.zero_grad()loss_G.backward()optim_G.step()if epoch % 100 == 0:viz.line([[loss_D.item(),loss_G.item()]], [epoch],win='loss', update='append')print(loss_D.item(), loss_G.item())generate_image(D, G, xr, epoch)

执行

run()

可以看到在wgan中,生成网络开始学习,生成的数据点也能基本根据高斯分布落在中心点附近

wgan运行.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/836279.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

阿里开源编程大模型 CodeQwen1.5:64K92编程语言,Code和SQL编程,评测接近GPT-4-Turbo

前言 阿里巴巴最近发布的CodeQwen1.5模型标志着其在编程语言模型领域的一次重大突破。这款开源模型不仅支持高达92种编程语言和64K的上下文长度,而且在多项性能评测中显示出接近或超过当前行业领导者GPT-4-Turbo的能力。 Huggingface模型下载:https://h…

Boost库的使用

1 下载与安装 1.1 下载 网址:Boost C Libraries 进入后选择自己需要的版本安装即可 1.2 安装 1.2.1 解压 1.2.2 编译安装 双击bootstrap.bat 这一步完成后会生成一个b2.exe文件 双击b2.exe文件运行(此步需要花费较长的时间) 之后再stag…

双向链表(双向带头循环)的增删查改的实现(简单易懂)

一:双向链表的概念 每个节点除开存有数据,还有一个指针指向前一个节点,一个指针指向后一个节点,尾节点和哨兵位互相指向,从而形成一个循环。 二:双向链表的实现第一点: 本文采用三个文件进行实…

GIS数据—1984-2020中国1km人造夜间灯光观测数据

夜间灯光观测数据(Nighttime Light,NTL)是评估人类活动边界的常用手段,目前,该数据已经广泛应用于城市范围、不透水面、基础设施建设等一系列过程。今天,小编要带来的是长时间序列中国区域边界的夜间灯光观测数据。 数…

springcloud -nacos实战

一、nacos 功能简介 1.1.什么是Nacos? 官方简介:一个更易于构建云原生应用的动态服务发现(Nacos Discovery )、服务配置(Nacos Config)和服务管理平台。 Nacos的关键特性包括: 服务发现和服务健康监测动态配置服务动态DNS服务服务及其元数…

VMware配置Kali linux + 物理机连接Xshell

VMware 配置 kali linux 首先需要先安装VMware Workstation 我是在Windows 安装的 VMware Workstation Pro 17 虚拟化,产品密钥。。这里不做多说了 下载kali linux 这里我下载的是kali-linux-2024.1 Note:这里选Virtual Machines,建议不要…

景源畅信:抖音小店的商品怎么同步到橱窗?

在数字营销的海洋中,抖音小店与橱窗的同步操作无疑是商家们关注的焦点。这不仅能增加商品的曝光度,还能提高交易的可能性。那么,如何将抖音小店的商品同步到橱窗呢? 一、核心步骤解析 要实现商品从抖音小店同步到橱窗,你需要确保…

【Linux 网络】网络编程套接字 -- 详解

⚪ 预备知识 1、理解源 IP 地址和目的 IP 地址 举例理解:(唐僧西天取经) 在 IP 数据包头部中 有两个 IP 地址, 分别叫做源 IP 地址 和目的 IP 地址。 如果我们的台式机或者笔记本没有 IP 地址就无法上网,而因为…

Unity引擎是什么?有哪些优点

大家好,我是咕噜土豆,很高兴又和大家见面了。今天我们一起来了解一下Unity引擎和它有哪些优点。 首先带大家了解什么是Unity引擎 Unity引擎是一款由Unity Technologies开发的跨平台游戏开发引擎,广泛用于创建2D和3D游戏以及其他交互式内容&…

C++动态内存区域划分、new、delete关键字

目录 一、C/C中程序的内存区域划分 为什么会存在内存区域划分? 二、new关键字 1、内置类型的new/delete使用方法: 2、new和delete的本质 一、C/C中程序的内存区域划分 为什么会存在内存区域划分? 因为不同数据有不同的存储需求&#xff0…

【SpringBoot记录】从基本使用案例入手了解SpringBoot-数据访问(1)

前言 在程序开发尤其是网页应用开发中,数据访问是必不可少的。通过前面的基本案例我们完成了一个简单的SpringBoot Web应用并对自动配置原理有了一定了解,本节在上述案例基础上,继续编写数据访问案例,将通过SpringBoot中数据访问…

音视频开发6 音视频录制原理和播放原理

音视频录制原理 音视频播放原理

# 电脑突然连接不上网络了,怎么办?

电脑突然连接不上网络了,怎么办? 一、原因分析: 1、IP 地址冲突 2、DNS 解析出现问题。 3、电脑网络设置是否打开了【移动热点】或【飞行模式】。 4、【WLAN AutoConfig】服务是否打开。 5、无线网卡驱动损坏。 6、检查 WIFI 开关是否…

HTML+VUE3组合式+ELEMENT的容器模板示例(含侧栏导航,表格,...)

一个简单的在html中使用Vue3及Element-plus vue-icons的整合示例&#xff1a; 一、示例截图 二、文件代码 直接复制到html文件在浏览器打开即可预览 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title&g…

CCleaner系统优化与隐私保护工具,中文绿色便携版 v6.23.11010

01 软件介绍 CCleaner 是一款高级的系统优化工具&#xff0c;其设计宗旨在于彻底清理 Windows 操作系统中积累的无用文件和冗余的注册表项。此举旨在显著提升计算机的运行效率并回收磁盘空间。该软件拥有高效的能力&#xff0c;可以清除包括临时文件、浏览器缓存及其历史记录在…

08 - hive的集合函数、高级聚合函数、炸裂函数以及窗口函数

目录 1、集合函数 1.1、size&#xff1a;集合中元素的个数 1.2、map&#xff1a;创建map集合 1.3、map_keys&#xff1a; 返回map中的key 1.4、map_values: 返回map中的value 1.5、array 声明array集合 1.6、array_contains: 判断array中是否包含某个元素 1.7、sort_a…

UIKit之UIButton

功能需求&#xff1a; 点击按钮切换按钮的文字和背景图片&#xff0c;同时点击上下左右可以移动图片位置&#xff0c;点击加或减可以放大或缩小图片。 分析&#xff1a; 实现一个UIView的子类即可&#xff0c;该子类包含多个按钮。 实现步骤&#xff1a; 使用OC语言&#xf…

HCIP的学习(14)

过滤策略—filter-policy ​ 思科中&#xff1a;分发列表 ​ 过滤策略是只能够针对于路由信息进行筛选&#xff08;过滤&#xff09;的工具&#xff0c;而无法针对于LSA进行过滤。 在R4的出方向上配置过滤策略&#xff0c;使得R1不能学习到23.0.0.0/24路由信息1、抓取流量 […

Sping源码(七)—ConfigurationClassPostProcessor ——@PropertySources解析

序言 先来简单回顾一下ConfigurationClassPostProcessor大致的一个处理流程&#xff0c;再来详细的讲解PropertySources注解的处理逻辑。 详细的步骤可参考ConfigurationClassPostProcessor这篇帖子。 流程图 从获取所有BeanDefinition -> 过滤、赋值、遍历 -> 解析 -&…

自建WSUS更新服务器完成内网的安全补丁更新

一、适用场景 1、企业内部网络无法访问外网&#xff0c;所以搭建WSUS服务器&#xff0c;可以让内网环境进行更新补丁。 2、校园内部的电脑实训室一般不用外网资源&#xff0c;偶尔开启外网使用时&#xff0c;电脑实训室集体自动更新占用外网资源量大&#xff0c;所以搭建WSUS服…