【图像合成】基于DCGAN典型网络的MNIST字符生成(pytorch)

关于

 

近年来,基于卷积网络(CNN)的监督学习已经 在计算机视觉应用中得到了广泛的采用。相比之下,无监督 使用 CNN 进行学习受到的关注较少。在这项工作中,我们希望能有所帮助 缩小了 CNN 在监督学习和无监督学习方面的成功之间的差距。我们介绍一类称为深度卷积生成的 CNN 对抗性网络(DCGAN),具有一定的架构限制,以及 证明他们是无监督学习的有力候选人。训练 在各种图像数据集上,我们展示了令人信服的证据,表明我们的深度卷积对抗对学习了从对象部分到 生成器和鉴别器中的场景。此外,我们使用学到的 新任务的特征 - 证明它们作为一般图像表示的适用性。(https://arxiv.org/pdf/1511.06434.pdf)

工具

 数据集

方法实现

加载必要的库函数和自定义函数

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as Ffrom torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
def get_sample_image(G, n_noise):"""save sample 100 images"""z = torch.randn(100, n_noise).to(DEVICE)y_hat = G(z).view(100, 28, 28) # (100, 28, 28)result = y_hat.cpu().data.numpy()img = np.zeros([280, 280])for j in range(10):img[j*28:(j+1)*28] = np.concatenate([x for x in result[j*10:(j+1)*10]], axis=-1)return img

定义判别模型

class Discriminator(nn.Module):"""Convolutional Discriminator for MNIST"""def __init__(self, in_channel=1, num_classes=1):super(Discriminator, self).__init__()self.conv = nn.Sequential(# 28 -> 14nn.Conv2d(in_channel, 512, 3, stride=2, padding=1, bias=False),nn.BatchNorm2d(512),nn.LeakyReLU(0.2),# 14 -> 7nn.Conv2d(512, 256, 3, stride=2, padding=1, bias=False),nn.BatchNorm2d(256),nn.LeakyReLU(0.2),# 7 -> 4nn.Conv2d(256, 128, 3, stride=2, padding=1, bias=False),nn.BatchNorm2d(128),nn.LeakyReLU(0.2),nn.AvgPool2d(4),)self.fc = nn.Sequential(# reshape input, 128 -> 1nn.Linear(128, 1),nn.Sigmoid(),)def forward(self, x, y=None):y_ = self.conv(x)y_ = y_.view(y_.size(0), -1)y_ = self.fc(y_)return y_

定义生成模型

class Generator(nn.Module):"""Convolutional Generator for MNIST"""def __init__(self, input_size=100, num_classes=784):super(Generator, self).__init__()self.fc = nn.Sequential(nn.Linear(input_size, 4*4*512),nn.ReLU(),)self.conv = nn.Sequential(# input: 4 by 4, output: 7 by 7nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, bias=False),nn.BatchNorm2d(256),nn.ReLU(),# input: 7 by 7, output: 14 by 14nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),nn.BatchNorm2d(128),nn.ReLU(),# input: 14 by 14, output: 28 by 28nn.ConvTranspose2d(128, 1, 4, stride=2, padding=1, bias=False),nn.Tanh(),)def forward(self, x, y=None):x = x.view(x.size(0), -1)y_ = self.fc(x)y_ = y_.view(y_.size(0), 512, 4, 4)y_ = self.conv(y_)return y_

 模型超参数定义配置

batch_size = 64criterion = nn.BCELoss()
D_opt = torch.optim.Adam(D.parameters(), lr=0.001, betas=(0.5, 0.999))
G_opt = torch.optim.Adam(G.parameters(), lr=0.001, betas=(0.5, 0.999))max_epoch = 30 # need more than 20 epochs for training generator
step = 0
n_critic = 1 # for training more k steps about Discriminator
n_noise = 100D_labels = torch.ones([batch_size, 1]).to(DEVICE) # Discriminator Label to real
D_fakes = torch.zeros([batch_size, 1]).to(DEVICE) # Discriminator Label to fake

 模型训练

for epoch in range(max_epoch):for idx, (images, labels) in enumerate(data_loader):# Training Discriminatorx = images.to(DEVICE)x_outputs = D(x)D_x_loss = criterion(x_outputs, D_labels)z = torch.randn(batch_size, n_noise).to(DEVICE)z_outputs = D(G(z))D_z_loss = criterion(z_outputs, D_fakes)D_loss = D_x_loss + D_z_lossD.zero_grad()D_loss.backward()D_opt.step()if step % n_critic == 0:# Training Generatorz = torch.randn(batch_size, n_noise).to(DEVICE)z_outputs = D(G(z))G_loss = criterion(z_outputs, D_labels)D.zero_grad()G.zero_grad()G_loss.backward()G_opt.step()if step % 500 == 0:print('Epoch: {}/{}, Step: {}, D Loss: {}, G Loss: {}'.format(epoch, max_epoch, step, D_loss.item(), G_loss.item()))if step % 1000 == 0:G.eval()img = get_sample_image(G, n_noise)imsave('./{}_step{}.jpg'.format(MODEL_NAME, str(step).zfill(3)), img, cmap='gray')G.train()step += 1

测试生成效果

# generation to image
G.eval()
imshow(get_sample_image(G, n_noise), cmap='gray')

 

模型和状态参量保存

def save_checkpoint(state, file_name='checkpoint.pth.tar'):torch.save(state, file_name)# Saving params.
# torch.save(D.state_dict(), 'D_c.pkl')
# torch.save(G.state_dict(), 'G_c.pkl')
save_checkpoint({'epoch': epoch + 1, 'state_dict':D.state_dict(), 'optimizer' : D_opt.state_dict()}, 'D_dc.pth.tar')
save_checkpoint({'epoch': epoch + 1, 'state_dict':G.state_dict(), 'optimizer' : G_opt.state_dict()}, 'G_dc.pth.tar')

应用

DCGAN作为一个成熟的生成模型,在自然图像,医学图像,医学电生理信号数据分析中,都可以用来实现数据的合成,达到数据增强的目的,同时,如何减少增强数据对于后端任务的不利干扰,也是一个需要关注的方面。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/777246.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

aws使用记录

数据传输(S3) 安装命令行 安装awscli: https://docs.aws.amazon.com/zh_cn/cli/latest/userguide/getting-started-install.html#getting-started-install-instructions 直到 aws configure list 可以运行 身份验证: 运行: aws config…

字符计数.

题目描述 给定一个单词,请计算这个单词中有多少个元音字母,多少个辅音字母。 元音字母包括 a,e,i,o,u,共五个,其他均为辅音字母。 输入描述 输入格式 输入一行,包含一个单词,!单词中只包含小写英文字母。单…

【QGIS从shp文件中筛选目标区域导出为shp】

文章目录 1、写在前面2、QGIS将shp文件中目标区域输出为shp2.1、手动点选2.2、高级过滤 3、上述shp完成后,配合python的shp文件,即可凸显研究区域了 1、写在前面 利用shp文件制作研究区域mask,Matlab版本,请点击 Matlab利用shp文…

网络编程综合项目-多用户通信系统

文章目录 1.项目所用技术栈本项目使用了java基础,面向对象,集合,泛型,IO流,多线程,Tcp字节流编程的技术 2.通信系统整体分析主要思路(自己理解)1.如果不用多线程2.使用多线程3.对多线…

uniapp-Form示例(uviewPlus)

示例说明 Vue版本&#xff1a;vue3 组件&#xff1a;uviewPlus&#xff08;Form 表单 | uview-plus 3.0 - 全面兼容nvue的uni-app生态框架 - uni-app UI框架&#xff09; 说明&#xff1a;表单组建、表单验证、提交验证等&#xff1b; 截图&#xff1a; 示例代码 <templat…

3月28日,每日信息差

&#x1f384; 亚马逊未来15年或斥资1480亿美元在全球建设并运营数据中心 &#x1f30d; 中国移动正式商用5.5G 网络&#xff0c;Find X7系列成为首款5.5G手机 &#x1f30b; UC网盘官宣上传、下载不限速 &#x1f381; 优酷发布行业首款影视制作车 ✨ 阿里云联发科联手为手机芯…

O2OA(翱途)开发平台-快速入门开发一个门户实例

O2OA(翱途)开发平台[下称O2OA开发平台或者O2OA]拥有门户页面定制与集成的能力&#xff0c;平台通过门户定制&#xff0c;可以根据企业的文化&#xff0c;业务需要设计符合企业需要的统一信息门户&#xff0c;系统首页等UI界面。本篇主要介绍通过门户管理系统如何快速的进行一个…

第一次开发微信小程序的总结与心得

我们的小程序开发团队有三个人——两个前端和一个后端&#xff0c;一个没有产品经理和UI设计师的队伍&#xff0c;一个小程序开发零经验的队伍&#xff0c;却需要完成需求分析、页面设计、代码编写、功能测试、小程序上线的整个过程&#xff0c;所以在开发过程中&#xff0c;我…

团队建设与管理案例分析题

习题一 最近A公司接了一个信息系统运维项目&#xff0c;而且非常重视&#xff0c;任命了有丰富售后服务张某为系统规划与管理师&#xff0c;全权授权张某负责该项目&#xff0c;并要求他负责企业运维服务能力建设和提升。张某也学习了大量项目管理知识和运维管理知识&#xff…

刷题有一个疑问

public class Main {public static void main(String[] args) { // boolean a zhengchu(2,10); // System.out.print(a); // boolean a btrue(126); // System.out.print(a);int decimalNumber 10; // 十进制数String binaryString Integer.toS…

学点儿Java_Day12_IO流

1 IO介绍以及分类 IO: Input Output 流是一组有顺序的&#xff0c;有起点和终点的字节集合&#xff0c;是对数据传输的总称或抽象。即数据在两设备间的传输称为流&#xff0c;流的本质是数据传输&#xff0c;根据数据传输特性将流抽象为各种类&#xff0c;方便更直观的进行数据…

C++取经之路(其二)——含数重载,引用。

含数重载: 函数重载是指&#xff1a;在c中&#xff0c;在同一作用域&#xff0c;函数名相同&#xff0c;形参列表不相同(参数个数&#xff0c;或类型&#xff0c;或顺序)不同&#xff0c;C语言不支持。 举几个例子&#xff1a; 1.参数类型不同 int Add(int left, int right)…

【任职资格】某大型制造型企业任职资格体系项目纪实

该企业以业绩、责任、能力为导向&#xff0c;确定了分层分类的整体薪酬模式&#xff0c;但是每一名员工到底应该拿多少工资&#xff0c;同一个岗位的人员是否应该拿同样的工资是管理人员比较头疼的事情。华恒智信顾问认为&#xff0c;通过任职资格评价能实现真正的人岗匹配&…

基于Transformer的医学图像分类研究

医学图像分类目前面临的挑战 医学图像分类需要研究人员同时具备医学图像分析和数字图像的知识背景。由于图像尺度、数据格式和数据类别分布的影响&#xff0c;现有的模型方法&#xff0c;如传统的机器学习的识别方法和基于深度卷积神经网络的方法&#xff0c;取得的识别准确度…

微软AI 程序员AutoDev,自主执行工程任务生成代码

全球首个 AI 程序员 Devin 的横空出世&#xff0c;可能成为软件和 AI 发展史上一个重要的节点。它掌握了全栈的技能&#xff0c;不仅可以写代码 debug&#xff0c;训模型&#xff0c;还可以去美国最大求职网站 Upwork 上抢单。 Devin 诞生之后&#xff0c;让码农纷纷恐慌。最近…

银河麒麟 v10 sp2 aarch64架构制作openssh 9.7p1 rpm包(显示openssl版本信息)—— 筑梦之路

【国产化适配】银河麒麟v10 sp2 aarch64 制作openssh 9.6p1 rpm——筑梦之路_openssh 9.6ky10-CSDN博客 之前做过openssh 9.6 p1 rpm包&#xff0c;使用的是官方的spec文件&#xff0c;没有修改过&#xff0c;不过最新版9.7已经默认不使用openssl&#xff0c;因此制作出来的rp…

详解 WebWorker 的概念、使用场景、示例

前言 提到 WebWorker,可能有些小伙伴比较陌生,不知道是做什么的,甚至不知道使用场景,今天这篇文章就带大家一起简单了解一下什么是 webworker! 概念 WebWorker 实际上是运行在浏览器后台的一个单独的线程,因此可以执行一些耗时的操作而不会阻塞主线程。WebWorker 通过…

智慧光伏:企业无纸化办公

随着科技的快速发展&#xff0c;光伏技术不仅成为推动绿色能源革命的重要力量&#xff0c;更在企业办公环境中扮演起引领无纸化办公的重要角色。智慧光伏不仅为企业提供了清洁、可持续的能源&#xff0c;更通过智能化的管理方式&#xff0c;推动企业向无纸化办公转型&#xff0…

题目:古典问题:有一对兔子,从出生后第3个月起每个月都生一对兔子,小兔子长到第三个月 后每个月又生一对兔子,假如兔子都不死,问每个月的兔子总数为多少?

题目&#xff1a;古典问题&#xff1a;有一对兔子&#xff0c;从出生后第3个月起每个月都生一对兔子&#xff0c;小兔子长到第三个月    后每个月又生一对兔子&#xff0c;假如兔子都不死&#xff0c;问每个月的兔子总数为多少&#xff1f; There is no nutrition in the bl…

Linux小程序: 手写自己的shell

注意&#xff1a; 本文章只是为了理解shell内部的工作原理&#xff0c; 所以并没有完成shell的所有工作&#xff0c; 只是完成了shell里的一小部分工作 #include <stdio.h> #include <unistd.h> #include <string.h> #include <stdlib.h> #include &l…