人工智能(pytorch)搭建模型23-pytorch搭建生成对抗网络(GAN):手写数字生成的项目应用

大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型23-pytorch搭建生成对抗网络(GAN):手写数字生成的项目应用。生成对抗网络(GAN)是一种强大的生成模型,在手写数字生成方面具有广泛的应用前景。通过生成逼真的手写数字图像,GAN可以用于数据增强、图像修复、风格迁移等任务,提高模型的性能和泛化能力。生成对抗网络在手写数字生成领域具有广泛的应用前景。主要应用场景包括数据增强、图像修复、风格迁移和跨领域生成。数据增强可以通过生成逼真的手写数字图像,为训练数据集提供更多的样本,提高模型的泛化能力。

一、项目背景

随着深度学习技术的不断发展,生成模型在计算机视觉、自然语言处理等领域取得了显著的成果。生成对抗网络(GAN)作为一种新兴的生成模型,近年来备受关注。在手写数字生成方面,GAN可以生成逼真的手写数字图像,为数据增强、图像修复等任务提供有力支持。

二、生成对抗网络原理

生成对抗网络(GAN)由Goodfellow等人于2014年提出,它由两个神经网络——生成器(Generator)和判别器(Discriminator)——组成。生成器的目标是生成逼真的假样本,而判别器的目标是区分真实样本和生成器生成的假样本。在训练过程中,生成器和判别器相互竞争,不断调整参数,以达到纳什均衡。
GAN的目标是最小化以下价值函数:
min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p data ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log (1 - D(G(z)))] GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]
其中, G G G表示生成器, D D D表示判别器, x x x表示真实样本, z z z表示生成器的输入噪声, p data p_{\text{data}} pdata表示真实数据分布, p z p_z pz表示噪声分布。
在这里插入图片描述

三、生成对抗网络应用场景

生成对抗网络(GAN)在手写数字生成领域的应用具有广泛的前景。以下是几个主要的应用场景:
1.数据增强:通过生成逼真的手写数字图像,GAN可以为训练数据集提供更多的样本,提高模型的泛化能力。
2. 图像修复:GAN可以用于修复损坏或缺失的手写数字图像,提高图像的质量和可读性。
3. 风格迁移:GAN可以将一种手写风格转换为另一种风格,为个性化手写数字生成提供可能。
4. 跨领域生成:GAN可以实现不同手写数字数据集之间的转换,为多任务学习提供支持。

四、生成对抗网络实现手写数字生成

下面我将利用pytorch深度学习框架构建生成对抗网络的生成器模型Generator、判别器模型Discriminator。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image# 超参数设置
batch_size = 128
learning_rate = 0.0002
num_epochs = 80# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 下载并加载训练数据
train_data = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)# 定义生成器模型
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.model = nn.Sequential(nn.Linear(100, 256),nn.LeakyReLU(0.2),nn.Linear(256, 512),nn.LeakyReLU(0.2),nn.Linear(512, 1024),nn.LeakyReLU(0.2),nn.Linear(1024, 28*28),nn.Tanh())def forward(self, x):return self.model(x).view(x.size(0), 1, 28, 28)# 定义判别器模型
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(28*28, 1024),nn.LeakyReLU(0.2),nn.Dropout(0.3),nn.Linear(1024, 512),nn.LeakyReLU(0.2),nn.Dropout(0.3),nn.Linear(512, 256),nn.LeakyReLU(0.2),nn.Dropout(0.3),nn.Linear(256, 1),nn.Sigmoid())def forward(self, x):x = x.view(x.size(0), -1)return self.model(x)# 初始化模型
generator = Generator()
discriminator = Discriminator()# 损失函数和优化器
criterion = nn.BCELoss()
optimizerG = optim.Adam(generator.parameters(), lr=learning_rate)
optimizerD = optim.Adam(discriminator.parameters(), lr=learning_rate)# 训练模型
for epoch in range(num_epochs):for i, (images, _) in enumerate(train_loader):# 确保标签的大小与当前批次的数据大小一致real_labels = torch.ones(images.size(0), 1)fake_labels = torch.zeros(images.size(0), 1)# 训练判别器optimizerD.zero_grad()real_outputs = discriminator(images)d_loss_real = criterion(real_outputs, real_labels)z = torch.randn(images.size(0), 100)fake_images = generator(z)fake_outputs = discriminator(fake_images.detach())d_loss_fake = criterion(fake_outputs, fake_labels)d_loss = d_loss_real + d_loss_faked_loss.backward()optimizerD.step()# 训练生成器optimizerG.zero_grad()fake_images = generator(z)fake_outputs = discriminator(fake_images)g_loss = criterion(fake_outputs, real_labels)g_loss.backward()optimizerG.step()if (i+1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], d_loss: {d_loss.item()}, g_loss: {g_loss.item()}')# 保存生成器生成的图片save_image(fake_images.data[:25], './fake_images/fake_images-{}.png'.format(epoch+1), nrow=5, normalize=True)# 保存模型
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')

最后我们打开fake_images/文件夹,可以看到生成手写图片的过程:
在这里插入图片描述

五、总结

本项目利用生成对抗网络(GAN)实现了手写数字的生成。通过训练生成器和判别器,我们成功生成了逼真的手写数字图像。这些生成的图像可以应用于数据增强、图像修复、风格迁移等领域,为手写数字识别等相关任务提供有力支持。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/657486.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

类和对象 第六部分 继承 第一部分:继承的语法

一.继承的概念 继承是面向对象的三大特性之一 有些类与类之间存在特殊的关系,例如下图: 我们可以发现,下级别的成员除了拥有上一级的共性,还有自己的特性,这个时候,我们可以讨论利用继承的技术,…

【前端素材】bootstrap3 实现地产置业公司source网页设计

一、需求分析 地产置业公司的网页通常是该公司的官方网站,旨在向访问者提供相关信息和服务。这些网页通常具有以下功能: 公司介绍:网页通常包含有关公司背景、历史、核心价值观和使命等方面的信息。此部分帮助访问者了解公司的身份和目标。 …

LabVIEW船舶自动识别系统

在现代航海领域,安全高效的船舶自动识别系统对于保障航行安全和提高船舶管理效率非常重要。介绍了利用LabVIEW软件开发的一个船舶自动识别系统,该系统通过先进的数据采集和信号处理技术,显著提升了传统自动识别系统的性能。 这个船舶自动识别…

代理IP在游戏中的作用有哪些?

游戏代理IP的作用是什么?IP代理软件相当于连接客户端和虚拟服务器的软件“中转站”,在我们向远程服务器提出需求后,代理服务器首先获得用户的请求,然后将服务请求转移到远程服务器,然后将远程服务器反馈的结果转移到客…

【lesson1】高并发内存池项目介绍

文章目录 这个项目做的是什么?这个项目的要求的知识储备和难度?什么是内存池池化技术内存池内存池主要解决的问题malloc 这个项目做的是什么? 当前项目是实现一个高并发的内存池,他的原型是google的一个开源项目tcmalloc&#xf…

江西省考报名照不能成功上传的原因

江西省考报名照需要根据以下要求生成: 1、近期6个月,免冠证件照 2、照片背景白背景 3、照片文件jpg格式,20-100kb 4、照片像素大小,295x413像素 5、照片必须使用审核工具审核后才能上传

【Java】内存溢出和内存泄露的区别

目录 概念 内存溢出分类 内存泄漏分类 发生场景以及解决方法 内存溢出 内存泄漏 解决方法 这道题是面试常考的,一定要区分好区别,我之前就是直接认为内存溢出就是内存泄漏了 概念 内存溢出:是指程序在申请内存时,没有足够…

React18-模拟列表数据实现基础表格功能

文章目录 分页功能分页组件有两种接口参数分页类型用户列表参数类型 模拟列表数据分页触发方式实现目录 分页功能 分页组件有两种 table组件自带分页 <TableborderedrowKey"userId"rowSelection{{ type: checkbox }}pagination{{position: [bottomRight],pageSi…

海外云手机对于亚马逊卖家的作用

近年来&#xff0c;海外云手机作为一种新型模式迅速崭露头角&#xff0c;成为专业的出海SaaS平台软件。海外云手机在云端运行和存储数据&#xff0c;通过网页端操作&#xff0c;将手机芯片放置在机房&#xff0c;通过网络连接到服务器&#xff0c;为用户提供便捷的上网功能。因…

Spring Boot 中文件上传

Spring Boot 中文件上传 一、MultipartFile二、单文件上传案例三、多文件上传案例四、Servlet 规范五、Servlet 规范实现文件上传 上传文件大家用的最多的就是 Apache Commons FileUpload&#xff0c;这个库使用非常广泛。Spring Boot3 版本中已经不能使用了。代替它的是 Sprin…

苍穹外卖项目可以写的简历和如何优化简历

文章目录 重点写中规写添加自己个性的项目面试会问道的问题 我是一名双非大二计算机本科生&#xff0c;希望我的分享对你有帮助&#xff0c;点赞关注不迷路。 简历编写一直是很多人求职人的心病&#xff0c;我自己上学期有一门课程是去校内企业面试&#xff0c;当时我就感受出…

MySQL库表操作 作业

题目&#xff1a; 1. sql语句分为几类?2. 表的约束有哪些,分别是什么,设置的语法分别是什么?3. 做出班级表,学生表的E-R图,数据库模型图,以及核心的sql语句. 1. MySQL致力于支持全套ANSI/ISO SQL标准。在MySQL数据库中&#xff0c;SQL语句主要可以划分为以下几类: > DD…

【三维重建】三角化

三角化要解决的问题是&#xff1a; 已知两个相机的内参K、K、相机之间的旋转平移矩阵R、t以及匹配点p、p&#xff0c;如何求得P点的三维坐标&#xff1f; 线性解法

彻底解决 MAC Android Studio gradle async 时出现 “connect timed out“ 问题

最近在编译一个比较老的项目&#xff0c;git clone 之后使用 async 之后出现一下现象&#xff1a; 首先确定是我网络本身是没有问题的&#xff0c;尝试几次重新 async 之后还是出现问题&#xff0c;网上找了一些方法解决了本问题&#xff0c;以此来记录一下问题是如何解决的。 …

【Tomcat与网络5】再论Tomcat的工作过程与两种经典的设计模式

前面两篇&#xff0c;我们重点分析了Tomcat的容器和连接器的基本设计&#xff0c;今天我们来看一下两个机构如何在service的调度下进行协同工作的。 目录 1.模板模式与Tomcat的重用性设计 2.观察者模式与Tomcat可扩展性设计 1.模板模式与Tomcat的重用性设计 首先&#xff0…

油分离器的介绍

压缩机的排气中带有冷冻机油&#xff0c;这些冷冻机油如果随制冷剂蒸汽进入冷凝器、蒸发器后将 在传热表面形成油膜&#xff0c;从而影响换热效果。因此通常在压缩机与冷凝器之间装设油分离器&#xff0c;用 来分离制冷剂蒸汽中挟带的冷冻机油。在氟利昂制冷系统中&#xff0c;…

读AI3.0笔记10_读后总结与感想兼导读

1. 基本信息 AI 3.0 (美)梅拉妮米歇尔 著 四川科学技术出版社,2021年2月出版 1.1. 读薄率 书籍总字数355千字&#xff0c;笔记总字数33830字。 读薄率33830355000≈9.53% 1.2. 读厚方向 千脑智能 脑机穿越 未来呼啸而来 虚拟人 新机器人 如何创造可信的AI 新机器智…

sklearn 计算 tfidf 得到每个词分数

from sklearn.feature_extraction.text import TfidfVectorizer# 语料库 可以换为其它同样形式的单词 corpus [list(range(-5, 5)),list(range(-6,4)),list(range(12)),list(range(13))]# corpus [ # [Two, wrongs, don\t, make, a, right, .], # [The, pen, is, might…

CGAL5.4.1 边塌陷算法

目录 1、使用曲面网格的示例 2、使用默认多面体的示例 3、使用丰富多面体的示例 主要对1、使用曲面网格的示例 进行深度研究 CGAL编译与安装CGAL安装到验证到深入_cgal测试代码-CSDN博客 参考资料CGAL 5.4.5 - Triangulated Surface Mesh Simplification: User Manual …

云原生 k8s 可能使用到的端口整理【不定期更新】

k8s 因为涉及到的组件太多了&#xff0c;所以端口有很多&#xff0c;这里整理了日常所接触的接口&#xff0c;后续有新的再更新。 如果是通过公网 IP 进行安装的时候需要根据实际情况有选择的进行放开&#xff1b;一般只有云厂商会提供公网 IP 访问&#xff0c;自建的话不建议 …