pytorch实战-图像生成与对抗

1 概述

what:给定一句话,或一些要求,按要求生成需要的图像。

本篇总结主要包含反卷积和GAN(generative adversial network, GAN)

2 反卷积与图像生成

what:反卷积可以看成卷积的反操作,但不完全一样,不是把卷积反过来就是反卷积。即给定特征,反向生成输入。但反卷积运算的卷积核与卷积运算的不同

效果:卷积是大图像越来越小,反卷积可以图像越来越大

2.1 反卷积运算

卷积核不同:卷积卷积核旋转180度可得到反卷积运算的卷积核

padding:如果希望反卷积运算后,图像大小保持不变,需要计算padding并给输入图像补padding

2.2 反池化运算

反池化有很多方法,有一种卷积运算方法可以近似省略池化(因为效果相近),即给卷积运算加步伐。即每一个卷积核在原图像运算完,朝下一个运算窗口移动的步数。默认步数是1.步数大于1的效果很接近卷积+池化运算效果。这样的卷积运算,可以看成步数为1的卷积运算+池化运算,即省略了池化运算

步伐>2的卷积效果:卷积得到的图像比步伐小的图像更小。因此反卷积时,也需要处理此种情况

2.3 反卷积和分数步伐

步伐>2的卷积,可以通过分数步伐的反卷积恢复。即对输入图像每个像素点之间补充空白点,卷积步长越大,反卷积补的像素间空白点就越多

2.4 批正则化技术

概念:是每一层神经网络层和非线性运算层之间加入的一个线性运算层,逻辑为y=ax+b。a,b为要学习的参数,x为一批里归一化处理后的输入:(x-mean(x))/std

3 图像生成-最小均方差模型

3.1 思路

输入是一个数字,输出是一个数字的手写图像。通过反卷积网络实现这样的输入与输出

3.2 代码实现

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as Fimport torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as utilimport matplotlib.pyplot as pyplot
import numpy as np
import osoutput_img_size = 28
input_dim = 100
channel_num = 1
features_num = 64
batch_size = 64print(f'prepare datasets begin')
use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
itype = torch.cuda.LongTensor if use_cuda else torch.LongTensortrain_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
index_verify = range(len(test_dataset))[:5000]
index_test = range(len(test_dataset))[5000:]sampler_verify = torch.utils.data.sampler.SubsetRandomSampler(index_verify)
sampler_test = torch.utils.data.sampler.SubsetRandomSampler(index_test)verify_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, sampler=sampler_verify)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, sampler=sampler_test)class AntiCNN(nn.Module):def __init__(self):super(AntiCNN, self).__init__()self.model = nn.Sequential()self.model.add_module('deconv1', nn.ConvTranspose2d(input_dim, features_num * 2, 5, 2, 0, bias=False))self.model.add_module('batch_norm1', nn.BatchNorm2d(features_num * 2))self.model.add_module('relu1', nn.ReLU(True))self.model.add_module('deconv2', nn.ConvTranspose2d(features_num * 2, features_num, 5, 2, 0, bias=False))self.model.add_module('batch_norm2', nn.BatchNorm2d(features_num))self.model.add_module('relu2', nn.ReLU(True))self.model.add_module('deconv3', nn.ConvTranspose2d(features_num, channel_num, 4, 2, 0, bias=False))self.model.add_module('sigmoid', nn.Sigmoid())def forward(self, input):output = inputfor _, module in self.model.named_children():output = module(output)return outputdef weight_init(module):class_name = module.__class__.__name__if class_name.find('conv') != -1:module.weight.data.normal_(0, 0.02) # convey mean and stdif class_name.find('norm') != -1:module.weight.data.normal_(1, 0.02)def resize_to_img(img):return img.data.expand(batch_size, 3, output_img_size, output_img_size)def imgshow(input, title=None):if input.size()[0] > 1:input = input.numpy().transpose((1, 2, 0))else:input = input[0].numpy()min_val, max_val = np.amin(input), np.amax(input)if max_val > min_val:input = (input - min_val) / (max_val - min_val)pyplot.imshow(input)if title:pyplot.title(title)pyplot.pause(0.001)def main():net = AntiCNN()net = net.cuda() if use_cuda else netcriterion = nn.MSELoss()optimizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)samples = np.random.choice(10, batch_size)samples = torch.from_numpy(samples).type(dtype)step = 0num_epoch = 2record = []print('train begin')for epoch in range(num_epoch):print(f'the no.{epoch} epoch')train_loss = []for batch_index, (data, target) in enumerate(train_loader):target, data = data.clone().detach().requires_grad_(True), target.clone().detach()#target, data = target.cuda(), data.cuda() if use_cuda else target, dataif use_cuda:target, data = target.cuda(), data.cuda()data = data.type(dtype)data = data.resize(data.size()[0], 1, 1, 1)data = data.expand(data.size()[0], input_dim, 1, 1)net.train()output = net(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()step += 1loss = loss.cpu() if use_cuda else losstrain_loss.append(loss.data.numpy())if batch_index % 300 == 0:net.eval()verify_loss = []index = 0for data, target in verify_loader:target, data = data.clone().detach().requires_grad_(True), target.clone().detach()index += 1# target, data = target.cuda(), data.cuda() if use_cuda else target, dataif use_cuda:target, data = target.cuda(), data.cuda()data = data.type(dtype)data = data.resize(data.size()[0], 1, 1, 1)data = data.expand(data.size()[0], input_dim, 1, 1)output = net(data)loss = criterion(output, target)loss = loss.cpu() if use_cuda else lossverify_loss.append(loss.data.numpy())print(f'now no.{batch_index} batch. train loss:{np.mean(train_loss):.4f}, verify loss:{np.mean(verify_loss):.4f}')record.append([np.mean(train_loss), np.mean(verify_loss)])with torch.no_grad():samples.resize_(batch_size, 1, 1, 1)samples = samples.data.expand(batch_size, input_dim, 1, 1)# samples = samples.cuda() if use_cuda else samplesif use_cuda:samples = samples.cuda()fake_u = net(samples)# fake_u = fake_u.cuda() if use_cuda else fake_uif use_cuda:fake_u = fake_u.cuda()img = resize_to_img(fake_u)os.makedirs(os.path.realpath('./pytorch/jizhi/image_generate/temp1'), exist_ok=True)util.save_image(img, os.path.realpath(f'./pytorch/jizhi/image_generate/temp1/fake{epoch}.png'))pyplot.show()if __name__ == '__main__':main()

发现图片很模糊,可能是均方误差算的是所有手写数字的平均值,且每个图像没有明显模式,倒是平均值就是很模糊。咋整呢?可以尝试用之前的手写数字图像识别器帮助矫正MSE

4 图像生成-生成器-识别器模型

5 图像生成-GAN

6 小结

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/647818.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

紫光展锐T760_芯片性能介绍_展锐T760安卓核心板定制

展锐T760核心板是一款基于国产5G芯片的智能模块,采用紫光展锐T760制程工艺为台积电6nm工艺,支持工艺具有出色的能效表现。其采用主流的44架构的八核设计,包括4颗2.2GHz A76核心和4颗A55核心设计,内存单元板载可达8GB Ram256GB ROM…

uniapp vuecli项目融合[小记]:将多个项目融合,打包成一个小程序/App,拆分多个H5应用

前言: 目前两个uniapp vuecli开发的项目【A、B】,新规划的项目C:需要融合项目B 80%的功能模块,同时也需要涵盖项目A的所有功能模块。 应用需求: 1、新项目C【小程序】可支持切换到应用A/C界面【内部通过初始化、路由跳…

0125-1-vue3初体验

vue3尝鲜体验 初始化 安装vue/clinext: yarn global add vue/clinext # OR npm install -g vue/clinext然后在 Vue 项目运行: vue upgrade --next项目目录 vue3-template ├── index.html // html模板 ├── mock // mock数据 │ └── user.…

qt学习:QListWidget控件+自定义条目项+双击删除+单击获取

目录 图片 头函数 接口 显示案例 方法1 方法2 方法3 方法4 自定义 方法5 在方法4上实现 图片 头函数 #include <QListWidgetItem> 接口 //不怎么常用void addItem(const QString &label)void addItems(const QStringList &labels) //自定义条目项…

Redis客户端之Redisson(二)Redisson分布式锁

一、原理&#xff1a; Redisson并没有通过setNx命令来实现加锁&#xff0c;而是基于 Redis 看⻔狗机制&#xff0c;自己实现了一套分布式锁逻辑。 1、加锁机制&#xff1a; 二、使用方法&#xff1a;

EasyExcel实现下载模板

实体类&#xff1a; package com.aicut.monitor.domain;import com.alibaba.excel.annotation.ExcelIgnore; import com.alibaba.excel.annotation.ExcelIgnoreUnannotated; import com.alibaba.excel.annotation.ExcelProperty; import com.alibaba.excel.annotation.write.s…

YOLOv8全网独家首发:Powerful-IoU更好、更快的收敛IoU | 2024年最新IoU

💡💡💡本文独家改进:Powerful-IoU更好、更快的收敛IoU,是一种结合了目标尺寸自适应惩罚因子和基于锚框质量的梯度调节函数的损失函数 💡💡💡MS COCO和PASCAL VOC数据集实现涨点 收录 YOLOv8原创自研 https://blog.csdn.net/m0_63774211/category_12511737.htm…

【新课上架】安装部署系列Ⅲ—Oracle 19c Data Guard部署之两节点RAC部署实战

01 课程介绍 Oracle Real Application Clusters (RAC) 是一种跨多个节点分布数据库的企业级解决方案。它使组织能够通过实现容错和负载平衡来提高可用性和可扩展性&#xff0c;同时提高性能。本课程基于当前主流版本Oracle 19cOEL7.9解析如何搭建2节点RAC对1节点单机的DATA GU…

数学知识第一期 质数

前言 本文是关于质数的一些内容&#xff0c;希望能够对大家有帮助&#xff01;&#xff01;&#xff01; 一、质数的基本内容 定义&#xff1a; 质数又称素数。一个大于1的自然数&#xff0c;除了1和它自身外&#xff0c;不能被其他自然数整除的数叫做质数&#xff1b;否则…

Go Zero微服务个人探究之路(十)实战走通微服务前台请求调用的一套流程model->rpc微服务->apiHTTP调用

前言 Go语言凭借低占用&#xff0c;高并发等优秀特性成为后台编程语言的新星&#xff0c;GoZero框架由七牛云技术副总裁团队编写&#xff0c;目前已经成为Go微服务框架里star数量最多的框架 本文记录讲述笔者一步步走通前台向后台发出请求&#xff0c;后台api调用rpc服务的相…

VR数字展厅,平面静态跨越到3D立体化时代

近些年&#xff0c;VR的概念被越来越多的人提起&#xff0c;较为常见的形式就是VR数字展厅。VR数字展厅的出现&#xff0c;让各地以及各行业的展厅展馆的呈现和宣传都发生了很大的改变和革新&#xff0c;同时也意味着展览传播的方式不再局限于原来的图文、视频&#xff0c;而是…

【Redis】list以及他的应用场景

介绍 &#xff1a;list 即是 链表。链表是一种非常常见的数据结构&#xff0c;特点是易于数据元素的插入和删除并且且可以灵活调整链表长度&#xff0c;但是链表的随机访问困难。许多高级编程语言都内置了链表的实现比如 Java 中的 LinkedList&#xff0c;但是 C 语言并没有实现…

Spring Boot如何统计一个Bean中方法的调用次数

目录 实现思路 前置条件 实现步骤 首先我们先自定义一个注解 接下来定义一个切面 需要统计方法上使用该注解 测试 实现思路 通过AOP即可实现&#xff0c;通过AOP对Bean进行代理&#xff0c;在每次执行方法前或者后进行几次计数统计。这个主要就是考虑好如何避免并发情况…

JavaScript Proxy 对象、eval函数详解

&#x1f9d1;‍&#x1f393; 个人主页&#xff1a;《爱蹦跶的大A阿》 &#x1f525;当前正在更新专栏&#xff1a;《VUE》 、《JavaScript保姆级教程》、《krpano》、《krpano中文文档》 ​ 目录 ✨ 前言 ✨ 正文 Proxy 什么是 Proxy 代理 handlers get 捕获器 se…

Oracle ORA-09925

Error : 30: Read-only file system 造成这个问题的原因大多数是因为非正常关机后导致文件系统受损引起的&#xff0c;在系统重启之后&#xff0c;受损分区就会被Linux自动挂载为只读。 解决办法之一&#xff1a; 重启系统

聚观早报 | 特斯拉公布2023年财报;五菱红1号电池正式发布

聚观早报每日整理最值得关注的行业重点事件&#xff0c;帮助大家及时了解最新行业动态&#xff0c;每日读报&#xff0c;就读聚观365资讯简报。 整理丨Cutie 1月26日消息 特斯拉公布2023年财报 五菱红1号电池正式发布 Redmi Note 13 Pro新春版开售 三星Galaxy S24系列发布…

HCIA——29HTTP、万维网、HTML、PPP、ICMP;万维网的工作过程;HTTP 的特点HTTP 的报文结构的选择、解答

学习目标&#xff1a; 计算机网络 1.掌握计算机网络的基本概念、基本原理和基本方法。 2.掌握计算机网络的体系结构和典型网络协议&#xff0c;了解典型网络设备的组成和特点&#xff0c;理解典型网络设备的工作原理。 3.能够运用计算机网络的基本概念、基本原理和基本方法进行…

在windows上用python版tensorrt推理

文章目录 尝试一&#xff1a;利用torch导出的pth文件&#xff0c;调用torch2trt来进行trt推理1.1 搭建环境1.2 如何trt推理1.3 遇到的问题 尝试二&#xff1a;把onnx模型转为sim版的onnx模型2.1 搭建onnxim环境2.2 使用 onnxsim 尝试三&#xff1a;把onnx-sim转到trt&#xff0…

Python之代码覆盖率框架coverage使用介绍

Python代码覆盖率工具coverage.py其实是一个第三方的包&#xff0c;同时支持Python2和Python3版本。 安装也非常简单&#xff0c;直接运行&#xff1a; pip install coverage 安装完成后&#xff0c;会在Python环境下的\Scripts下看到coverage.exe&#xff1b; 首先我们编写…

【LIBS】交叉编译TCPDUMP

目录 1. 安装编译工具2. 设置环境变量3. 编译libpcap3.1 安装依赖3.2 交叉编译 4. 编译TCPDUMP4.1 克隆仓库与生成构建环境4.2 静态链接LIBPCAP4.3 动态链接LIBPCAP4.4 构建与安装 5. 查看交叉编译结果5.1 文件布局 1. 安装编译工具 sudo apt-get install -y autoconf automak…