【人工智能基础】GAN与WGAN实验

一、GAN网络概述

GAN:生成对抗网络。GAN网络中存在两个网络:G(Generator,生成网络)和D(Discriminator,判别网络)。

Generator接收一个随机的噪声z,通过这个噪声生成图片,记做G(z)

Discriminator功能是判别一张图片的真实。它的输入是一张图片x,输出D(x)代表x为真实图片的概率,如果为1就代表图片真实,而输出为0,就代表图片不真实。

在GAN网络的训练中,Generator的目标就是尽量生成真实的图片去欺骗Discriminator

Discriminator的目标就是尽量把Generator生成的图片和真实的图片分别开来

二、GAN实验环境准备

除了之前使用过的pytorch-nplnumpy以外,我们还需要安装visdom

pip install visdom

启动visdom

python -m visdom.server

visdom启动成功如下图,会占用8097端口,我们可以通过8097端口访问visdom

visdom启动.png

三、GAN网络实验

环境参数配置

import torch
from torch import nn,optim,autograd
import numpy as np
import visdom
import randomh_dim = 400
batchsz = 512
viz = visdom.Visdom()

生成网络定义

class Generator(nn.Module):def __init__(self):super(Generator,self).__init__()self.net = nn.Sequential(# input[b, 2]nn.Linear(2,h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, 2)# output[b,2])def forward(self, z):output = self.net(z)return output

判别网络定义

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.net = nn.Sequential(nn.Linear(2, h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, 1),nn.Sigmoid())def forward(self, x):output = self.net(x)return output.view(-1)

数据集生成函数

def data_generator():# 生成中心点scale = 2centers = [(1, 0),(-1, 0),(0, 1),(0, -1),(1. / np.sqrt(2), 1. / np.sqrt(2)),(1. / np.sqrt(2), -1. / np.sqrt(2)),(-1. / np.sqrt(2), 1. / np.sqrt(2)),(-1. / np.sqrt(2), -1. / np.sqrt(2))]centers = [(scale * x, scale * y) for x,y in centers] while True:dataset = []for i in range(batchsz):point = np.random.randn(2) * 0.02# 随机选取一个中心点center = random.choice(centers)# 把刚刚随机到的高斯分布点根据center进行移动point[0] += center[0]point[1] += center[1]dataset.append(point)dataset = np.array(dataset).astype(np.float32)dataset /= 1.414yield dataset

可视化函数

将图片生成到visdom

import matplotlib.pyplot as plt
def generate_image(D, G, xr, epoch):N_POINTS = 128RANGE = 3plt.clf()points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')points[:,:,0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]points[:,:,1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]points = points.reshape((-1,2))with torch.no_grad():points = torch.Tensor(points).cpu()disc_map = D(points).cpu().numpy()x = y = np.linspace(-RANGE,RANGE,N_POINTS)cs = plt.contour(x,y,disc_map.reshape((len(x), len(y))).transpose())plt.clabel(cs, inline=1,fontsize=10)with torch.no_grad():z = torch.randn(batchsz, 2).cpu()samples = G(z).cpu().numpy()plt.scatter(xr[:,0],xr[:,1],c='orange',marker='.')plt.scatter(samples[:,0], samples[:,1], c='green',marker='+')viz.matplot(plt, win='contour',opts=dict(title='p(x):%d'%epoch))

运行函数

def run():torch.manual_seed(23)np.random.seed(23)data_iter = data_generator()x = next(data_iter)# print(x.shape)# G = Generator().cuda()# D = Discriminator().cuda()# 无显卡环境device = torch.device("cpu")G = Generator().cpu()print(G)D = Discriminator().cpu()print(D)optim_G = optim.Adam(G.parameters(), lr = 5e-4, betas=(0.5,0.9))optim_D = optim.Adam(D.parameters(), lr = 5e-4, betas=(0.5,0.9))viz.line([[0,0]],[0],win='loss', opts=dict(title='loss',legend=['D','G']))"""gan核心部分"""for epoch in range(50000):# 训练判别网络for _ in range(5):# 真实数据训练xr = next(data_iter)xr = torch.from_numpy(xr).cpu()predr = D(xr)# 放大真实数据lossr = -predr.mean()# 虚假数据训练z = torch.randn(batchsz,2).cpu()xf = G(z).detach()predf = D(xf)# 缩小虚假数据lossf = predf.mean()loss_D = lossr + lossf# 梯度清零optim_D.zero_grad()# 向后传播loss_D.backward()optim_D.step()# 训练生成网络z = torch.randn(batchsz,2).cpu()xf = G(z)predf = D(xf)loss_G = -predf.mean()optim_G.zero_grad()loss_G.backward()optim_G.step()if epoch % 100 == 0:viz.line([[loss_D.item(),loss_G.item()]], [epoch],win='loss', update='append')print(loss_D.item(), loss_G.item())generate_image(D, G, xr, epoch)

执行(GAN的不稳定性)

run()

从结果中可以看到,判别网络的loss一直为0,而生成网络一直得不到更新,生成的数据点远离我们创建的中心点

gan运行.png

四、wgan实验

WGAN主要从损失函数的角度对GAN做了改进,对更新后的权重强制截断到一定范围内

增加一个梯度惩罚函数

def gradient_penalty(D,xr,xf):# [b,1]t = torch.rand(batchsz, 1).cpu()# 扩展为[b, 2]t = t.expand_as(xr)# 插值mid = t * xr + (1 - t) * xf# 设置需要的倒数信息mid.requires_grad_()pred = D(mid)grads = autograd.grad(outputs=pred, inputs=mid,grad_outputs=torch.ones_like(pred),create_graph=True,retain_graph=True,only_inputs=True)[0]gp = torch.pow(grads.norm(2, dim=1) - 1, 2).mean()return gp

修改运行函数

def run():torch.manual_seed(23)np.random.seed(23)data_iter = data_generator()x = next(data_iter)# print(x.shape)# G = Generator().cuda()# D = Discriminator().cuda()# 无显卡环境device = torch.device("cpu")G = Generator().cpu()print(G)D = Discriminator().cpu()print(D)optim_G = optim.Adam(G.parameters(), lr = 5e-4, betas=(0.5,0.9))optim_D = optim.Adam(D.parameters(), lr = 5e-4, betas=(0.5,0.9))viz.line([[0,0]],[0],win='loss', opts=dict(title='loss',legend=['D','G']))"""gan核心部分"""for epoch in range(50000):# 训练判别网络for _ in range(5):# 真实数据训练xr = next(data_iter)xr = torch.from_numpy(xr).cpu()predr = D(xr)# 放大真实数据lossr = -predr.mean()# 虚假数据训练z = torch.randn(batchsz,2).cpu()xf = G(z).detach()predf = D(xf)# 缩小虚假数据lossf = predf.mean()# 梯度惩罚值gp = gradient_penalty(D,xr,xf.detach())loss_D = lossr + lossf + 0.2 * gp# 梯度清零optim_D.zero_grad()# 向后传播loss_D.backward()optim_D.step()# 训练生成网络z = torch.randn(batchsz,2).cpu()xf = G(z)predf = D(xf)loss_G = -predf.mean()optim_G.zero_grad()loss_G.backward()optim_G.step()if epoch % 100 == 0:viz.line([[loss_D.item(),loss_G.item()]], [epoch],win='loss', update='append')print(loss_D.item(), loss_G.item())generate_image(D, G, xr, epoch)

执行

run()

可以看到在wgan中,生成网络开始学习,生成的数据点也能基本根据高斯分布落在中心点附近

wgan运行.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/836279.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL中的查询语法

条件查询 -- 条件查询 select 结果列 from 表名 where 条件 SELECT * FROM student WHERE height>1.80;-- and 所有条件都要满足 SELECT * FROM student WHERE height<1.80 AND gender 女;-- or 满足一个条件即可 SELECT * FROM student WHERE height1.80 OR gender女…

阿里开源编程大模型 CodeQwen1.5:64K92编程语言,Code和SQL编程,评测接近GPT-4-Turbo

前言 阿里巴巴最近发布的CodeQwen1.5模型标志着其在编程语言模型领域的一次重大突破。这款开源模型不仅支持高达92种编程语言和64K的上下文长度&#xff0c;而且在多项性能评测中显示出接近或超过当前行业领导者GPT-4-Turbo的能力。 Huggingface模型下载&#xff1a;https://h…

Boost库的使用

1 下载与安装 1.1 下载 网址&#xff1a;Boost C Libraries 进入后选择自己需要的版本安装即可 1.2 安装 1.2.1 解压 1.2.2 编译安装 双击bootstrap.bat 这一步完成后会生成一个b2.exe文件 双击b2.exe文件运行&#xff08;此步需要花费较长的时间&#xff09; 之后再stag…

双向链表(双向带头循环)的增删查改的实现(简单易懂)

一&#xff1a;双向链表的概念 每个节点除开存有数据&#xff0c;还有一个指针指向前一个节点&#xff0c;一个指针指向后一个节点&#xff0c;尾节点和哨兵位互相指向&#xff0c;从而形成一个循环。 二&#xff1a;双向链表的实现第一点&#xff1a; 本文采用三个文件进行实…

Pycharm中安装tablepyxl失败

tablepyxl是一个存在的 Python 包&#xff0c;它是一个桥接 HTML 表格和 openpyxl 的工具&#xff0c;允许你将 HTML 表格转换成 Excel 工作簿。如果你想在 conda 环境中安装 tablepyxl&#xff0c;可以按照以下步骤进行&#xff1a; &#xff08;1&#xff09;打开conda终端。…

GIS数据—1984-2020中国1km人造夜间灯光观测数据

夜间灯光观测数据&#xff08;Nighttime Light,NTL&#xff09;是评估人类活动边界的常用手段&#xff0c;目前&#xff0c;该数据已经广泛应用于城市范围、不透水面、基础设施建设等一系列过程。今天&#xff0c;小编要带来的是长时间序列中国区域边界的夜间灯光观测数据。 数…

springcloud -nacos实战

一、nacos 功能简介 1.1.什么是Nacos&#xff1f; 官方简介&#xff1a;一个更易于构建云原生应用的动态服务发现(Nacos Discovery )、服务配置(Nacos Config)和服务管理平台。 Nacos的关键特性包括&#xff1a; 服务发现和服务健康监测动态配置服务动态DNS服务服务及其元数…

C++语法|explicit关键字

文章目录 1.C的隐式对象转换问题举例产生的问题 2.使用explicit解决上述问题总结 1.C的隐式对象转换问题 在C中&#xff0c;隐式对象转换&#xff08;Implicit Object Conversion&#xff09;指的是编译器在不需要程序员明确指示的情况下&#xff0c;自动将对象从一种类型转换…

VMware配置Kali linux + 物理机连接Xshell

VMware 配置 kali linux 首先需要先安装VMware Workstation 我是在Windows 安装的 VMware Workstation Pro 17 虚拟化&#xff0c;产品密钥。。这里不做多说了 下载kali linux 这里我下载的是kali-linux-2024.1 Note&#xff1a;这里选Virtual Machines&#xff0c;建议不要…

景源畅信:抖音小店的商品怎么同步到橱窗?

在数字营销的海洋中&#xff0c;抖音小店与橱窗的同步操作无疑是商家们关注的焦点。这不仅能增加商品的曝光度&#xff0c;还能提高交易的可能性。那么&#xff0c;如何将抖音小店的商品同步到橱窗呢? 一、核心步骤解析 要实现商品从抖音小店同步到橱窗&#xff0c;你需要确保…

【Linux 网络】网络编程套接字 -- 详解

⚪ 预备知识 1、理解源 IP 地址和目的 IP 地址 举例理解&#xff1a;&#xff08;唐僧西天取经&#xff09; 在 IP 数据包头部中 有两个 IP 地址&#xff0c; 分别叫做源 IP 地址 和目的 IP 地址。 如果我们的台式机或者笔记本没有 IP 地址就无法上网&#xff0c;而因为…

Unity引擎是什么?有哪些优点

大家好&#xff0c;我是咕噜土豆&#xff0c;很高兴又和大家见面了。今天我们一起来了解一下Unity引擎和它有哪些优点。 首先带大家了解什么是Unity引擎 Unity引擎是一款由Unity Technologies开发的跨平台游戏开发引擎&#xff0c;广泛用于创建2D和3D游戏以及其他交互式内容&…

C++动态内存区域划分、new、delete关键字

目录 一、C/C中程序的内存区域划分 为什么会存在内存区域划分&#xff1f; 二、new关键字 1、内置类型的new/delete使用方法&#xff1a; 2、new和delete的本质 一、C/C中程序的内存区域划分 为什么会存在内存区域划分&#xff1f; 因为不同数据有不同的存储需求&#xff0…

【SpringBoot记录】从基本使用案例入手了解SpringBoot-数据访问(1)

前言 在程序开发尤其是网页应用开发中&#xff0c;数据访问是必不可少的。通过前面的基本案例我们完成了一个简单的SpringBoot Web应用并对自动配置原理有了一定了解&#xff0c;本节在上述案例基础上&#xff0c;继续编写数据访问案例&#xff0c;将通过SpringBoot中数据访问…

音视频开发6 音视频录制原理和播放原理

音视频录制原理 音视频播放原理

VO、PO、DTO的区别

VO&#xff1a;值对象&#xff0c;用于视图层&#xff0c;它的作用是把某个指定页面&#xff08;或组件&#xff09;的所有数据封装起来。 PO&#xff1a;持久化对象&#xff0c;它跟持久层&#xff08;通常是关系型数据库&#xff09;的数据结构形成一一对应的映射关系&#…

# 电脑突然连接不上网络了,怎么办?

电脑突然连接不上网络了&#xff0c;怎么办&#xff1f; 一、原因分析&#xff1a; 1、IP 地址冲突 2、DNS 解析出现问题。 3、电脑网络设置是否打开了【移动热点】或【飞行模式】。 4、【WLAN AutoConfig】服务是否打开。 5、无线网卡驱动损坏。 6、检查 WIFI 开关是否…

java线程池源码解析:ThreadPoolExecutor源码,execute方法、addWorker方法解析

1. 概述 线程池 的作用不用太说了&#xff0c;线程池会按照一定的规则&#xff0c;创建和维护一定数量的线程。这些线程可以被循环利用&#xff0c;来处理用户提交的任务。对比不同线程池的使用方式&#xff0c;节省了频繁的创建和销毁线程带来的性能开销。 2. 概念理解 2…

从FasterTransformer源码解读开始了解大模型(2.0)代码通读01

从FasterTransformer源码解读开始了解大模型&#xff08;2.0&#xff09;代码解读01-看看头文件 写在前面的话 本篇的内容直接开始我们的代码通读&#xff0c;整个通读可能需要好几篇文章来将一整个gpt的代码结构给讲清楚。目前的计划是先从整体model层次开始讲&#xff0c;将…

Java8 Stream API在集合上执行复杂的数据处理查询

Java 8 引入的 Stream API 是一个高级工具&#xff0c;用于在集合上执行复杂的数据处理查询。Stream API 通过提供一系列的中间操作和最终操作&#xff0c;支持声明式处理&#xff08;类似于SQL声明式语句&#xff09;并且可以轻松使用多核架构。 创建Stream流 创建Stream 流…