WGAN原理及实现(pytorch版)

WGAN原理及实现

  • 一、WGAN原理
    • 1.1 原始GAN的缺陷
    • 1.2 Wasserstein距离的引入
    • 1.3 Kantorovich-Rubinstein对偶
    • 1.4 WGAN的优化目标
    • 1.4 数学推导步骤
    • 1.5 权重裁剪 vs 梯度惩罚
    • 1.6 优势
    • 1.7 总结
  • 二、WGAN实现
    • 2.1 导包
    • 2.2 数据加载和处理
    • 2.3 构建生成器
    • 2.4 构建判别器
    • 2.5 训练和保存模型
    • 2.6 图片转GIF

一、WGAN原理

1.1 原始GAN的缺陷

原始GAN通过最小化JS散度(Jensen-Shannon Divergence)训练生成器(Generator)和判别器(Discriminator),但存在两个关键问题:

  1. 梯度消失:当真实分布 P r P_r Pr 和生成分布 P g P_g Pg 不重叠时,JS散度为常数 log ⁡ 2 \log 2 log2,导致梯度为0,无法更新生成器
  2. 训练不稳定:判别器容易过拟合,生成器难以收敛

1.2 Wasserstein距离的引入

Wasserstein距离(Earth-Mover距离)衡量两个分布 P r \mathbb{P}_r Pr(真实分布)和 P g \mathbb{P}_g Pg(生成分布)之间的差异:

W ( P r , P g ) = inf ⁡ γ ∈ Π ( P r , P g ) E ( x , y ) ∼ γ [ ∥ x − y ∥ ] W(\mathbb{P}_r, \mathbb{P}_g) = \inf_{\gamma \in \Pi(\mathbb{P}_r, \mathbb{P}_g)} \mathbb{E}_{(x,y)\sim \gamma} [\|x - y\|] W(Pr,Pg)=infγΠ(Pr,Pg)E(x,y)γ[xy],其中 Π ( P r , P g ) \Pi(\mathbb{P}_r, \mathbb{P}_g) Π(Pr,Pg) 是联合分布集合,边缘分布分别为 P r \mathbb{P}_r Pr P g \mathbb{P}_g Pg

关键改进:即使两个分布支撑集不重叠,Wasserstein距离仍能提供有意义的梯度

直观解释:衡量将“概率质量”从 P r P_r Pr 搬运到 P g P_g Pg 的最小成本


1.3 Kantorovich-Rubinstein对偶

通过对偶形式将问题转化为:

W ( P r , P g ) = sup ⁡ ∥ f ∥ L ≤ 1 E x ∼ P r [ f ( x ) ] − E x ∼ P g [ f ( x ) ] W(\mathbb{P}_r, \mathbb{P}_g) = \sup_{\|f\|_L \leq 1} \mathbb{E}_{x \sim \mathbb{P}_r} [f(x)] - \mathbb{E}_{x \sim \mathbb{P}_g} [f(x)] W(Pr,Pg)=supfL1ExPr[f(x)]ExPg[f(x)],其中 f f f1-Lipschitz函数(满足 ∣ f ( x ) − f ( y ) ∣ ≤ ∥ x − y ∥ |f(x) - f(y)| \leq \|x - y\| f(x)f(y)xy


1.4 WGAN的优化目标

  • 判别器(Critic):拟合一个1-Lipschitz函数 f w f_w fw,最大化:
    L critic = E x ∼ P r [ f w ( x ) ] − E z ∼ p ( z ) [ f w ( G θ ( z ) ) ] L_{\text{critic}} = \mathbb{E}_{x \sim \mathbb{P}_r} [f_w(x)] - \mathbb{E}_{z \sim p(z)} [f_w(G_\theta(z))] Lcritic=ExPr[fw(x)]Ezp(z)[fw(Gθ(z))]
  • 生成器:最小化Wasserstein距离,即:
    L generator = − E z ∼ p ( z ) [ f w ( G θ ( z ) ) ] L_{\text{generator}} = -\mathbb{E}_{z \sim p(z)} [f_w(G_\theta(z))] Lgenerator=Ezp(z)[fw(Gθ(z))]

关键点

  1. 判别器(称为Critic)输出为标量,无需Sigmoid激活
  2. 通过权重裁剪(强制参数 w w w) 在 [ − c , c ] [-c, c] [c,c] 内)或梯度惩罚(WGAN-GP)近似Lipschitz约束

1.4 数学推导步骤

(1)原始Wasserstein距离

W ( P r , P g ) = inf ⁡ γ E ( x , y ) ∼ γ [ ∥ x − y ∥ ] W(\mathbb{P}_r, \mathbb{P}_g) = \inf_{\gamma} \mathbb{E}_{(x,y)\sim \gamma} [\|x - y\|] W(Pr,Pg)=infγE(x,y)γ[xy]

(2)对偶形式推导

利用线性规划对偶性,转化为: W ( P r , P g ) = sup ⁡ f ∈ 1-Lip E x ∼ P r [ f ( x ) ] − E x ∼ P g [ f ( x ) ] W(\mathbb{P}_r, \mathbb{P}_g) = \sup_{f \in \text{1-Lip}} \mathbb{E}_{x \sim \mathbb{P}_r} [f(x)] - \mathbb{E}_{x \sim \mathbb{P}_g} [f(x)] W(Pr,Pg)=supf1-LipExPr[f(x)]ExPg[f(x)]

(3) 参数化近似
用神经网络 f w f_w fw 近似 f f f,优化: max ⁡ w E x ∼ P r [ f w ( x ) ] − E z ∼ p ( z ) [ f w ( G θ ( z ) ) ] \max_{w} \mathbb{E}_{x \sim \mathbb{P}_r} [f_w(x)] - \mathbb{E}_{z \sim p(z)} [f_w(G_\theta(z))] maxwExPr[fw(x)]Ezp(z)[fw(Gθ(z))]

(4)生成器优化
固定 f w f_w fw,生成器最小化: min ⁡ θ − E z ∼ p ( z ) [ f w ( G θ ( z ) ) ] \min_{\theta} -\mathbb{E}_{z \sim p(z)} [f_w(G_\theta(z))] minθEzp(z)[fw(Gθ(z))]


1.5 权重裁剪 vs 梯度惩罚

  • 权重裁剪(原始WGAN):
    强制参数 w w w [ − c , c ] [-c, c] [c,c] 内,但可能导致梯度消失或爆炸
  • 梯度惩罚(WGAN-GP):
    添加正则项: λ E x ^ ∼ P x ^ [ ( ∥ ∇ x ^ f w ( x ^ ) ∥ 2 − 1 ) 2 ] \lambda \mathbb{E}_{\hat{x} \sim \mathbb{P}_{\hat{x}}} [(\|\nabla_{\hat{x}} f_w(\hat{x})\|_2 - 1)^2] λEx^Px^[(x^fw(x^)21)2],其中 x ^ \hat{x} x^ 是真实样本和生成样本的随机插值

1.6 优势

  1. 训练信号:Critic的损失值与生成样本质量相关(越低表示越真实)
  2. 训练稳定性:避免模式崩溃(Mode Collapse)
  3. 梯度有意义:即使分布不重叠,仍能提供有效梯度
  4. 生成质量高:Wasserstein距离直接反映生成数据与真实数据的差异

1.7 总结

WGAN通过Wasserstein距离的优良性质,解决了传统GAN的训练难题。其数学核心在于对偶形式的转化和Lipschitz约束的实现,后续改进(如WGAN-GP)进一步提升了性能。


二、WGAN实现

2.1 导包

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as npimport os
import time
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm  
from torchsummary import summary# 判断是否存在可用的GPU
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 指定存放日志路径
writer=SummaryWriter(log_dir="./runs/wgan")os.makedirs("./img/wgan_mnist", exist_ok=True) # 存放生成样本目录
os.makedirs("./model", exist_ok=True) # 模型存放目录

2.2 数据加载和处理

# 加载 MNIST 数据集
def load_data(batch_size=64,img_shape=(1,28,28)):transform = transforms.Compose([transforms.ToTensor(),  # 将图像转换为张量transforms.Normalize(mean=[0.5], std=[0.5])  # 归一化到[-1,1]])# 下载训练集和测试集train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 创建 DataLoadertrain_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=2,shuffle=True)test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=2,shuffle=False)return train_loader, test_loader

2.3 构建生成器

class Generator(nn.Module):"""生成器"""def __init__(self, latent_dim=100,img_shape=(1,28,28)):super(Generator,self).__init__()# 网络块def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat))layers.append(nn.LeakyReLU(negative_slope=0.2, inplace=True))return layersself.model = nn.Sequential(*block(latent_dim, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, int(np.prod(img_shape))),nn.Tanh() # 输出归一化到[-1,1] )def forward(self,z): # 噪声z,2维[batch_size,latent_dim]gen_img=self.model(z) gen_img=gen_img.view(gen_img.shape[0],*img_shape)return gen_img # 4维[batch_size,1,H,W]

2.4 构建判别器

class Discriminator(nn.Module):"""判别器"""def __init__(self,img_shape=(1,28,28)):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512),nn.LeakyReLU(negative_slope=0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(negative_slope=0.2, inplace=True),nn.Linear(256, 1))def forward(self,img): # 输入图片,4维[batc_size,1,H,W]img=img.view(img.shape[0], -1) pred = self.model(img)return pred # 2维[batch_size,1] 

2.5 训练和保存模型

  • WGAN算法流程
  • 训练和保存
# 设置超参数
batch_size = 64
epochs = 200
lr= 0.00005
latent_dim=100 # 生成器输入噪声向量的长度(维数)
sample_interval=400 #每400次迭代保存生成样本# WGAN的特别设置
num_iter_critic = 5
weight_clip_value = 0.01# 设置图片形状1*28*28
img_shape = (1,28,28)# 加载数据
train_loader,_= load_data(batch_size=batch_size,img_shape=img_shape)# 实例化生成器G、判别器D
G=Generator().to(device)
D=Discriminator().to(device)# 设置优化器
optimizer_G = torch.optim.RMSprop(G.parameters(), lr=lr)
optimizer_D = torch.optim.RMSprop(D.parameters(), lr=lr)# 开始训练
batches_done=0
loader_len=len(train_loader) #训练集加载器的长度
for epoch in range(epochs):# 进入训练模式G.train()D.train()loop = tqdm(train_loader, desc=f"第{epoch+1}轮")for i, (real_imgs, _) in enumerate(loop):real_imgs=real_imgs.to(device)  # [B,C,H,W]# -----------------#  训练判别器# -----------------# 获取噪声样本[B,latent_dim)z=torch.normal(0,1,size=(real_imgs.shape[0],latent_dim),device=device)  #从正态分布中抽样# Step-1 计算判断器损失=判断真实图片损失+判断生成图片损失fake_imgs=G(z).detach()dis_loss=-torch.mean(D(real_imgs)) + torch.mean(D(fake_imgs))# Step-2 更新判别器参数optimizer_D.zero_grad() # 梯度清零dis_loss.backward() #反向传播,计算梯度optimizer_D.step()  #更新判别器 # Step-3 对判别器进行权重裁剪for p in D.parameters():p.data.clamp_(-weight_clip_value,weight_clip_value)# -----------------#  训练生成器# -----------------# 判别器每迭代 num_iter_critic 次,生成器迭代一次if i % num_iter_critic ==0 :gen_imgs=G(z).detach()# 更新生成器参数optimizer_G.zero_grad() #梯度清零gen_loss=-torch.mean(D(gen_imgs))gen_loss.backward() #反向传播,计算梯度optimizer_G.step()  #更新生成器  # 更新进度条loop.set_postfix(gen_loss=f"{gen_loss:.8f}",dis_loss=f"{dis_loss:.8f}")# 每 sample_interval 次迭代保存生成样本if batches_done % sample_interval == 0:save_image(gen_imgs.data[:25], f"./img/wgan_mnist/{epoch}_{i}.png", nrow=5, normalize=True)batches_done += 1print('总共训练用时: %.2f min' % ((time.time() - start_time)/60))#仅保存模型的参数(权重和偏置),灵活性高,可以在不同的模型结构之间加载参数
torch.save(G.state_dict(), "./model/WGAN_G.pth") 
torch.save(D.state_dict(), "./model/WGAN_D.pth") 

2.6 图片转GIF

from PIL import Imagedef create_gif(img_dir="./img/wgan_mnist", output_file="./img/wgan_mnist/wgan_figure.gif", duration=100):images = []img_paths = [f for f in os.listdir(img_dir) if f.endswith(".png")]# 自定义排序:按 "x_y.png" 的 x 和 y 排序img_paths_sorted = sorted(img_paths,key=lambda x: (int(x.split('_')[0]),  # 第一个数字(如 0_400.png 的 0)int(x.split('_')[1].split('.')[0])  # 第二个数字(如 0_400.png 的 400)))for img_file in img_paths_sorted:img = Image.open(os.path.join(img_dir, img_file))images.append(img)images[0].save(output_file, save_all=True, append_images=images[1:], duration=duration, loop=0)print(f"GIF已保存至 {output_file}")
create_gif()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/75469.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Unity网络开发基础 (3) Socket入门 TCP同步连接 与 简单封装练习

本文章不作任何商业用途 仅作学习与交流 教程来自Unity唐老狮 关于练习题部分是我观看教程之后自己实现 所以和老师写法可能不太一样 唐老师说掌握其基本思路即可,因为前端程序一般不需要去写后端逻辑 1.认识Socket的重要API Socket是什么 Socket(套接字&#xff0…

【linux】一文掌握 ssh和scp 指令的详细用法(ssh和scp 备忘速查)

文章目录 入门连接执行SCP配置位置SCP 选项配置示例ProxyJumpssh-copy-id SSH keygenssh-keygen产生钥匙类型known_hosts密钥格式 此快速参考备忘单提供了使用 SSH 的各种方法。 参考: OpenSSH 配置文件示例 (cyberciti.biz)ssh_config (linux.die.net) 入门 连…

真实笔试题

文章目录 线程题树的深度遍历 线程题 实现一个类支持100个线程同时向一个银行账户中存入一元钱.需通过同步机制消除竞态条件,当所有线程执行完成后,账户余额必须精确等于100元 package com.itheima.thread;public class ShowMeBug {private double balance; // 账户余额priva…

2.2 路径问题专题:LeetCode 63. 不同路径 II

动态规划解决LeetCode 63题:不同路径 II(含障碍物) 1. 题目链接 LeetCode 63. 不同路径 II 2. 题目描述 一个机器人位于 m x n 网格的左上角,每次只能向右或向下移动一步。网格中可能存在障碍物(标记为 1&#xff…

2874. 有序三元组中的最大值 II

给你一个下标从 0 开始的整数数组 。nums 请你从所有满足 的下标三元组 中&#xff0c;找出并返回下标三元组的最大值。 如果所有满足条件的三元组的值都是负数&#xff0c;则返回 。i < j < k(i, j, k)0 下标三元组 的值等于 。(i, j, k)(nums[i] - nums[j]) * nums[k…

【论文笔记】Llama 3 技术报告

Llama 3中的顶级模型是一个拥有4050亿参数的密集Transformer模型&#xff0c;并且它的上下文窗口长度可以达到128,000个tokens。这意味着它能够处理非常长的文本&#xff0c;记住和理解更多的信息。Llama 3.1的论文长达92页&#xff0c;详细描述了模型的开发阶段、优化策略、模…

JVM深入原理(一+二):JVM概述和JVM功能

目录 1. JVM概述 1.1. Java程序结构 1.2. JVM作用 1.3. JVM规范和实现 2. JVM功能 2.1. 功能-编译和运行 2.2. 功能-内存管理 2.3. 功能-即时编译 1. JVM概述 1.1. Java程序结构 1.2. JVM作用 JVM全称是Java Virtual Machine-Java虚拟机 JVM作用:本质上是一个运行在…

SQL Server Integration Services (SSIS) 服务无法启动

问题现象&#xff1a; 安装 SQL Server 2022 后&#xff0c;SQL Server Integration Services (SSIS) 服务无法启动&#xff0c;日志报错 “服务无法响应控制请求”&#xff08;错误代码 1067&#xff09;或 “依赖服务不存在或已标记为删除”。 快速诊断 检查服务状态与依赖项…

Spring Boot 定时任务的多种实现方式

&#x1f31f; 前言 欢迎来到我的技术小宇宙&#xff01;&#x1f30c; 这里不仅是我记录技术点滴的后花园&#xff0c;也是我分享学习心得和项目经验的乐园。&#x1f4da; 无论你是技术小白还是资深大牛&#xff0c;这里总有一些内容能触动你的好奇心。&#x1f50d; &#x…

Java基础之反射的基本使用

简介 在运行状态中&#xff0c;对于任意一个类&#xff0c;都能够知道这个类的所有属性和方法&#xff1b;对于任意一个对象&#xff0c;都能够调用它的任意属性和方法&#xff1b;这种动态获取信息以及动态调用对象方法的功能称为Java语言的反射机制。反射让Java成为了一门动…

AI产品的上层建筑:提示词工程、RAG与Agent

上节课我们拆解了 AI 产品的基础设施建设&#xff0c;这节课我们聊聊上层建筑。这部分是产品经理日常工作的重头戏&#xff0c;包含提示词、RAG 和 Agent 构建。 用 AI 客服产品举例&#xff0c;这三者的作用是这样的&#xff1a; 提示词能让客服很有礼貌。比如它会说&#x…

蓝桥杯刷题记录【并查集001】(2024)

主要内容&#xff1a;并查集 并查集 并查集的题目感觉大部分都是模板题&#xff0c;上板子&#xff01;&#xff01; class UnionFind:def __init__(self, n):self.pa list(range(n))self.size [1]*n self.cnt ndef find(self, x):if self.pa[x] ! x:self.pa[x] self.fi…

海外SD-WAN专线网络部署成本分析

作为支撑企业国际业务的重要基石&#xff0c;海外SD-WAN专线以其独特的成本优势和技术特性&#xff0c;正成为企业构建高效稳定的全球网络架构的首选方案。本文将从多维度解构海外SD-WAN专线部署的核心成本要素&#xff0c;为企业的全球化网络布局提供战略参考。 一、基础资源投…

操作系统(二):实时系统介绍与实例分析

目录 一.概念 1.1 分类 1.2 主要指标 二.实现原理 三.主流实时系统对比 一.概念 实时系统&#xff08;Real-Time System, RTS&#xff09;是一类以时间确定性为核心目标的计算机系统&#xff0c;其设计需确保在严格的时间约束内完成任务响应。 1.1 分类 根据时间约束的严…

Golang的消息中间件选型

# Golang的消息中间件选型 消息中间件的作用 消息中间件是一种用于分布式系统中应用程序之间进行通信的基础架构工具&#xff0c;它能够有效地解耦发送者和接收者&#xff0c;并提供高可用性和可靠性的消息传递机制。在Golang应用程序中&#xff0c;选择适合的消息中间件对于构…

大模型中的参数规模与显卡匹配

在大模型训练和推理中&#xff0c;显卡&#xff08;GPU/TPU&#xff09;的选择与模型参数量紧密相关&#xff0c;需综合考虑显存、计算能力和成本。以下是不同规模模型与硬件的匹配关系及优化策略&#xff1a; 一、参数规模与显卡匹配参考表 模型参数量训练阶段推荐显卡推理阶…

带头结点 的单链表插入方法(头插法与尾插法)

带头结点的单链表插入方法&#xff08;头插法与尾插法&#xff09; 在单链表的操作中&#xff0c;插入是最常见的操作之一&#xff0c;本文介绍 带头结点的单链表 如何实现 后插法 和 前插法&#xff08;包括 插入法 和 后插数据交换法&#xff09;&#xff0c;并提供完整的 C …

Prometheus的工作流程

Prometheus 是一个开源的监控和告警系统&#xff0c;专为监控分布式系统而设计。它的工作流程主要包括以下几个关键步骤&#xff1a; 1. 数据采集 (Scraping) 目标发现 (Service Discovery)&#xff1a; Prometheus 自动或手动配置监控目标&#xff0c;通过 DNS、Kubernetes、…

软件工程面试题(二十二)

1、常用的设计模式有哪些&#xff1f;并写出一段程序代码 Factory(工厂模式)&#xff0c;Adapter(适配器模式)&#xff0c;Singleton(单例模式)&#xff0c;State(状态模式)&#xff0c;Observer(观察者模式) 等。 单例模式 public class Singleton{ private static Singleton …

【Pandas】pandas DataFrame select_dtypes

Pandas2.2 DataFrame Attributes and underlying data 方法描述DataFrame.index用于获取 DataFrame 的行索引DataFrame.columns用于获取 DataFrame 的列标签DataFrame.dtypes用于获取 DataFrame 中每一列的数据类型DataFrame.info([verbose, buf, max_cols, …])用于提供 Dat…