在kaggle中用GPU使用CGAN生成指定mnist手写数字

文章目录

  • 1项目介绍
  • 2参考文章
  • 3代码的实现过程及对代码的详细解析
    • 独热编码
    • 定义生成器
    • 定义判别器
    • 打印我们的引导信息
    • 模型训练
    • 迭代过程中生成的图片
    • 损失函数的变化
  • 4总结
  • 5 模型相关的文件

1项目介绍

在GAN的基础上进行有条件的引导生成图片cgan

2参考文章

GAN实战之Pytorch 使用CGAN生成指定MNIST手写数字
GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

3代码的实现过程及对代码的详细解析


import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)import os
for dirname, _, filenames in os.walk('/kaggle/input'):for filename in filenames:print(os.path.join(dirname, filename))
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
from torch.utils import data
import os
import glob
from PIL import Image

独热编码


# 输入x代表默认的torchvision返回的类比值,class_count类别值为10
def one_hot(x, class_count=10):return torch.eye(class_count)[x, :]  # 切片选取,第一维选取第x个,第二维全要

torch.eye(10)函数的作用是生成一个10*10的对角矩阵
该函数的作用是得到第x个位置为1的独热编码,如果传入为列表,则得到一个矩阵
在这里插入图片描述

 
transform =transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5, 0.5)])#minist数据集中的图片数据的维度是[batch_size, 1, 28, 28],其中batch_size是每个批次的图像数量。这个数据集中的每个图像都是28x28像素的灰度图像,因此它们只有一个通道
dataset = torchvision.datasets.MNIST('data',train=True,transform=transform,target_transform=one_hot,download=True)
#这里target_transform参数的作用是对标签进行转换。在这个例子中,它的作用是将标签转换为one-hot编码。
dataloader = data.DataLoader(dataset, batch_size=64, shuffle=True)

定义生成器


class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()#因此,这个函数的输入张量维度为[batch_size, 10]和[batch_size, 100],输出张量维度为[batch_size, 1, 1, 1]。self.linear1 = nn.Linear(10, 128 * 7 * 7)self.bn1 = nn.BatchNorm1d(128 * 7 * 7)self.linear2 = nn.Linear(100, 128 * 7 * 7)self.bn2 = nn.BatchNorm1d(128 * 7 * 7)#这个函数的作用是将一个输入张量进行反卷积操作,得到一个输出张量。#nn.ConvTranspose2d函数的作用是将一个256通道的输入张量转换为一个128通道的输出张量,使用3x3的卷积核进行卷积操作,并在卷积操作后进行1像素的paddingself.deconv1 = nn.ConvTranspose2d(256, 128,kernel_size=(3, 3),padding=1)self.bn3 = nn.BatchNorm2d(128)self.deconv2 = nn.ConvTranspose2d(128, 64,kernel_size=(4, 4),stride=2,padding=1)self.bn4 = nn.BatchNorm2d(64)self.deconv3 = nn.ConvTranspose2d(64, 1,kernel_size=(4, 4),stride=2,padding=1)def forward(self, x1, x2):x1 = F.relu(self.linear1(x1))x1 = self.bn1(x1)x1 = x1.view(-1, 128, 7, 7)x2 = F.relu(self.linear2(x2))x2 = self.bn2(x2)x2 = x2.view(-1, 128, 7, 7)#将两个处理后的结果拼接在一起,得到形状为[64, 256, 7, 7]的张量x = torch.cat([x1, x2], axis=1)x = F.relu(self.deconv1(x))#形状变为为[64, 128, 7, 7]的张量x = self.bn3(x)x = F.relu(self.deconv2(x))#形状变为为[64, 64, 14, 14]的张量x = self.bn4(x)# 形状变为为[64, 1, 28, 28]的张量x = torch.tanh(self.deconv3(x))return x

生成器对数据的处理过程:
这个函数对于输入张量[64, 1, 28, 28]的维度变化过程如下:
输入张量维度为[64, 1, 28, 28]
经过线性变换和ReLU激活函数处理后,得到两个形状为[64, 128 * 7 * 7]的张量
将两个张量分别通过BatchNorm1d进行归一化处理
将两个处理后的结果reshape成形状为[64, 128, 7, 7]的张量
将两个处理后的结果拼接在一起,得到形状为[64, 256, 7, 7]的张量
经过反卷积操作得到输出张量,维度为[64, 1, 28, 28]

定义判别器


# input:1,28,28的图片以及长度为10的condition
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.linear = nn.Linear(10, 1*28*28)self.conv1 = nn.Conv2d(2, 64, kernel_size=3, stride=2)self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2)self.bn = nn.BatchNorm2d(128)self.fc = nn.Linear(128*6*6, 1) # 输出一个概率值def forward(self, x1, x2):#leak_relu激活函数:它在输入小于0时返回一个小的斜率,而在输入大于等于0时返回输入本身x1 =F.leaky_relu(self.linear(x1))x1 = x1.view(-1, 1, 28, 28)#torch.cat([x1, x2], axis=1)函数将张量x1和张量x2沿着第二个维度(即列)拼接起来x = torch.cat([x1, x2], axis=1)#处理过后变为(64,2,28,28)x = F.dropout2d(F.leaky_relu(self.conv1(x)))#维度变为(64,64,13,13)x = F.dropout2d(F.leaky_relu(self.conv2(x)))#维度变为(64,128,6,6)x = self.bn(x)x = x.view(-1, 128*6*6)#最后键位了64*1(同时把值映射到0~1之间)x = torch.sigmoid(self.fc(x))return x
# 初始化模型
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)# 损失计算函数
loss_function = torch.nn.BCELoss()# 定义优化器
d_optim = torch.optim.Adam(dis.parameters(), lr=1e-5)
g_optim = torch.optim.Adam(gen.parameters(), lr=1e-4)
# 定义可视化函数
def generate_and_save_images(model, epoch, label_input, noise_input):#生成器生成取片,label_input为输入的引导信息,noise_input为随机的噪声点predictions = np.squeeze(model(label_input, noise_input).cpu().numpy())#numpy.squeeze()函数的作用是去掉矩阵里维度为1的维度。fig = plt.figure(figsize=(4, 4))for i in range(predictions.shape[0]):plt.subplot(4, 4, i + 1)plt.imshow((predictions[i] + 1) / 2, cmap='gray')plt.axis("off")from IPython.display import FileLinkplt.savefig('data/img/image_at_epoch_{:04d}.png'.format(epoch))plt.show()
import os 
os.makedirs("data/img")

打印我们的引导信息

noise_seed = torch.randn(16, 100, device=device)label_seed = torch.randint(0, 10, size=(16,))
label_seed_onehot = one_hot(label_seed).to(device)print(label_seed)
tensor([1, 3, 5, 4, 9, 3, 0, 0, 1, 3, 4, 5, 9, 2, 3, 7])

模型训练

D_loss = []
G_loss = []


# 训练循环
for epoch in range(150):d_epoch_loss = 0g_epoch_loss = 0count = len(dataloader.dataset)# 对全部的数据集做一次迭代#dataloader中的图像是四维的。在for循环中,每次迭代会返回一个batch_size大小的数据#其中每个数据都是一个四维张量,形状为[batch_size, channels, height, width]for step, (img, label) in enumerate(dataloader):img = img.to(device)label = label.to(device)size = img.shape[0]random_noise = torch.randn(size, 100, device=device)d_optim.zero_grad()real_output = dis(label, img)d_real_loss = loss_function(real_output,torch.ones_like(real_output, device=device))#torch.ones_like(real_output, device=device)函数的作用是生成一个与real_output形状相同的张量,其中所有元素都为1。                         d_real_loss.backward() #求解梯度# 得到判别器在生成图像上的损失gen_img = gen(label,random_noise)fake_output = dis(label, gen_img.detach())  # 判别器输入生成的图片,f_o是对生成图片的预测结果d_fake_loss = loss_function(fake_output,torch.zeros_like(fake_output, device=device))d_fake_loss.backward()d_loss = d_real_loss + d_fake_lossd_optim.step()  # 优化# 得到生成器的损失g_optim.zero_grad()fake_output = dis(label, gen_img)g_loss = loss_function(fake_output,torch.ones_like(fake_output, device=device))g_loss.backward()g_optim.step()with torch.no_grad():d_epoch_loss += d_loss.item()g_epoch_loss += g_loss.item()with torch.no_grad():d_epoch_loss /= countg_epoch_loss /= countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)if epoch % 10 == 0:print('Epoch:', epoch)generate_and_save_images(gen, epoch, label_seed_onehot, noise_seed)print("epoch:{}/150".format(epoch))plt.plot(D_loss, label='D_loss')
plt.plot(G_loss, label='G_loss')
plt.legend()
plt.show()

迭代过程中生成的图片

迭代1次
在这里插入图片描述
迭代10次
在这里插入图片描述
迭代20次
在这里插入图片描述

迭代30次
在这里插入图片描述

迭代40次
在这里插入图片描述

迭代150次
在这里插入图片描述

损失函数的变化

在这里插入图片描述

4总结

cGAN相比于GAN而言,将label的信息通过一系列的卷积操作和图像的信息融合在一起,然后放进模型进行训练,让我们的模型能和label相匹配的图像,从而在我们给出制定的数字label时能够生成对应的数字图片,实现了引导的过程。

5 模型相关的文件

模型的相关文件:提取码(ujki)

本模型是放在kaggle中运行的,kaggle的部署流程请参考:在kaggle中用GPU训练模型

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/56987.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Go 自学:切片slices

以下代码展示了两种建立slice的方法。 我们可以使用sort函数给slice排序。 package mainimport ("fmt""sort" )func main() {var fruitList []string{"Apple", "Tomato", "Peach"}fmt.Printf("Type of fruitlist is…

Django基础6——数据模型关系

文章目录 一、基本了解二、一对一关系三、一对多关系3.1 增删改查3.2 案例:应用详情页3.2 案例:新建应用页 四、多对多关系4.1 增删改查4.2 案例:应用详情页4.3 案例:部署应用页 一、基本了解 常见数据模型关系: 一对一…

好用的网页制作工具就是这6个,快点来看!

对于网页设计师来说,好用的网页设计工具是非常重要的,今天本文收集了6个好用的网页设计工具供设计师自由挑选使用。在这6个好用的网页设计工具的帮助下,设计师将获得更高的工作效率和更精致的网页设计效果,接下来,就一…

【Java alibabahutool】JSON、Map、实体对象间的相互转换

首先要知道三者的互转关系&#xff0c;可以先将JSON理解成是String类型。这篇博文主要是记录阿里巴巴的JSONObject的两个方法。toJSONString()以及parseObject()方法。顺便巩固Map与实体对象的转换技巧。 引入依赖 <!-- 阿里巴巴 JSON转换 以下二选一即可 没有去细研究两者…

Centos 7.6 安装mongodb

以下是在CentOS 7.6上安装MongoDB的步骤&#xff1a; 打开终端并以root用户身份登录系统。 创建一个新的MongoDB存储库文件 /etc/yum.repos.d/mongodb-org-4.4.repo 并编辑它。 sudo vi /etc/yum.repos.d/mongodb-org-4.4.repo在编辑器中&#xff0c;添加下面的内容到文件中并…

什么是字体图标(Icon Font)?如何在网页中使用字体图标?

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ 字体图标&#xff08;Icon Font&#xff09;⭐ 如何在网页中使用字体图标⭐ 写在最后 ⭐ 专栏简介 前端入门之旅&#xff1a;探索Web开发的奇妙世界 记得点击上方或者右侧链接订阅本专栏哦 几何带你启航前端之旅 欢迎来到前端入门之旅&a…

基于MATLAB/Simulink的三相并网逆变器dq阻抗建模及扫频仿真

目录 整体系统介绍理论模型MATLAB实现 基于Simulink的阻抗扫频仿真整体思路注意事项流程框图 其他 本文主要介绍三相并网逆变器dq阻抗建模的相关知识&#xff0c;和大家分享一下怎么使用MATLAB/Simulink来进行理论模型的搭建以及如何通过扫频获取阻抗模型&#xff0c;一方面是给…

Node基础--npm相关内容

下面,我们一起来看看Node中的至关重要的一个知识点-----npm 1.npm概述 npm(Node Package Manager),CommonJS包规范是理论,npm是其中一种实践。 对于Node而言,NPM帮助其完成了第三方模块的发布、安装和依赖等。借助npm,Node与第三方模块之间形成了很好的一个 生态系统。(类…

Netty简易聊天室

文章目录 本文目的参考说明环境说明maven依赖日志配置单元测试 功能介绍开发步骤 本文目的 通过一个简易的聊天室案例&#xff0c;讲述Netty的基本使用。同时分享案例代码。项目中用到了log4j2&#xff0c;junit5&#xff0c;同时分享这些基础组件的使用。项目中用到了awt&…

Rust踩雷笔记(5)——刷点链表的题(涉及智能指针Box,持续更新)

目录 leetcode 2 两数相加——模式匹配单链表Box 只能说Rust链表题的画风和C完全不一样&#xff0c;作为新手一时间还不太适应&#xff0c;于是单独为链表、智能指针开一篇&#xff0c;主要记录leetcode相关题型的答案以及注意事项。 leetcode 2 两数相加——模式匹配单链表Bo…

小程序开发之登录授权

小程序开发登录授权流程 看懂这张图登录授权就没问题了&#xff08;哈哈哈哈哈&#xff09; 说明&#xff1a; 调用 wx.login() 获取 临时登录凭证code &#xff0c;并回传到开发者服务器。 调用 auth.code2Session 接口&#xff0c;换取 用户唯一标识 OpenID 和 会话密钥 sess…

【C++】Cmake使用教程(看这一篇就够了)

文章目录 引言一 环境搭建二 简单入门2.1 项目结构2.2 示例源码2.3 运行查看 三 编译多个源文件3.1 在同一个目录下有多个源文件3.1.1 简单版本3.1.1.1 项目结构3.1.1.2 示例代码3.1.1.3 运行查看 3.1.2 进阶版本3.1.2.1 项目结构3.1.2.2 示例源码3.1.2.3 运行查看 3.2 在不同目…

TabView 初始化与自定义 TabBar 属性相关

SWift TabView 与 UIKit 中的 UITabBarController 如出一辙.在 TabView 组件中配置对应的图片和标题; 其中,Tag 用来设置不同 TabView 可动态设置当前可见 Tab;另也有一些常用的属性与 UIKit 中的类似,具体可以按需参考 api 中属性进行单独修改定制; 在 iOS 15.0 之后还可设置角…

系列十一、AOP

一、概述 1.1、官网 AOP的中文名称是面向切面编程或者面向方面编程&#xff0c;利用AOP可以对业务逻辑的各个部分进行隔离&#xff0c;从而使得业务逻辑各部分之间的耦合度降低&#xff0c;提高程序的可重用性&#xff0c;同时提高了开发的效率。 1.2、通俗描述 不通过…

[JDK8环境下的HashMap类应用及源码分析] capacity实验

🌹作者主页:青花锁 🌹简介:Java领域优质创作者🏆、Java微服务架构公号作者😄、CSDN博客专家 🌹简历模板、学习资料、面试题库、技术互助 🌹文末获取联系方式 📝 系列文章目录 [Java基础] StringBuffer 和 StringBuilder 类应用及源码分析 [Java基础] 数组应用…

横扫“盲区”、“看透”缺陷,维视智造推出短波红外相机

在可见光领域&#xff0c;工业相机的视觉应用已经十分成熟&#xff0c;但在日常的客户咨询中&#xff0c;我们也经常接到一些“超纲需求”——客户想要检测“白底上的白色缺陷”、“不透明包装内的透明物体有无”等&#xff0c;均属于可见光无法实现的检测&#xff0c;而市面上…

API接口开发管理平台--多领域企业数字化管理的解决方案

随着数字化时代的到来&#xff0c;企业需要进行数字化转型才能更好地适应市场需求和用户需求。而API接口则是数字化转型中的重要组成部分&#xff0c;可以帮助企业更好地管理信息&#xff0c;提高效率。本文将介绍一种解决方案--API接口开发管理平台&#xff0c;该平台开发出多…

无线 CBTC 与互联互通

1 无线 CBTC 的技术与经济优势 由于无线 CB TC 可采用移动闭塞的制式&#xff0c;列车能以较小的间隔运行&#xff0c;可使运 营商实现“ 高密度”的运营模式&#xff0c;这使系统可在同样满足客运需求的基 础上&#xff0c;缩短旅客的候车时间&#xff0c;缩小站台长度和候…

微信小程序开发教学系列(13)- 小程序支付与电商功能

13. 小程序支付与电商功能 小程序支付与电商功能是微信小程序中非常重要的一部分&#xff0c;通过支付功能可以实现用户购买商品、支付订单等功能。下面将介绍小程序支付的接入和配置以及实现电商功能和订单管理的方法。 13.1 小程序支付的接入和配置 13.1.1 获取微信支付商…

Pygame编程(1)初始化和退出模块

初始化和退出模块 pygame使用基础流程 初始化模块设置主屏窗口程序主循环&#xff08;处理键盘、鼠标、游戏杆、触摸屏等事件&#xff09;退出模块终止程序 import sys import pygame from pygame.locals import *# 1.初始化模块 pygame.init()# 2.设置主屏窗口 display pyg…