CGAN 原理及实现
- 一、CGAN 原理
- 1.1 基本概念
- 1.2 与传统GAN的区别
- 1.3 目标函数
- 1.4 损失函数
- 1.5 条件信息的融合方式
- 1.6 与其他GAN变体的对比
- 1.7 CGAN的应用
- 1.8 改进与变体
- 二、CGAN 实现
- 2.1 导包
- 2.2 数据加载和处理
- 2.3 构建生成器
- 2.4 构建判别器
- 2.5 训练和保存模型
- 2.6 绘制训练损失
- 2.7 图片转GIF
- 2.8 模型加载和生成
一、CGAN 原理
1.1 基本概念
条件生成对抗网络
(Conditional GAN, CGAN)是GAN的一种扩展,它在生成器和判别器中都加入了额外的条件信息
y y y。这个条件信息可以是类别标签、文本描述或其他形式的辅助信息。
1.2 与传统GAN的区别
- 传统GAN: G ( z ) G(z) G(z) → 生成样本, D ( x ) D(x) D(x) → 判断真实/生成
CGAN
: G ( z ∣ y ) G(z|y) G(z∣y) → 基于条件 y y y 生成样本, D ( x ∣ y ) D(x|y) D(x∣y) → 基于条件 y y y 判断真实/生成
1.3 目标函数
CGAN的目标函数可以表示为: m i n G m a x D V ( D , G ) = 𝔼 x ∼ p data [ l o g D ( x ∣ y ) ] + 𝔼 z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ∣ y ) ∣ y ) ) ] min_G max_D V(D,G) = 𝔼_{x \sim p_{\text{data}}}[log D(x|y)] + 𝔼_{z \sim p_z(z)}[log(1 - D(G(z|y)|y))] minGmaxDV(D,G)=Ex∼pdata[logD(x∣y)]+Ez∼pz(z)[log(1−D(G(z∣y)∣y))],其中 y y y 是条件信息。
1.4 损失函数
(1) 判别器(Discriminator)的损失函数
\space \space 判别器需要同时判断:
真实图像是否真实
(且匹配其标签)生成图像是否虚假
(且匹配其标签)
损失函数公式:
L D = E x , y ∼ p data [ log D ( x ∣ y ) ] ⏟ 真实样本损失 + E z ∼ p z , y ∼ p labels [ log ( 1 − D ( G ( z ∣ y ) ∣ y ) ] ⏟ 生成样本损失 \mathcal{L}_D = \underbrace{\mathbb{E}_{x,y \sim p_{\text{data}}}[\log D(x|y)]}_{\text{真实样本损失}} + \underbrace{\mathbb{E}_{z \sim p_z, y \sim p_{\text{labels}}}[\log (1 - D(G(z|y)|y)]}_{\text{生成样本损失}} LD=真实样本损失 Ex,y∼pdata[logD(x∣y)]+生成样本损失 Ez∼pz,y∼plabels[log(1−D(G(z∣y)∣y)]
(2)生成器(Generator)的损失函数
\space 生成器的目标是欺骗判别器
,使其认为生成的图像是真实的(且匹配条件标签 y y y)。
损失函数公式:
L G = E z ∼ p z , y ∼ p labels [ log ( 1 − D ( G ( z ∣ y ) ∣ y ) ] ⏟ 原始形式 或 − E z , y [ log D ( G ( z ∣ y ) ∣ y ) ] ⏟ 改进形式 \mathcal{L}_G = \underbrace{\mathbb{E}_{z \sim p_z, y \sim p_{\text{labels}}}[\log (1 - D(G(z|y)|y)]}_{\text{原始形式}} \quad \text{或} \quad \underbrace{-\mathbb{E}_{z,y}[\log D(G(z|y)|y)]}_{\text{改进形式}} LG=原始形式 Ez∼pz,y∼plabels[log(1−D(G(z∣y)∣y)]或改进形式 −Ez,y[logD(G(z∣y)∣y)]
1.5 条件信息的融合方式
在损失计算中,条件标签通过以下方式参与:
- 生成器输入:噪声
z
和标签y
拼接后输入生成器gen_input = torch.cat([z, label_embed], dim=1) # z: [batch, z_dim], label_embed: [batch, embed_dim]
- 判别器输入:图像和标签拼接后输入判别器
# 图像x: [batch, C, H, W], 标签扩展为 [batch, embed_dim, H, W] label_expanded = label_embed.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, H, W) disc_input = torch.cat([x, label_expanded], dim=1) # 沿通道维度拼接
1.6 与其他GAN变体的对比
损失函数特性 | 标准GAN | 条件GAN(CGAN) | WGAN-GP |
---|---|---|---|
判别器输出 | 概率值 (0~1) | 条件概率值 | 未限制的分数 |
生成器目标 | 欺骗判别器 | 生成符合标签的图像 | 最小化Wasserstein距离 |
梯度稳定性 | 易崩溃 | 依赖条件强度 | 通过梯度惩罚稳定 |
1.7 CGAN的应用
- 图像生成:根据类别标签生成特定类型的图像
- 图像到图像转换:如将语义标签图转换为真实图像
- 文本到图像生成:根据文本描述生成图像
- 数据增强:为特定类别生成额外的训练样本
1.8 改进与变体
- AC-GAN:辅助分类器GAN,在判别器中增加分类任务
- InfoGAN:学习可解释的潜在表示
- StackGAN:分阶段生成高分辨率图像
- ProGAN:渐进式生成高分辨率图像
条件GAN通过引入条件信息
,使得生成过程更加可控,能够生成特定类别的样本,在实际应用中具有广泛的用途。
二、CGAN 实现
2.1 导包
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as npimport os
import time
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm # 判断是否存在可用的GPU
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 设置日志
time_str = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()) # 生成当前时间格式(例如:2024-03-15_14-30-00)
log_dir = os.path.join("./logs/cgan", time_str) # 设置日志路径,格式如:./logs/cgan/2024-03-15_14-30-00
os.makedirs(log_dir, exist_ok=True) # 自动创建目录
writer = SummaryWriter(log_dir=log_dir) # 初始化 SummaryWriteros.makedirs("./img/cgan_mnist", exist_ok=True) # 存放生成样本目录
os.makedirs("./model", exist_ok=True) # 模型存放目录
2.2 数据加载和处理
# 加载 MNIST 数据集
def load_data(batch_size=64,img_shape=(1,32,32)):transform = transforms.Compose([transforms.Resize((img_shape[1],img_shape[2])),transforms.ToTensor(), # 将图像转换为张量transforms.Normalize(mean=[0.5], std=[0.5]) # 归一化到[-1,1]])# 下载训练集和测试集train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 创建 DataLoadertrain_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=2,shuffle=True)test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=2,shuffle=False)return train_loader, test_loader
2.3 构建生成器
class Generator(nn.Module):"""生成器"""def __init__(self, img_shape=(1,32,32),latent_dim=100,num_classes=10,label_embed_dim=10):"""Args:img_shape (int, optional): 生成图片大小,默认CHW=1*32*32latent_dim (int, optional): 潜在噪声向量的维度。默认100维,作为生成器的随机输入种子。 num_classes (int, optional): 类别数量。默认10(例如MNIST的0-9数字分类)。决定标签嵌入矩阵的行数。label_embed_dim (int, optional): 标签嵌入向量的维度。默认10维。将离散标签映射为连续向量的维度,影响条件信息的表达能力。 """super(Generator, self).__init__()# 定义嵌入层 [batch_szie]-> [batch_size,label_embed_dim]=[64,10]self.label_embed = nn.Embedding(num_classes, label_embed_dim) # num_classes 个类别, label_embed_dim 维嵌入# 定义网络块def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(negative_slope=0.2, inplace=True))return layers# 定义模型架构self.model = nn.Sequential(*block(latent_dim + label_embed_dim, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, int(np.prod(img_shape))), # [batch_size,1024]-> [batch_size,1*32*32]nn.Tanh() # 输出归一化到[-1,1] )def forward(self, z, labels):# 嵌入标签 [batch_size]-> [batch_size,label_embed_dim]=[64,10]label_embed = self.label_embed(labels)# 拼接嵌入标签和噪声 ->[batch_size,latent_dim + label_embed_dim]=[64,100+10]gen_input = torch.cat([label_embed, z], dim=1)# 生成图片-> [batch_size,C,H,W]=[64,1,32,32]img = self.model(gen_input) # -> [batch_size,C*H*W]=[64,1*32*32]img = img.view(img.shape[0], *img_shape) # [batch_size,C*H*W]-> [batch_size,C,H,W]=[64,1,32,32]return img
2.4 构建判别器
class Discriminator(nn.Module):"""判别器"""def __init__(self, img_shape=(1,32,32),label_embed_dim=10):super(Discriminator, self).__init__()# 定义嵌入层 [batch_szie]-> [batch_size,label_embed_dim]=[64,10]self.label_embed = nn.Embedding(num_classes, label_embed_dim) # num_classes 个类别, label_embed_dim 维嵌入# 定义模型结构self.model = nn.Sequential(nn.Linear(label_embed_dim+ int(np.prod(img_shape)), 512), # [64,10+1*32*32]-> [64,512]nn.LeakyReLU(negative_slope=0.2, inplace=True),nn.Linear(512, 512),nn.Dropout(0.4),nn.LeakyReLU(negative_slope=0.2, inplace=True),nn.Linear(512, 512),nn.Dropout(0.4),nn.LeakyReLU(negative_slope=0.2, inplace=True),nn.Linear(512, 1),)def forward(self, img, labels):# 嵌入标签 [batch_size]-> [batch_size,label_embed_dim]=[64,10]label_embed = self.label_embed(labels)# 输入图片展平[64,1,32,32]-> [64,1*32*32]img=img.view(img.shape[0], -1)# 拼接嵌入标签和输入图片 ->[batch_size,label_embed_dim + C*H*W]=[64,10+1*32*32]dis_input = torch.cat([label_embed, img], dim=1)# 进行判定validity = self.model(dis_input)return validity # -> [64,1]
2.5 训练和保存模型
1. 定义保存生成样本
def sample_image(G,n_row, batches_done,latent_dim=100,device=device):"""Saves a grid of generated digits ranging from 0 to n_classes"""# 随机噪声-> [n_row ** 2,latent_dim]=[100,100]z=torch.normal(0,1,size=(n_row ** 2,latent_dim),device=device) #从正态分布中抽样# 条件标签->[100]labels = torch.arange(n_row, dtype=torch.long, device=device).repeat_interleave(n_row)gen_imgs = G(z, labels)save_image(gen_imgs.data, "./img/cgan_mnist/%d.png" % batches_done, nrow=n_row, normalize=True)
2. 训练和保存
# 设置超参数
batch_size = 64
epochs = 200
lr= 0.0002
latent_dim=100 # 生成器输入噪声向量的长度(维数)
sample_interval=400 #每400次迭代保存生成样本
img_shape = (1,32,32) # 图片大小
num_classes=10 # 分类数
label_embed_dim=10 # 嵌入维数# 加载数据
train_loader,_= load_data(batch_size=batch_size,img_shape=img_shape)# 实例化生成器G、判别器D
G=Generator().to(device)
D=Discriminator().to(device)# 设置优化器
optimizer_G = torch.optim.Adam(G.parameters(), lr=lr,betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(D.parameters(), lr=lr,betas=(0.5, 0.999))# 损失函数
loss_fn=nn.BCEWithLogitsLoss()# 开始训练
dis_costs,gen_costs = [],[] # 记录生成器和判别器每次迭代的开销(损失)
start_time = time.time() # 计时器
loader_len=len(train_loader) #训练集加载器的长度
for epoch in range(epochs):# 进入训练模式G.train()D.train()#记录生成器G和判别器D的总损失(1个 epoch 内)gen_loss_sum,dis_loss_sum=0.0,0.0loop = tqdm(train_loader, desc=f"第{epoch+1}轮")for i, (real_imgs, real_labels) in enumerate(loop):real_imgs=real_imgs.to(device) # [B,C,H,W]real_labels=real_labels.to(device) # [B]# 平滑真假标签,2维[B,1]valid_labels = torch.empty(real_imgs.shape[0], 1, device=device).uniform_(0.9, 1.0).requires_grad_(False) # 替代1.0fake_labels = torch.empty(real_imgs.shape[0], 1, device=device).uniform_(0.0, 0.1).requires_grad_(False) # 替代0.0# -----------------# 训练生成器# -----------------# 获取噪声样本[batch_size,latent_dim]及对应的条件标签 [batch_size]z=torch.normal(0,1,size=(real_imgs.shape[0],latent_dim),device=device) #从正态分布中抽样gen_labels = torch.randint(0, num_classes, (real_imgs.shape[0],), device=device, dtype=torch.long) # 0~9整数之间,随机抽 real_imgs.shape[0]次# 计算生成器损失gen_imgs=G(z,gen_labels)gen_loss=loss_fn(D(gen_imgs,gen_labels),valid_labels)# 更新生成器参数optimizer_G.zero_grad() #梯度清零gen_loss.backward() #反向传播,计算梯度optimizer_G.step() #更新生成器# -----------------# 训练判别器# -----------------# 计算判别器损失# Step-1:对真实图片损失valid_loss=loss_fn(D(real_imgs,real_labels),valid_labels)# Step-2:对生成图片损失fake_loss=loss_fn(D(gen_imgs.detach(),gen_labels),fake_labels)# Step-3:整体损失dis_loss=(valid_loss+fake_loss)/2.0# 更新判别器参数optimizer_D.zero_grad() #梯度清零dis_loss.backward() #反向传播,计算梯度optimizer_D.step() #更新判断器 # 对生成器和判别器每次迭代的损失进行累加gen_loss_sum+=gen_lossdis_loss_sum+=dis_lossgen_costs.append(gen_loss.item())dis_costs.append(dis_loss.item())# 每 sample_interval 次迭代保存生成样本batches_done = epoch * loader_len + iif batches_done % sample_interval == 0:sample_image(G=G,n_row=10, batches_done=batches_done)# 更新进度条loop.set_postfix(mean_gen_loss=f"{gen_loss_sum/(loop.n + 1):.8f}",mean_dis_loss=f"{dis_loss_sum/(loop.n + 1):.8f}")writer.add_scalars(main_tag="Train Losses", tag_scalar_dict={"Generator": gen_loss,"Discriminator": dis_loss,},global_step=batches_done # X轴坐标)
writer.close()
print('总共训练用时: %.2f min' % ((time.time() - start_time)/60))#仅保存模型的参数(权重和偏置),灵活性高,可以在不同的模型结构之间加载参数
torch.save(G.state_dict(), "./model/CGAN_G.pth")
torch.save(D.state_dict(), "./model/CGAN_D.pth")
2.6 绘制训练损失
# 创建画布
plt.figure(figsize=(10, 5))
ax1 = plt.subplot(1, 1, 1)# 绘制曲线
ax1.plot(range(len(gen_costs)), gen_costs, label='Generator loss', linewidth=2)
ax1.plot(range(len(dis_costs)), dis_costs, label='Discriminator loss', linewidth=2)ax1.set_xlabel('Iterations', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('CGAN Training Loss', fontsize=14)
ax1.legend(fontsize=10)
ax1.grid(True, linestyle='--', alpha=0.6)ax2 = ax1.twiny() # 创建共享Y轴的新X轴
newlabel = list(range(epochs+1)) # Epoch标签 [0,1,2,...]
iter_per_epoch = len(train_loader) # 每个epoch的iteration次数
newpos = [e*iter_per_epoch for e in newlabel] # 计算Epoch对应的iteration位置ax2.set_xticks(newpos[::10])
ax2.set_xticklabels(newlabel[::10]) ax2.xaxis.set_ticks_position('bottom')
ax2.xaxis.set_label_position('bottom')
ax2.spines['bottom'].set_position(('outward', 45)) # 坐标轴下移45点
ax2.set_xlabel('Epochs') # 设置轴标签
ax2.set_xlim(ax1.get_xlim()) # 与主X轴范围同步plt.tight_layout()
plt.savefig('cgan_loss.png', dpi=300)
plt.show()


2.7 图片转GIF
from PIL import Imagedef create_gif(img_dir="./img/cgan_mnist", output_file="./img/cgan_mnist/cgan_figure.gif", duration=100):images = []img_paths = [f for f in os.listdir(img_dir) if f.endswith(".png")]img_paths_sorted = sorted(img_paths,key=lambda x: (int(x.split('.')[0]), # (如 400.png 的 400)))for img_file in img_paths_sorted:img = Image.open(os.path.join(img_dir, img_file))images.append(img)images[0].save(output_file, save_all=True, append_images=images[1:], duration=duration, loop=0)print(f"GIF已保存至 {output_file}")
create_gif()

2.8 模型加载和生成
#载入训练好的模型
G = Generator() # 定义模型结构
G.load_state_dict(torch.load("./model/CGAN_G.pth",weights_only=True,map_location=device)) # 加载保存的参数
G.to(device) # 将模型移动到设备(GPU 或 CPU)
G.eval() # 将模型设置为评估模式# 获取噪声样本[10,100]及对应的条件标签 [10]
z=torch.normal(0,1,size=(10,100),device=device) #从正态分布中抽样
gen_labels = torch.arange(10, dtype=torch.long, device=device) #0~9整数#生成假样本
gen_imgs=G(z,gen_labels).view(-1,32,32) # 4维->3维
gen_imgs=gen_imgs.detach().cpu().numpy()# #绘制
plt.figure(figsize=(3, 2))
for i in range(10):plt.subplot(2, 5, i + 1) plt.xticks([], []) plt.yticks([], []) plt.imshow(gen_imgs[i], cmap='gray') plt.title(f"Figure {i}", fontsize=5)
plt.tight_layout()
plt.show()
