G9 - ACGAN理论与实战

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

目录

  • 环境
  • 步骤
    • 环境设置
    • 数据准备
    • 工具方法
    • 模型设计
    • 模型训练
    • 模型效果展示
  • 总结与心得体会


上周已经简单的了解了ACGAN的原理,并且不经实践的编写了部分代码,这周复现一下真正的ACGAN

环境

Pytorch: 2.3.1+cu121
Nvidia GTX 4090

步骤

环境设置

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_imagefrom torch.utils.data import DataLoader
from torch.autograd import Variable
import numpy as npdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 全局参数
n_epochs = 200
batch_size = 64
lr = 0.0002
b1 = 0.5
b2 = 0.999
n_cpu = 8
latent_dim = 100
n_classes = 10
img_size = 32
channels = 1
sample_interval = 400

数据准备

# 创建中间采样图片的文件夹
import os
os.makedirs('images', exist_ok=True)
# 配置数据集
os.makedirs('data/mnist', exist_ok=True)
dataloader = DataLoader(datasets.MNIST('data/mnist',train=True,download=True,transform=transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),transforms.Normalize([0.5], [0.5])]),),batch_size=batch_size,shuffle=True,
)

工具方法

# 权重初始化函数
def weights_init_normal(m):classname = m.__class__.__name__if classname.find('Conv') != -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find('BatchNorm2d') != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02)torch.nn.init.constant_(m.bias.data, 0.0)# 日志函数 因为使用了jupyter notebook环境,长时间的任务日志无法直接查看,于是需要打印到文件
import logging
import sys
import datetimedef init_logger(filename, logger_name):'''@brief:initialize logger that redirect info to a file just in case we lost connection to the notebook@params:filename: to which file should we log all the infologger_name: an alias to the logger'''# get current timestamptimestamp = datetime.datetime.utcnow().strftime('%Y%m%d_%H-%M-%S')logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s',handlers=[logging.FileHandler(filename=filename),logging.StreamHandler(sys.stdout)])# Testlogger = logging.getLogger(logger_name)logger.info('### Init. Logger {} ###'.format(logger_name))return logger# Initialize
my_logger = init_logger("./ml_notebook.log", "ml_logger")# 生成函数的结果保存
def sample_image(n_row, batches_done):"""保存从0到n_classes的生成数字的图像风格"""# 采样噪声z = torch.randn((n_row**2, latent_dim), device=device)# 为n行生成标签从0到n_classeslabels = torch.tensor([num for _ in range(n_row) for num in range(n_row)], device=device)gen_imgs = generator(z, labels)save_image(gen_imgs.data.cpu(), 'images/%d.png' % batches_done, nrow=n_row, normalize=True)

模型设计

# 生成器
class Generator(nn.Module):def __init__(self):super().__init__()# 标签嵌入self.label_emb = nn.Embedding(n_classes, latent_dim)# 计算上采样前的初始大小self.init_size = img_size // 4# 第一层线性层self.l1 = nn.Sequential(nn.Linear(latent_dim, 128*self.init_size**2))# 卷积层self.conv_blocks = nn.Sequential(nn.BatchNorm2d(128),nn.Upsample(scale_factor=2),nn.Conv2d(128, 128, 3, stride=1, padding=1),nn.BatchNorm2d(128, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Upsample(scale_factor=2),nn.Conv2d(128, 64, 3, stride=1, padding=1),nn.BatchNorm2d(64, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, channels, 3, stride=1, padding=1),nn.Tanh(),)def forward(self, noise, labels):# 标签嵌入到噪声中gen_input = torch.mul(self.label_emb(labels), noise)# 通过第一层线性层out = self.l1(gen_input)# 整形out = out.view(out.shape[0], 128, self.init_size, self.init_size)# 卷积生成图像img = self.conv_blocks(out)return img
# 判别器
class Discriminator(nn.Module):def __init__(self):super().__init__()# 判别器块生成函数def discriminator_block(in_filters, out_filters, bn=True):"""返回每个判别器层"""block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]if bn:block.append(nn.BatchNorm2d(out_filters, 0.8))return block# 卷积层self.conv_blocks = nn.Sequential(*discriminator_block(channels, 16, bn=False),*discriminator_block(16, 32),*discriminator_block(32, 64),*discriminator_block(64, 128),)# 下采样后,图像的宽高ds_size = img_size // 2 ** 4# 输出层self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, n_classes), nn.Softmax())def forward(self, img):out = self.conv_blocks(img)out = out.view(out.shape[0], -1)validity = self.adv_layer(out)label = self.aux_layer(out)return validity, label# 模型初始化# 损失函数
adversarial_loss = nn.BCELoss()
auxiliary_loss = nn.CrossEntropyLoss()# 初始化生成器和判别器
generator = Generator().to(device)
discriminator = Discriminator().to(device)# 初始化权重
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

模型训练

# 训练# 优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))for epoch in range(n_epochs):for i, (imgs, labels) in enumerate(dataloader):batch_size = imgs.shape[0]# 图像是 真实的 标签valid = torch.ones((batch_size, 1), requires_grad=False, device=device)# 图像是 生成的 标签fake = torch.zeros((batch_size, 1), requires_grad=False, device=device)real_imgs = imgs.to(device)labels = labels.to(device)# 训练生成器optimizer_G.zero_grad()# 采样噪声和标签作为生成器的输入z = torch.randn((batch_size, latent_dim), device=device)gen_labels = torch.randint(0, 1, (batch_size,), device=device)# 生成一批图像gen_imgs = generator(z, gen_labels)# 损失度量 生成器欺骗判别器的能力validity, pred_label = discriminator(gen_imgs)g_loss = 0.5 * (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels))g_loss.backward()optimizer_G.step()# 训练判别器optimizer_D.zero_grad()# 真实图像的损失real_pred, real_aux = discriminator(real_imgs)d_real_loss = 0.5 * (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels))# 生成图像的损失fake_pred, fake_aux = discriminator(gen_imgs.detach())d_fake_loss = 0.5 * (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels))# 判别器的总损失d_loss = 0.5 * (d_real_loss + d_fake_loss)# 计算判别器的准确率pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)gt = np.concatenate([labels.data.cpu().numpy(), gen_labels.data.cpu().numpy()], axis=0)d_acc = np.mean(np.argmax(pred, axis=1) == gt)d_loss.backward()optimizer_D.step()if i % 100 == 0:my_logger.info("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]" % (epoch, n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item()))batches_done = epoch * len(dataloader) + iif batches_done % sample_interval == 0:sample_image(n_row=10, batches_done=batches_done)

训练过程

模型效果展示

刚开始训练
训练到最后

总结与心得体会

通过对模型的复现,发现我之前对判别器的理解有偏差,如果在判别器的输入中插入分类信息,等于是将答案直接给了判别器,生成的结果反而不会太好。还有一个和我预想的不一样的地方,在生成器中,将标签嵌入到特征向量使用了矩阵乘法,而没有直接使用concatenate操作。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/41959.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python如何设计窗口

PyQt是一个基于Qt的接口包,可以直接拖拽控件设计UI界面,下面我简单介绍一下这个包的安装和使用,感兴趣的朋友可以自己尝试一下: 1、首先,安装PyQt模块,这个直接在cmd窗口输入命令“pip install pyqt5”就行…

中金女员工离世悲剧:职场压力、心理健康与社会支持的深刻反思

中金女员工离世背后的深思 2024年7月1日,一则令人痛心的消息在金融界乃至整个网络迅速传播:中金公司一位年仅30岁的女员工郑某露,在降薪和房贷的双重压力下,不幸离世。这一事件不仅让她的家人和朋友陷入了深深的悲痛之中,也引发了社会各界对职场环境、个体心理健康以及社…

使用块的网络 VGG

一、AlexNet与VGG 1、深度学习追求更深更大,使用VGG将卷积层组合为块 2、VGG块:3*3卷积(pad1,n层,m通道)、2*2最大池化层 二、VGG架构 1、多个VGG块后接全连接层 2、不同次数的重复块得到不同的架构&a…

工作手机怎么做好业务员工作微信的监控管理

什么是工作手机管理系统? 工作手机管理系统是专为企业管理设计的员工微信管理,它通过监控通讯记录、保障数据安全、自动检测敏感行为、永久保留客户信息等功能,帮助企业提升销售效率、维护客户资源安全,并确保业务流程的合规性。…

NASA和IBM推出INDUS:高级科学研究的综合大模型

在最近的一项研究中,来自美国宇航局和IBM的一组研究人员合作开发了一种模型,该模型可应用于地球科学,天文学,物理学,天体物理学,太阳物理学,行星科学和生物学以及其他多学科学科。当前的模型&am…

DP:二维费用背包问题

文章目录 🎵二维费用背包问题🎶引言🎶问题定义🎶动态规划思想🎶状态定义和状态转移方程🎶初始条件和边界情况 🎵例题🎶1.一和零🎶2.盈利计划 🎵总结 &#x1…

机器人具身智能Embodied AI

强调智能体(如机器人)通过物理身体在物理世界中的实时感知、交互和学习来执行任务。 通过物理交互来完成任务的智能系统。它由“本体”(即物理身体)和“智能体”(即智能核心)耦合而成,能够在复…

taoCMS v3.0.2 任意文件读取漏洞(CVE-2022-23316)

前言 CVE-2022-23316 是一个影响 taoCMS v3.0.2 的漏洞。这个漏洞允许攻击者通过 admin.php?actionfile&ctrldownload&path../../1.txt 的路径读取任意文件。攻击者可以利用该漏洞读取服务器上的任何文件,只要他们知道文件的路径​ (OpenCVE)​​ (Tenabl…

亚马逊跟卖ERP的自动调价功能,能够简易地批量设置价格规则。

跟卖的智能调价 跟卖智能调价简单说是可以上调,下调就是怎么说?上调就是它根靠根据市场最低的价格情况进行去上调。 然后添加指定条件,到工具栏找到指定条件,点击添加指定条件。 然后选择店铺,比如选择店铺&#xf…

微信⼩程序的电影推荐系统-计算机毕业设计源码76756

摘 要 随着互联网的普及和移动互联网的发展,人们对于获取信息的便捷性和高效性要求越来越高。电影作为一种受众广泛喜爱的娱乐方式,电影推荐系统的出现为用户提供了更加个性化和精准的电影推荐服务。微信小程序作为一种轻量级应用形式,在用户…

算法题-回文子串和最长回文子序列

算法题-回文子串和最长回文子序列 一、647. 回文子串二、516. 最长回文子序列 一、647. 回文子串 中等 给你一个字符串 s ,请你统计并返回这个字符串中 回文子串 的数目。 回文字符串 是正着读和倒过来读一样的字符串。 子字符串 是字符串中的由连续字符组成的一个…

qt 如果把像素点数据变成一个图片

1.概要 图像的本质是什么&#xff0c;就是一个个的像素点&#xff0c;对与显示器来说就是一个二维数组。无论多复杂的图片&#xff0c;对于显示器来说就是一个二维数组。 2.代码 #include "widget.h"#include <QApplication> #include <QImage> #incl…

Java对象通用比对工具

目录 背景 思路 实现 背景 前段时间的任务中&#xff0c;遇到了需要识别两个对象不同属性的场景&#xff0c;如果使用传统的一个个属性比对equals方法&#xff0c;会存在大量的重复工作&#xff0c;而且为对象新增了属性后&#xff0c;比对方法也需要同步修改&#xff0c;不方…

node的下载、安装、配置和使用(node.js下载安装和配置、npm命令汇总、cnpm的使用)

天行健,君子以自强不息;地势坤,君子以厚德载物。 每个人都有惰性,但不断学习是好好生活的根本,共勉! 文章均为学习整理笔记,分享记录为主,如有错误请指正,共同学习进步。 愿将腰下剑,直为斩楼兰。 ——《塞下曲》 文章目录 一、node.js的下载、安装和配置1. node.js下…

集智书童 | 英伟达和斯坦福基于 Transformer 的异常检测最新研究!

本文来源公众号“集智书童”&#xff0c;仅用于学术分享&#xff0c;侵权删&#xff0c;干货满满。 原文链接&#xff1a;英伟达和斯坦福基于 Transformer 的异常检测最新研究&#xff01; 在作者推动各种视觉任务性能边界的同时&#xff0c;模型的大小也在相应增长。为了跟上…

电商视角如何理解动态IP与静态IP

在电子商务的蓬勃发展中&#xff0c;网络基础设施的稳定性和安全性是至关重要的。其中&#xff0c;IP地址作为网络设备间通信的基础&#xff0c;扮演着举足轻重的角色。从电商的视角出发&#xff0c;我们可以将动态IP和静态IP比作电商平台上不同类型的店铺安排&#xff0c;以此…

华为ENSP防火墙+路由器+交换机的常规配置

(防火墙区域DHCP基于接口DHCP中继服务器区域有线区域无线区域&#xff09;配置 一、适用场景&#xff1a; 1、普通企业级网络无冗余网络环境&#xff0c;防火墙作为边界安全设备&#xff0c;分trust&#xff08;内部网络信任区域&#xff09;、untrust&#xff08;外部网络非信…

vulnhub靶场之Jarbas

1 信息收集 1.1 主机发现 arp-scan -l 发现主机IP地址为&#xff1a;192.168.1.16 1.2 端口发现 nmap -sS -sV -A -T5 -p- 192.168.1.16 存在端口22&#xff0c;80&#xff0c;3306&#xff0c;8080 1.3 目录扫描 dirsearch -u 192.168.1.16 2 端口访问 2.1 80端口 2.2…

LRU缓存算法设计

LRU 缓存算法的核⼼数据结构就是哈希链表&#xff0c;双向链表和哈希表的结合体。这个数据结构⻓这样&#xff1a; 创建的需要有两个方法&#xff0c;一个是get方法&#xff0c;一个是put方法。 一些问题&#xff1a;为什么需要使用双向链表呢&#xff1f;因为删除链表的本身&…

[单master节点k8s部署]20.监控系统构建(五)Alertmanager

prometheus将监控到的异常事件发送给Alertmanager&#xff0c;然后Alertmanager将报警信息发送到邮箱等设备。可以从下图看出&#xff0c;push alerts是由Prometheus发起的。 安装Alertmanager config文件 [rootmaster prometheus]# cat alertmanager-cm.yaml kind: ConfigMa…