Pytorch深度学习框架60天进阶学习计划 - 第41天:生成对抗网络进阶(一)
今天我们将深入探讨生成对抗网络(GAN)的进阶内容,特别是Wasserstein GAN(WGAN)的梯度惩罚机制,以及条件生成与无监督生成在模式坍塌方面的差异。
生成对抗网络是近年来深度学习领域最激动人心的进展之一,它由Ian Goodfellow于2014年提出,通过生成器和判别器的博弈来学习生成真实数据分布的样本。随着研究的深入,GAN的改进版本层出不穷,其中WGAN及其梯度惩罚版本(WGAN-GP)解决了原始GAN训练不稳定的问题,成为了GAN研究的重要里程碑。
今天我们将从理论到实践,系统地学习这些进阶概念,并通过PyTorch实现相关模型,探索其工作原理。
1. GAN基础回顾
在深入WGAN之前,让我们简要回顾GAN的基本原理:
1.1 GAN的基本架构
GAN由两部分组成:
- 生成器(Generator): 学习从随机噪声生成看起来真实的数据
- 判别器(Discriminator): 学习区分真实数据和生成器生成的假数据
这两个网络通过对抗训练相互提高:生成器尝试生成越来越逼真的样本以欺骗判别器,而判别器则努力提高其区分真假样本的能力。
1.2 原始GAN的问题
虽然GAN的思想非常优雅,但原始GAN在训练过程中存在一些问题:
- 训练不稳定:很难找到生成器和判别器之间的平衡点
- 梯度消失:当判别器表现过好时,生成器梯度接近于零
- 模式坍塌:生成器只生成有限种类的样本,无法覆盖真实数据的全部分布
- 难以量化训练进度:缺乏有效的指标来衡量生成样本的质量
这些问题促使研究者寻找GAN的改进版本,其中WGAN是最重要的改进之一。
2. Wasserstein GAN详解
2.1 从JS散度到Wasserstein距离
原始GAN隐式地最小化生成分布与真实分布之间的Jensen-Shannon(JS)散度,这在两个分布没有显著重叠时会导致梯度问题。
Wasserstein距离(也称Earth Mover’s Distance,简称EMD)提供了一种更平滑的度量方式,即使两个分布没有重叠或重叠很少,也能提供有意义的梯度。
Wasserstein距离的直观解释:想象将一个分布的概率质量移动到另一个分布所需的最小"工作量",其中工作量定义为概率质量乘以移动距离。
2.2 WGAN的核心改进
WGAN相比原始GAN有以下关键改进:
- 目标函数改变:使用Wasserstein距离而非JS散度
- 判别器(现称为评论家/Critic)输出不再是概率:移除了最后的sigmoid激活函数
- 权重裁剪:限制评论家的参数在一定范围内,满足Lipschitz约束
- 避免使用基于动量的优化器:建议使用RMSProp或Adam优化器(学习率较小)
2.3 WGAN的目标函数
WGAN的目标函数如下:
min G max D ∈ D E x ∼ P r [ D ( x ) ] − E z ∼ P z [ D ( G ( z ) ) ] \min_G \max_{D \in \mathcal{D}} \mathbb{E}_{x \sim P_r}[D(x)] - \mathbb{E}_{z \sim P_z}[D(G(z))] GminD∈DmaxEx∼Pr[D(x)]−Ez∼Pz[D(G(z))]
其中 D \mathcal{D} D是满足1-Lipschitz约束的函数集合。
2.4 Lipschitz约束与权重裁剪
为了满足Wasserstein距离计算中的Lipschitz约束,WGAN对评论家的参数进行了权重裁剪:将权重限制在 [ − c , c ] [-c, c] [−c,c]的范围内,其中 c c c是一个小常数(如0.01)。
然而,权重裁剪是一种粗糙的方法,会导致优化问题和容量浪费。这就引出了WGAN的进一步改进:梯度惩罚机制。
3. WGAN的梯度惩罚机制
3.1 权重裁剪的局限性
WGAN中的权重裁剪虽然简单有效,但存在以下问题:
- 容量浪费:强制权重接近0或c,导致模型倾向于使用更简单的函数
- 优化困难:可能导致梯度爆炸或消失
- 对架构敏感:不同网络架构可能需要不同的裁剪范围
3.2 梯度惩罚的原理
WGAN-GP(带梯度惩罚的WGAN)提出了一种更优雅的方式来满足Lipschitz约束。其核心思想是:
对于一个1-Lipschitz函数,其梯度范数在任何地方都不应超过1。因此,我们可以通过惩罚评论家函数梯度范数偏离1的行为来满足这一约束。
具体来说,WGAN-GP在真实数据和生成数据之间的随机插值点上施加梯度惩罚:
L G P = E x ^ ∼ P x ^ [ ( ∣ ∣ ∇ x ^ D ( x ^ ) ∣ ∣ 2 − 1 ) 2 ] \mathcal{L}_{GP} = \mathbb{E}_{\hat{x} \sim P_{\hat{x}}}[(||\nabla_{\hat{x}}D(\hat{x})||_2 - 1)^2] LGP=Ex^∼Px^[(∣∣∇x^D(x^)∣∣2−1)2]
其中 x ^ \hat{x} x^是在真实样本 x x x和生成样本 G ( z ) G(z) G(z)之间的随机插值:
x ^ = ϵ x + ( 1 − ϵ ) G ( z ) \hat{x} = \epsilon x + (1-\epsilon)G(z) x^=ϵx+(1−ϵ)G(z)
ϵ \epsilon ϵ是一个在 [ 0 , 1 ] [0,1] [0,1]之间均匀采样的随机数。
3.3 WGAN-GP的完整目标函数
将梯度惩罚添加到WGAN的目标函数中,我们得到WGAN-GP的目标函数:
L = E z ∼ p ( z ) [ D ( G ( z ) ) ] − E x ∼ p d a t a [ D ( x ) ] + λ E x ^ ∼ P x ^ [ ( ∣ ∣ ∇ x ^ D ( x ^ ) ∣ ∣ 2 − 1 ) 2 ] \mathcal{L} = \mathbb{E}_{z \sim p(z)}[D(G(z))] - \mathbb{E}_{x \sim p_{data}}[D(x)] + \lambda \mathbb{E}_{\hat{x} \sim P_{\hat{x}}}[(||\nabla_{\hat{x}}D(\hat{x})||_2 - 1)^2] L=Ez∼p(z)[D(G(z))]−Ex∼pdata[D(x)]+λEx^∼Px^[(∣∣∇x^D(x^)∣∣2−1)2]
其中 λ \lambda λ是梯度惩罚的权重,通常设为10。
3.4 WGAN-GP的优势
WGAN-GP相比WGAN有以下优势:
- 更好的稳定性:避免了权重裁剪带来的问题
- 更快的收敛:通常需要更少的迭代次数
- 更好的生成质量:能生成更多样、更高质量的样本
- 架构灵活性:适用于各种GAN架构,包括深度卷积网络
4. PyTorch实现WGAN-GP
下面我们使用PyTorch实现一个简单的WGAN-GP模型,用于生成MNIST手写数字。
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt# 设置随机种子,确保结果可复现
torch.manual_seed(42)
np.random.seed(42)# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")# 超参数
batch_size = 64
lr = 0.0002
n_epochs = 50
latent_dim = 100
img_shape = (1, 28, 28)
lambda_gp = 10 # 梯度惩罚权重# 数据加载
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5]) # 归一化到[-1, 1]
])mnist_dataset = torchvision.datasets.MNIST(root='./data',train=True,transform=transform,download=True
)dataloader = DataLoader(mnist_dataset,batch_size=batch_size,shuffle=True,num_workers=2
)# 生成器网络
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_features, out_features, normalize=True):layers = [nn.Linear(in_features, out_features)]if normalize:layers.append(nn.BatchNorm1d(out_features, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*block(latent_dim, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, int(np.prod(img_shape))),nn.Tanh() # 输出归一化到[-1, 1])def forward(self, z):img = self.model(z)img = img.view(img.size(0), *img_shape)return img# 判别器网络(评论家)
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1)# 注意:没有sigmoid激活函数)def forward(self, img):img_flat = img.view(img.size(0), -1)validity = self.model(img_flat)return validity# 初始化网络
generator = Generator().to(device)
discriminator = Discriminator().to(device)# 初始化优化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.9))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.9))# 计算梯度惩罚
def compute_gradient_penalty(D, real_samples, fake_samples):"""计算WGAN-GP中的梯度惩罚"""# 在真实样本和生成样本之间随机插值alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)# 计算插值点的判别器输出d_interpolates = D(interpolates)# 计算梯度fake = torch.ones(d_interpolates.size(), device=device, requires_grad=False)gradients = torch.autograd.grad(outputs=d_interpolates,inputs=interpolates,grad_outputs=fake,create_graph=True,retain_graph=True,only_inputs=True)[0]# 计算梯度范数gradients = gradients.view(gradients.size(0), -1)gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()return gradient_penalty# 训练函数
def train_wgan_gp():# 用于记录损失d_losses = []g_losses = []for epoch in range(n_epochs):for i, (real_imgs, _) in enumerate(dataloader):real_imgs = real_imgs.to(device)batch_size = real_imgs.shape[0]# ---------------------# 训练判别器# ---------------------optimizer_D.zero_grad()# 生成随机噪声z = torch.randn(batch_size, latent_dim, device=device)# 生成一批假图像fake_imgs = generator(z)# 判别器前向传播real_validity = discriminator(real_imgs)fake_validity = discriminator(fake_imgs.detach())# 计算梯度惩罚gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data)# WGAN-GP 判别器损失d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penaltyd_loss.backward()optimizer_D.step()# 每n_critic次迭代训练一次生成器n_critic = 5if i % n_critic == 0:# ---------------------# 训练生成器# ---------------------optimizer_G.zero_grad()# 生成一批新的假图像gen_imgs = generator(z)# 判别器评估假图像fake_validity = discriminator(gen_imgs)# WGAN 生成器损失g_loss = -torch.mean(fake_validity)g_loss.backward()optimizer_G.step()if i % 50 == 0:print(f"[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] "f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")d_losses.append(d_loss.item())g_losses.append(g_loss.item())# 每个epoch结束后保存生成的图像样本if (epoch + 1) % 10 == 0:save_sample_images(epoch)# 绘制损失曲线plt.figure(figsize=(10, 5))plt.plot(d_losses, label='Discriminator Loss')plt.plot(g_losses, label='Generator Loss')plt.xlabel('Iterations (x50)')plt.ylabel('Loss')plt.legend()plt.savefig('wgan_gp_loss.png')plt.close()# 保存样本图像
def save_sample_images(epoch):# 生成并保存样本图像z = torch.randn(25, latent_dim, device=device)gen_imgs = generator(z).detach().cpu()# 将图像像素值从[-1, 1]转换为[0, 1]gen_imgs = 0.5 * gen_imgs + 0.5# 创建图像网格fig, axs = plt.subplots(5, 5, figsize=(10, 10))for i in range(5):for j in range(5):axs[i, j].imshow(gen_imgs[i*5+j, 0, :, :], cmap='gray')axs[i, j].axis('off')# 保存图像plt.savefig(f'wgan_gp_epoch_{epoch+1}.png')plt.close()# 运行训练
if __name__ == "__main__":train_wgan_gp()
这段代码实现了一个基本的WGAN-GP模型,用于生成MNIST数字图像。下面我们来解析代码的关键部分:
- 梯度惩罚计算:
compute_gradient_penalty
函数实现了WGAN-GP的核心——在真实样本和生成样本之间的插值点上计算梯度惩罚。 - 判别器损失:包括真实数据的评论家值、生成数据的评论家值,以及梯度惩罚项。
- 生成器损失:仅包含生成数据的评论家值的负期望。
- 优化器设置:使用Adam优化器,但β1参数设为0.5,这是GAN训练的常见设置。
- 训练循环:判别器和生成器交替训练,但判别器通常训练多次(n_critic=5)后才训练一次生成器。
5. WGAN-GP训练流程图
以下是WGAN-GP的训练流程图,帮助理解整个训练过程:
┌────────────────────┐
│ 初始化网络和优化器 │
└──────────┬─────────┘│▼
┌────────────────────┐
│ 开始训练循环 │
└──────────┬─────────┘│▼
┌────────────────────┐
│ 从数据集加载真实样本 │
└──────────┬─────────┘│▼
┌────────────────────┐
│ 生成随机噪声并产生 │
│ 假样本 │
└──────────┬─────────┘│▼
┌────────────────────┐
│ 计算判别器对真实 │
│ 和假样本的输出 │
└──────────┬─────────┘│▼
┌────────────────────┐
│ 在样本插值点上计算 │
│ 梯度惩罚 │
└──────────┬─────────┘│▼
┌────────────────────┐
│ 计算判别器损失 │
│ 并更新判别器参数 │
└──────────┬─────────┘│▼┌────┴─────┐│ i % n_critic ││ == 0? │└────┬─────┘No │ Yes┌─────────┘ └──────────┐│ ▼│ ┌────────────────────┐│ │ 重新生成假样本 ││ └──────────┬─────────┘│ ││ ▼│ ┌────────────────────┐│ │ 计算生成器损失 ││ │ 并更新生成器参数 ││ └──────────┬─────────┘│ │└─────────────────────────┘│▼
┌────────────────────┐
│ 是否达到预定训练轮数? │
└──────────┬─────────┘No │ Yes┌────┘ └──────────┐│ ▼│ ┌────────────────────┐└──────▶ │ 结束训练 │└────────────────────┘
这个流程图展示了WGAN-GP的训练过程,包括梯度惩罚的计算和判别器多次训练的机制。与普通GAN相比,WGAN-GP的关键区别在于梯度惩罚的引入和目标函数的改变。
6. 条件生成与无监督生成的对比
接下来,我们将探讨条件生成与无监督生成在模式坍塌方面的差异。
6.1 无监督生成与模式坍塌
无监督生成是指生成器仅从随机噪声生成样本,没有额外的条件输入。
模式坍塌(Mode Collapse)是GAN训练中的常见问题,指生成器只学会生成数据分布中的少数几种模式,而忽略了其他模式。例如,在MNIST数据集上,模型可能只生成数字"1"而不生成其他数字。
导致模式坍塌的原因:
- 判别器更新不足:判别器无法有效区分真假样本
- 梯度消失:当判别器表现过好时,生成器梯度接近零
- 目标函数设计问题:JS散度在两个分布不重叠时提供有限的梯度信息
6.2 条件生成对模式坍塌的缓解
条件生成是指生成器不仅接收随机噪声,还接收额外的条件信息(如类别标签)作为输入。
条件GAN(CGAN)通过以下方式缓解模式坍塌:
- 强制生成器覆盖所有类别:通过提供不同的类别条件,迫使生成器学习生成不同类别的样本
- 简化学习任务:条件信息使生成器只需要学习条件分布,而非整个联合分布
- 提供更多监督信号:条件信息为生成器提供了额外的指导
6.3 条件生成与无监督生成的模式坍塌差异表
特性 | 无监督生成 | 条件生成 |
---|---|---|
输入 | 仅随机噪声 | 随机噪声 + 条件信息 |
模式覆盖 | 容易忽略部分模式 | 被条件强制覆盖更多模式 |
生成样本多样性 | 较低,倾向于生成相似样本 | 较高,不同条件生成不同样本 |
训练稳定性 | 较差,易发生模式坍塌 | 较好,条件信息提供稳定指导 |
应用灵活性 | 生成过程不可控 | 可控制生成特定类别/属性的样本 |
实现复杂度 | 相对简单 | 需要额外的条件嵌入机制 |
7. 实现条件WGAN-GP
下面我们将实现一个条件版本的WGAN-GP,以比较其与无监督版本在模式坍塌方面的差异。
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")# 超参数
batch_size = 64
lr = 0.0002
n_epochs = 50
latent_dim = 100
img_shape = (1, 28, 28)
n_classes = 10 # MNIST有10个类别
lambda_gp = 10 # 梯度惩罚权重# 数据加载
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5]) # 归一化到[-1, 1]
])mnist_dataset = torchvision.datasets.MNIST(root='./data',train=True,transform=transform,download=True
)dataloader = DataLoader(mnist_dataset,batch_size=batch_size,shuffle=True,num_workers=2
)# 条件生成器网络
class ConditionalGenerator(nn.Module):def __init__(self):super(ConditionalGenerator, self).__init__()# 嵌入层将类别标签转换为嵌入向量self.label_embedding = nn.Embedding(n_classes, n_classes)# 输入层处理噪声和类别嵌入self.input_layer = nn.Linear(latent_dim + n_classes, 128)# 主要模型self.model = nn.Sequential(nn.BatchNorm1d(128, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Linear(128, 256),nn.BatchNorm1d(256, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 512),nn.BatchNorm1d(512, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 1024),nn.BatchNorm1d(1024, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Linear(1024, int(np.prod(img_shape))),nn.Tanh())def forward(self, noise, labels):# 将标签嵌入向量与噪声拼接label_embedding = self.label_embedding(labels)x = torch.cat([noise, label_embedding], dim=1)# 通过输入层x = self.input_layer(x)# 通过主模型x = self.model(x)# 重塑为图像格式img = x.view(x.size(0), *img_shape)return img# 条件判别器网络
class ConditionalDiscriminator(nn.Module):def __init__(self):super(ConditionalDiscriminator, self).__init__()# 嵌入层将类别标签转换为嵌入向量self.label_embedding = nn.Embedding(n_classes, n_classes)# 处理图像和标签self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)) + n_classes, 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1))def forward(self, img, labels):# 将图像展平img_flat = img.view(img.size(0), -1)# 获取标签嵌入label_embedding = self.label_embedding(labels)# 拼接图像特征和标签嵌入x = torch.cat([img_flat, label_embedding], dim=1)# 通过判别器网络validity = self.model(x)return validity# 初始化网络
generator = ConditionalGenerator().to(device)
discriminator = ConditionalDiscriminator().to(device)# 初始化优化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.9))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.9))# 计算梯度惩罚(条件版本)
def compute_gradient_penalty(D, real_samples, fake_samples, labels):"""计算条件WGAN-GP的梯度惩罚"""# 在真实样本和生成样本之间随机插值alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)# 计算插值点的判别器输出(带条件)d_interpolates = D(interpolates, labels)# 计算梯度fake = torch.ones(d_interpolates.size(), device=device, requires_grad=False)gradients = torch.autograd.grad(outputs=d_interpolates,inputs=interpolates,grad_outputs=fake,create_graph=True,retain_graph=True,only_inputs=True)[0]# 计算梯度范数gradients = gradients.view(gradients.size(0), -1)gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()return gradient_penalty# 训练条件WGAN-GP
def train_conditional_wgan_gp():# 用于记录损失d_losses = []g_losses = []# 用于记录生成样本的多样性(通过类别分布)class_distributions = []for epoch in range(n_epochs):for i, (real_imgs, labels) in enumerate(dataloader):real_imgs = real_imgs.to(device)labels = labels.to(device)batch_size = real_imgs.shape[0]# ---------------------# 训练判别器# ---------------------optimizer_D.zero_grad()# 生成随机噪声z = torch.randn(batch_size, latent_dim, device=device)# 为生成器生成随机标签gen_labels = torch.randint(0, n_classes, (batch_size,), device=device)# 生成一批假图像fake_imgs = generator(z, gen_labels)# 判别器前向传播real_validity = discriminator(real_imgs, labels)fake_validity = discriminator(fake_imgs.detach(), gen_labels)# 计算梯度惩罚gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data, labels)# WGAN-GP 判别器损失d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penaltyd_loss.backward()optimizer_D.step()# 每n_critic次迭代训练一次生成器n_critic = 5if i % n_critic == 0:# ---------------------# 训练生成器# ---------------------optimizer_G.zero_grad()# 为生成器生成新的随机标签gen_labels = torch.randint(0, n_classes, (batch_size,), device=device)# 生成一批新的假图像gen_imgs = generator(z, gen_labels)# 判别器评估假图像fake_validity = discriminator(gen_imgs, gen_labels)# WGAN 生成器损失g_loss = -torch.mean(fake_validity)g_loss.backward()optimizer_G.step()if i % 50 == 0:print(f"[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] "f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")d_losses.append(d_loss.item())g_losses.append(g_loss.item())# 每个epoch结束后,评估生成样本的类别分布if (epoch + 1) % 10 == 0:class_dist = evaluate_class_distribution()class_distributions.append(class_dist)# 保存生成的图像样本save_sample_images(epoch)# 绘制损失曲线plt.figure(figsize=(10, 5))plt.plot(d_losses, label='Discriminator Loss')plt.plot(g_losses, label='Generator Loss')plt.xlabel('Iterations (x50)')plt.ylabel('Loss')plt.legend()plt.savefig('cond_wgan_gp_loss.png')plt.close()# 绘制类别分布变化plot_class_distributions(class_distributions)# 评估生成样本的类别分布
def evaluate_class_distribution():"""评估生成样本在各类别上的分布情况"""# 创建一个预训练的分类器classifier = torchvision.models.resnet18(pretrained=True)# 修改第一个卷积层以适应灰度图classifier.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)# 修改最后的全连接层以适应MNIST的10个类别classifier.fc = nn.Linear(classifier.fc.in_features, 10)# 加载预先训练好的MNIST分类器权重(这里假设我们有一个预训练的模型)# classifier.load_state_dict(torch.load('mnist_classifier.pth'))# 简化起见,这里我们使用一个简单的CNN分类器classifier = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(64 * 7 * 7, 128),nn.ReLU(),nn.Linear(128, 10)).to(device)# 这里假设这个简单分类器已经在MNIST上训练好了# 实际应用中,应该加载一个预先训练好的模型# 生成1000个样本z = torch.randn(1000, latent_dim, device=device)# 均匀采样所有类别gen_labels = torch.tensor([i % 10 for i in range(1000)], device=device)gen_imgs = generator(z, gen_labels)# 使用分类器预测类别with torch.no_grad():classifier.eval()preds = torch.softmax(classifier(gen_imgs), dim=1)pred_labels = torch.argmax(preds, dim=1)# 计算每个类别的样本数量class_counts = torch.zeros(10)for i in range(10):class_counts[i] = (pred_labels == i).sum().item() / 1000return class_counts.numpy()# 绘制类别分布变化
def plot_class_distributions(class_distributions):"""绘制生成样本类别分布的变化"""epochs = [10, 20, 30, 40, 50] # 假设每10个epoch评估一次plt.figure(figsize=(12, 8))for i, dist in enumerate(class_distributions):plt.subplot(len(class_distributions), 1, i+1)plt.bar(np.arange(10), dist)plt.ylabel(f'Epoch {epochs[i]}')plt.ylim(0, 0.3) # 限制y轴范围,便于比较if i == len(class_distributions) - 1:plt.xlabel('Digit Class')plt.tight_layout()plt.savefig('class_distribution.png')plt.close()# 保存样本图像(条件版本)
def save_sample_images(epoch):"""保存按类别排列的样本图像"""# 为每个类别生成样本n_row = 10 # 每个类别一行n_col = 10 # 每个类别10个样本fig, axs = plt.subplots(n_row, n_col, figsize=(12, 12))for i in range(n_row):# 固定类别fixed_class = torch.tensor([i] * n_col, device=device)# 随机噪声z = torch.randn(n_col, latent_dim, device=device)# 生成图像gen_imgs = generator(z, fixed_class).detach().cpu()# 转换到[0, 1]范围gen_imgs = 0.5 * gen_imgs + 0.5# 显示图像for j in range(n_col):axs[i, j].imshow(gen_imgs[j, 0, :, :], cmap='gray')axs[i, j].axis('off')plt.savefig(f'cond_wgan_gp_epoch_{epoch+1}.png')plt.close()# 运行条件WGAN-GP训练
if __name__ == "__main__":train_conditional_wgan_gp()
清华大学全五版的《DeepSeek教程》完整的文档需要的朋友,关注我私信:deepseek 即可获得。
怎么样今天的内容还满意吗?再次感谢朋友们的观看,关注GZH:凡人的AI工具箱,回复666,送您价值199的AI大礼包。最后,祝您早日实现财务自由,还请给个赞,谢谢!