生成式AI系列 —— DCGAN生成手写数字

1、模型构建

1.1 构建生成器

# 导入软件包
import torch
import torch.nn as nnclass Generator(nn.Module):def __init__(self, z_dim=20, image_size=256):super(Generator, self).__init__()self.layer1 = nn.Sequential(nn.ConvTranspose2d(z_dim, image_size * 32,kernel_size=4, stride=1),nn.BatchNorm2d(image_size * 32),nn.ReLU(inplace=True))self.layer2 = nn.Sequential(nn.ConvTranspose2d(image_size * 32, image_size * 16,kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(image_size * 16),nn.ReLU(inplace=True))self.layer3 = nn.Sequential(nn.ConvTranspose2d(image_size * 16, image_size * 8,kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(image_size * 8),nn.ReLU(inplace=True))self.layer4 = nn.Sequential(nn.ConvTranspose2d(image_size * 8, image_size *4,kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(image_size * 4),nn.ReLU(inplace=True))self.layer5 = nn.Sequential(nn.ConvTranspose2d(image_size * 4, image_size * 2,kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(image_size * 2),nn.ReLU(inplace=True))self.layer6 = nn.Sequential(nn.ConvTranspose2d(image_size * 2, image_size,kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(image_size),nn.ReLU(inplace=True))self.last = nn.Sequential(nn.ConvTranspose2d(image_size, 3, kernel_size=4,stride=2, padding=1),nn.Tanh())# 注意:因为是黑白图像,所以只有一个输出通道def forward(self, z):out = self.layer1(z)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.layer5(out)out = self.layer6(out)out = self.last(out)return outif __name__ == "__main__":import matplotlib.pyplot as pltG = Generator(z_dim=20, image_size=256)# 输入的随机数input_z = torch.randn(1, 20)# 将张量尺寸变形为(1,20,1,1)input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)#输出假图像fake_images = G(input_z)print(fake_images.shape)img_transformed = fake_images[0].detach().numpy().transpose(1, 2, 0)plt.imshow(img_transformed)plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-M0oWDbXr-1692468683782)(E:\学习笔记\深度学习笔记\生成模型\GAN\DCGAN.assets\Figure_1.png)]

1.1 构建判别器

class Discriminator(nn.Module):def __init__(self, z_dim=20, image_size=256):super(Discriminator, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(3, image_size, kernel_size=4,stride=2, padding=1),nn.LeakyReLU(0.1, inplace=True))#注意:因为是黑白图像,所以输入通道只有一个self.layer2 = nn.Sequential(nn.Conv2d(image_size, image_size*2, kernel_size=4,stride=2, padding=1),nn.LeakyReLU(0.1, inplace=True))self.layer3 = nn.Sequential(nn.Conv2d(image_size*2, image_size*4, kernel_size=4,stride=2, padding=1),nn.LeakyReLU(0.1, inplace=True))self.layer4 = nn.Sequential(nn.Conv2d(image_size*4, image_size*8, kernel_size=4,stride=2, padding=1),nn.LeakyReLU(0.1, inplace=True))self.layer5 = nn.Sequential(nn.Conv2d(image_size*8, image_size*16, kernel_size=4,stride=2, padding=1),nn.LeakyReLU(0.1, inplace=True))self.layer6 = nn.Sequential(nn.Conv2d(image_size*16, image_size*32, kernel_size=4,stride=2, padding=1),nn.LeakyReLU(0.1, inplace=True))self.last = nn.Conv2d(image_size*32, 1, kernel_size=4, stride=1)def forward(self, x):out = self.layer1(x)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.layer5(out)out = self.layer6(out)out = self.last(out)return outif __name__ == "__main__":#确认程序执行D = Discriminator(z_dim=20, image_size=64)#生成伪造图像input_z = torch.randn(1, 20)input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)fake_images = G(input_z)#将伪造的图像输入判别器D中d_out = D(fake_images)#将输出值d_out乘以Sigmoid函数,将其转换成0~1的值print(torch.sigmoid(d_out))

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-SEGvF8xh-1692468683784)(E:\学习笔记\深度学习笔记\生成模型\GAN\DCGAN.assets\image-20230817224333376.png)]

2、数据集构建

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), ".")))
import time
from PIL import Image
import torch
import torch.utils.data as data
import torch.nn as nnfrom torchvision import transforms
from model.DCGAN import Generator, Discriminator
from matplotlib import pyplot as pltdef make_datapath_list(root):"""创建用于学习和验证的图像数据及标注数据的文件路径列表。 """train_img_list = list() #保存图像文件的路径for img_idx in range(200):img_path = f"{root}/img_7_{str(img_idx)}.jpg"train_img_list.append(img_path)img_path = f"{root}/img_8_{str(img_idx)}.jpg"train_img_list.append(img_path)return train_img_listclass ImageTransform:"""图像的预处理类"""def __init__(self, mean, std):self.data_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])def __call__(self, img):return self.data_transform(img)class GAN_Img_Dataset(data.Dataset):"""图像的 Dataset 类,继承自 PyTorchd 的 Dataset 类"""def __init__(self, file_list, transform):self.file_list = file_listself.transform = transformdef __len__(self):'''返回图像的张数'''return len(self.file_list)def __getitem__(self, index):'''获取经过预处理后的图像的张量格式的数据'''img_path = self.file_list[index]img = Image.open(img_path)  # [ 高度 ][ 宽度 ] 黑白# 图像的预处理img_transformed = self.transform(img)return img_transformed# 创建DataLoader并确认执行结果# 创建文件列表
root = "./img_78"
train_img_list = make_datapath_list(root)# 创建Dataset
mean = (0.5)
std = (0.5)
train_dataset = GAN_Img_Dataset(file_list=train_img_list, transform=ImageTransform(mean, std)
)# 创建DataLoader
batch_size = 2
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True
)# 确认执行结果
batch_iterator = iter(train_dataloader)  # 转换为迭代器
imges = next(batch_iterator)  # 取出位于第一位的元素
print(imges.size())  # torch.Size([64, 1, 64, 64])

数据请在访问链接获取:

3、train接口实现


def train_model(G, D, dataloader, num_epochs):# 确认是否能够使用GPU加速device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("使用设备:", device)# 设置最优化算法g_lr, d_lr = 0.0001, 0.0004beta1, beta2 = 0.0, 0.9g_optimizer = torch.optim.Adam(G.parameters(), g_lr, [beta1, beta2])d_optimizer = torch.optim.Adam(D.parameters(), d_lr, [beta1, beta2])# 定义误差函数criterion = nn.BCEWithLogitsLoss(reduction='mean')# 使用硬编码的参数z_dim = 20mini_batch_size = 8# 将网络载入GPU中G.to(device)D.to(device)G.train()  # 将模式设置为训练模式D.train()  # 将模式设置为训练模式# 如果网络相对固定,则开启加速torch.backends.cudnn.benchmark = True# 图像张数num_train_imgs = len(dataloader.dataset)batch_size = dataloader.batch_size# 设置迭代计数器iteration = 1logs = []# epoch循环for epoch in range(num_epochs):# 保存开始时间t_epoch_start = time.time()epoch_g_loss = 0.0  # epoch的损失总和epoch_d_loss = 0.0  # epoch的损失总和print('-------------')print('Epoch {}/{}'.format(epoch, num_epochs))print('-------------')print('(train)')# 以minibatch为单位从数据加载器中读取数据的循环for imges in dataloader:# --------------------# 1.判别器D的学习# --------------------# 如果小批次的尺寸设置为1,会导致批次归一化处理产生错误,因此需要避免if imges.size()[0] == 1:continue# 如果能使用GPU,则将数据送入GPU中imges = imges.to(device)# 创建正确答案标签和伪造数据标签# 在epoch最后的迭代中,小批次的数量会减少mini_batch_size = imges.size()[0]label_real = torch.full((mini_batch_size,), 1).to(device)label_fake = torch.full((mini_batch_size,), 0).to(device)# 对真正的图像进行判定d_out_real = D(imges)# 生成伪造图像并进行判定input_z = torch.randn(mini_batch_size, z_dim).to(device)input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)fake_images = G(input_z)d_out_fake = D(fake_images)# 计算误差d_loss_real = criterion(d_out_real.view(-1), label_real.to(torch.float))d_loss_fake = criterion(d_out_fake.view(-1), label_fake.to(torch.float))d_loss = d_loss_real + d_loss_fake# 反向传播处理g_optimizer.zero_grad()d_optimizer.zero_grad()d_loss.backward()d_optimizer.step()# --------------------# 2.生成器G的学习# --------------------# 生成伪造图像并进行判定input_z = torch.randn(mini_batch_size, z_dim).to(device)input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)fake_images = G(input_z)d_out_fake = D(fake_images)# 计算误差g_loss = criterion(d_out_fake.view(-1), label_real.to(torch.float))# 反向传播处理g_optimizer.zero_grad()d_optimizer.zero_grad()g_loss.backward()g_optimizer.step()# --------------------# 3.记录结果# --------------------epoch_d_loss += d_loss.item()epoch_g_loss += g_loss.item()iteration += 1# epoch的每个phase的loss和准确率t_epoch_finish = time.time()print('-------------')print('epoch {} || Epoch_D_Loss:{:.4f} ||Epoch_G_Loss:{:.4f}'.format(epoch, epoch_d_loss / batch_size, epoch_g_loss / batch_size))print('timer:  {:.4f} sec.'.format(t_epoch_finish - t_epoch_start))t_epoch_start = time.time()return G, D

4、训练


G = Generator(z_dim=20, image_size=64)
D = Discriminator(z_dim=20, image_size=64)
# 定义误差函数
criterion = nn.BCEWithLogitsLoss(reduction='mean')
num_epochs = 200
G_update, D_update = train_model(G, D, dataloader=train_dataloader, num_epochs=num_epochs
)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-EvqY6h3G-1692468683786)(E:\学习笔记\深度学习笔记\生成模型\GAN\DCGAN.assets\image-20230820020234209.png)]

5、测试

# 将生成的图像和训练数据可视化
# 反复执行本单元中的代码,直到生成感觉良好的图像为止device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 生成用于输入的随机数
batch_size = 8
z_dim = 20
fixed_z = torch.randn(batch_size, z_dim)
fixed_z = fixed_z.view(fixed_z.size(0), fixed_z.size(1), 1, 1)# 生成图像
fake_images = G_update(fixed_z.to(device))# 训练数据
imges = next(iter(train_dataloader))  # 取出位于第一位的元素# 输出结果
fig = plt.figure(figsize=(15, 6))
for i in range(0, 5):# 将训练数据放入上层plt.subplot(2, 5, i + 1)plt.imshow(imges[i][0].cpu().detach().numpy(), 'gray')# 将生成数据放入下层plt.subplot(2, 5, 5 + i + 1)plt.imshow(fake_images[i][0].cpu().detach().numpy(), 'gray')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/44064.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JVM——StringTable面试案例+垃圾回收+性能调优+直接内存

JVM——引言JVM内存结构_北岭山脚鼠鼠的博客-CSDN博客 书接上回内存结构——方法区。 这里常量池是运行时常量池。 方法区 面试题 intern()方法 intern() 方法用于在运行时将字符串添加到内部的字符串池stringtable中,并返回字符串池stringtable中的引用。 返…

计算机组成与设计 Patterson Hennessy 笔记(二)MIPS 指令集

计算机的语言:汇编指令集 也就是指令集。本书主要介绍 MIPS 指令集。 汇编指令 算数运算: add a,b,c # abc sub a,b,c # ab-cMIPS 汇编的注释是 # 号。 由于MIPS中寄存器大小32位,是基本访问单位,因此也被称为一个字 word。M…

输入输出+暴力模拟入门:魔法之树、染色の树、矩阵、字母加密、玫瑰鸭

秋招实习刷题网站推荐&#xff1a;codefun2000.com&#xff0c;还有题解博客&#xff1a;blog.codefun2000.com/。以下内容都是来自塔子哥的~ 输入输出 2023.04.15-春招-第三题-魔法之树 //#include<bits/stdc.h> #include<vector> #include<iostream>usin…

Datawhale Django后端开发入门Task01 Vscode配置环境

首先呢放一张运行成功的截图纪念一下&#xff0c;感谢众多小伙伴的帮助呀&#xff0c;之前没有配置这方面的经验 &#xff0c;但还是一步一步配置成功了&#xff0c;所以在此以一个纯小白的经验分享如何配置成功。 1.选择要建立项目的文件夹&#xff0c;打开文件找到目标文件夹…

csapp archlab PartC满分解答

任务 修改ncopy.ys和pipe-full.hcl以尽可能的提高ncopy.ys的运行速度 思路 pipe-full.hcl&#xff1a; 实现iaddq指令&#xff08;家庭作业4.54&#xff09;实现加载转发&#xff08;家庭作业4.57&#xff09; ncopy.ys&#xff1a; 使用循环展开&#xff08;第5.8节&…

openai多模态大模型:clip详解及使用

引言 CLIP全称Constrastive Language-Image Pre-training&#xff0c;是OpenAI推出的采用对比学习的文本-图像预训练模型。CLIP惊艳之处在于架构非常简洁且效果好到难以置信&#xff0c;在zero-shot文本-图像检索&#xff0c;zero-shot图像分类&#xff0c;文本→图像生成任务…

windows服务器下java程序健康检测及假死崩溃后自动重启应用、开机自动启动

前两天由于项目需要&#xff0c;一个windows上的批处理任务&#xff08;kitchen.bat&#xff09;&#xff0c;需要接到mq的消息通知后执行&#xff0c;为了快速实现这里我们通过springboot写了一个jar程序&#xff0c;用于接收mq的消息&#xff0c;并调用bat文件。 本程序需要实…

缺少或找不到vcruntime140_1.dll的解决方法

某天&#xff0c;当我准备打开电脑上的一个应用程序时&#xff0c;突然收到一个错误提示&#xff0c;显示缺少了vcruntime140_1.dll文件。这个文件是一个重要的系统组件&#xff0c;它的丢失导致了我无法正常运行该应用程序。于是&#xff0c;我开始了一场寻找和修复旅程。然而…

这是我的纪念日

我的1460天 写这篇文章&#xff0c;是为了纪念自己这一千多个日日夜夜&#xff0c;我的热爱总算有了回报。 每次看到有小伙伴点赞&#xff0c;评论的时候&#xff0c;我都很开心&#xff0c;我知道自己的选择是正确的&#xff0c;我喜欢分享自己的所见所学&#xff0c;我也很…

python3.7 安装pywin32报错,完美解决方法

本机环境 python&#xff1a;3.7 遇到2种报错 第一种 ImportError: DLL load failed: The specified module could not be found.第二种&#xff1a; import win32gui ModuleNotFoundError: No module named ‘win32gui‘解决方法 我安装pywin32时候&#xff0c;是直接pi…

stm32红绿灯源代码示例(附带Proteus电路图)

本代码不能直接用于红路灯&#xff0c;只是提供一个思路 #include "main.h" #include "gpio.h" void SystemClock_Config(void); void MX_GPIO_Init(void) {GPIO_InitTypeDef GPIO_InitStruct {0};/* GPIO Ports Clock Enable */__HAL_RCC_GPIOB_CLK_ENAB…

Jenkins-CICD-python/Java包升级与回退

Jenkins- CICD流水线 python/Java代码升级与回退 1、执行思路 1.1、代码升级 jenkins上点击 upgrade和 代码版本号 --${tag} jenkins 推送 代码 和 执行脚本 到目标服务器/opt目录下 执行命令 sh run.sh 代码名称 版本号 upgrade 版本号 来自jenkins的 构建参数中的 标签…

Gitlab-第四天-CD到k8s集群的坑

一、.gitlab-ci.yml #CD到k8s集群的 stages: - deploy-test build-image-deploy-test: stage: deploy-test image: bitnami/kubectl:latest # 使用一个包含 kubectl 工具的镜像 tags: - k8s script: - ls -al - kubectl apply -f deployment.yaml # 根据实际情况替换…

LlamaGPT -基于Llama 2的自托管类chatgpt聊天机器人

LlamaGPT一个自托管、离线、类似 ChatGPT 的聊天机器人&#xff0c;由 Llama 2 提供支持。100% 私密&#xff0c;不会有任何数据离开你的设备。 推荐&#xff1a;用 NSDT编辑器 快速搭建可编程3D场景 1、如何安装LlamaGPT LlamaGPT可以安装在任何x86或arm64系统上。 首先确保…

PHP8的字符串操作3-PHP8知识详解

今天继续分享字符串的操作&#xff0c;前面说到了字符串的去除空格和特殊字符&#xff0c;获取字符串的长度&#xff0c;截取字符串、检索字符串。 今天继续分享字符串的其他操作。如&#xff1a;替换字符串、分割和合成字符串。 5、替换字符串 替换字符串就是对指定字符串中…

vue浏览器插件安装-各种问题

方法1&#xff1a;vue.js devtolls插件下载 https://blog.csdn.net/qq_55640378/article/details/131553642 下载地址&#xff1a; Tags vuejs/devtools GitHub npm install 或是 cnpm install 遇到的报错 设置淘宝镜像源&#xff08;推荐使用nrm&#xff0c;这一步是为…

使用IText导出复杂pdf

1、问题描述 需要将发票导出成pdf&#xff0c;要求每页都必须包含发票信息和表头行。 2、解决方法 使用IText工具实现PDF导出 IText8文档&#xff1a;Examples (itextpdf.com) 3、我的代码 引入Itext依赖&#xff0c;我这里用的是8.0.1版本 <dependency><groupId>…

uniapp 上传比较大的视频文件就超时

uni.uploadFile&#xff0c;上传超过10兆左右的文件就报错err&#xff1a;uploadFile:fail timeout&#xff0c;超时 解决&#xff1a; 在manifest.json文件中做超时配置 uni.uploadFile({url: this.action,method: "POST",header: {Authorization: uni.getStorage…

Python编程——列表解析与常用操作

作者&#xff1a;Insist-- 个人主页&#xff1a;insist--个人主页 本文专栏&#xff1a;Python专栏 专栏介绍&#xff1a;本专栏为免费专栏&#xff0c;并且会持续更新python基础知识&#xff0c;欢迎各位订阅关注。 目录 一、列表是什么&#xff1f; 二、列表的特点 1、元素…

pyltp 0.2.1安装

1. LTP及pyltp pyltp是 LTP的 Python封装&#xff0c;它里面提供了包括分词&#xff0c;词性标注&#xff0c;命名实体识别&#xff0c;句法分析等等能力。 比较坑的是我们可能无法直接通过pip install pyltp0.2.1方式来安装&#xff0c;所以本文就简单记录下如何通过源码安装…