PyTorch深度学习实战(26)——卷积自编码器(Convolutional Autoencoder)

PyTorch深度学习实战(26)——卷积自编码器

    • 0. 前言
    • 1. 卷积自编码器
    • 2. 使用 t-SNE 对相似图像进行分组
    • 小结
    • 系列链接

0. 前言

我们已经学习了自编码器 (AutoEncoder) 的原理,并使用 PyTorch 搭建了全连接自编码器,但我们使用的数据集较为简单,每张图像只有一个通道(每张图像都为黑白图像)且图像相对较小 (28 x 28)。但在现实场景中,图像数据通常为彩色图像( 3 个通道)且图像尺寸通常较大。在本节中,我们将实现能够处理多维输入图像的卷积自编码器,为了与普通自编码器进行对比,同样使用 MNIST 数据集。

1. 卷积自编码器

与传统的全连接自编码器不同,卷积自编码器 (Convolutional Autoencoder) 利用卷积层和池化层替代了全连接层,以处理具有高维空间结构的图像数据。这样的设计使得卷积自编码器能够在较少的参数量下对输入数据进行降维和压缩,同时保留重要的空间特征。卷积自编码器架构如下所示:

卷积自编码器

从上图中可以看出,输入图像被表示为瓶颈层中的潜空间变量,用于重建图像。图像经过多次卷积(编码器)得到低维潜空间表示,然后在解码器中,将潜空间变量还原为原始尺寸,使解码器的输出能够近似恢复原始输入。
本质上,卷积自编码器在其网络中使用卷积、池化操作来代替原始自编码器的全连接操作,并使用反卷积操作 (Conv2DTranspose) 对特征图进行上采样。了解卷积自编码器的原理后,使用 PyTorch 实现此架构。

(1) 数据集的加载和构建方式与全连接自编码器完全相同:

from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import make_grid
import numpy as np
from matplotlib import pyplot as plt
device = 'cuda' if torch.cuda.is_available() else 'cpu'img_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5]),transforms.Lambda(lambda x: x.to(device))
])trn_ds = MNIST('MNIST/', transform=img_transform, train=True, download=True)
val_ds = MNIST('MNIST/', transform=img_transform, train=False, download=True)batch_size = 256
trn_dl = DataLoader(trn_ds, batch_size=batch_size, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

(2) 定义神经网络类 ConvAutoEncoder

定义 __init__ 方法:

class ConvAutoEncoder(nn.Module):def __init__(self):super().__init__()

定义编码器架构:

        self.encoder = nn.Sequential(nn.Conv2d(1, 32, 3, stride=3, padding=1), nn.ReLU(True),nn.MaxPool2d(2, stride=2),nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(True),nn.MaxPool2d(2, stride=1))

在以上代码中,通道数最初由 1 开始,逐渐增加到 64,同时通过 nn.MaxPool2dnn.Conv2d 操作减小输入图像尺寸。

定义解码器架构:

        self.decoder = nn.Sequential(nn.ConvTranspose2d(64, 32, 3, stride=2), nn.ReLU(True),nn.ConvTranspose2d(32, 16, 5, stride=3, padding=1), nn.ReLU(True),nn.ConvTranspose2d(16, 1, 2, stride=2, padding=1), nn.Tanh())

定义前向传播方法 forward

    def forward(self, x):x = self.encoder(x)x = self.decoder(x)return x

(3) 使用 summary 方法获取模型摘要信息:

model = ConvAutoEncoder().to(device)
from torchsummary import summary
summary(model, (1,28,28))
输出结果如下所示:
```shell
----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1           [-1, 32, 10, 10]             320ReLU-2           [-1, 32, 10, 10]               0MaxPool2d-3             [-1, 32, 5, 5]               0Conv2d-4             [-1, 64, 3, 3]          18,496ReLU-5             [-1, 64, 3, 3]               0MaxPool2d-6             [-1, 64, 2, 2]               0ConvTranspose2d-7             [-1, 32, 5, 5]          18,464ReLU-8             [-1, 32, 5, 5]               0ConvTranspose2d-9           [-1, 16, 15, 15]          12,816ReLU-10           [-1, 16, 15, 15]               0ConvTranspose2d-11            [-1, 1, 28, 28]              65Tanh-12            [-1, 1, 28, 28]               0
================================================================
Total params: 50,161
Trainable params: 50,161
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.14
Params size (MB): 0.19
Estimated Total Size (MB): 0.34
----------------------------------------------------------------

从以上模型架构信息可以看出,使用尺寸为 batch size x 64 x 2 x 2MaxPool2d-6 层作为瓶颈层。

模型训练过程,训练和验证损失随时间的变化以及对输入图像的重建结果如下:

def train_batch(input, model, criterion, optimizer):model.train()optimizer.zero_grad()output = model(input)loss = criterion(output, input)loss.backward()optimizer.step()return loss@torch.no_grad()
def validate_batch(input, model, criterion):model.eval()output = model(input)loss = criterion(output, input)return lossmodel = ConvAutoEncoder().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-5)num_epochs = 20
train_loss_epochs = []
val_loss_epochs = []
for epoch in range(num_epochs):N = len(trn_dl)trn_loss = []val_loss = []for ix, (data, _) in enumerate(trn_dl):loss = train_batch(data, model, criterion, optimizer)pos = (epoch + (ix+1)/N)trn_loss.append(loss.item())train_loss_epochs.append(np.average(trn_loss))N = len(val_dl)for ix, (data, _) in enumerate(val_dl):loss = validate_batch(data, model, criterion)pos = epoch + (1+ix)/Nval_loss.append(loss.item())val_loss_epochs.append(np.average(val_loss))epochs = np.arange(num_epochs)+1
plt.plot(epochs, train_loss_epochs, 'bo', label='Training loss')
plt.plot(epochs, val_loss_epochs, 'r-', label='Test loss')
plt.title('Training and Test loss over increasing epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid('off')
plt.show()for _ in range(5):ix = np.random.randint(len(val_ds))im, _ = val_ds[ix]_im = model(im[None])[0]plt.subplot(121)# fig, ax = plt.subplots(1,2,figsize=(3,3)) plt.imshow(im[0].detach().cpu(), cmap='gray')plt.title('input')plt.subplot(122)plt.imshow(_im[0].detach().cpu(), cmap='gray')plt.title('prediction')plt.show()

模型性能监测图像重建结果图像重建结果

从上图中,我们可以看到卷积自编码器重建后的图像比全连接自编码器更清晰,可以通过改变编码器和解码器中的通道数,观察模型训练结果。在下一节中,我们将根据瓶颈层潜变量对相似图像进行分组

2. 使用 t-SNE 对相似图像进行分组

假设相似的图像具有相似的潜变量(也称嵌入),而不相似的图像具有不同的潜变量,使用自编码器,可以在低维空间中表示图像。接下来,我们继续学习图像的相似度度量,在二维空间中绘制潜变量,使用 t-SNE 技术将卷积自编码器的 64 维向量缩减至到 2 维空间。
2 维空间中,我们可以方便的可视化潜变量,以观察相似图像是否具有相似的潜变量,相似图像在二维平面中应该聚集在一起。接下里,我们在二维平面中表示所有测试图像的潜变量。

(1) 初始化列表,以便存储潜变量 (latent_vectors) 和相应的图像类别(存储每个图像的类别只是为了验证同一类别的图像是否具有较高的相似性,并不会在训练过程使用):

latent_vectors = []
classes = []

(2) 遍历验证数据加载器 (val_dl) 中的图像,并存储编码器的输出 (model.encoder(im).view(len(im),-1)) 和每个图像 (im) 对应的类别 (clss):

for im,clss in val_dl:latent_vectors.append(model.encoder(im).view(len(im),-1))classes.extend(clss)

(3) 连接潜变量 (latent_vectors) NumPy 数组:

latent_vectors = torch.cat(latent_vectors).cpu().detach().numpy()

(4) 导入 t-SNE 库 (TSNE),并将潜变量转换为二维向量 (TSNE(2)) ,以便进行绘制:

from sklearn.manifold import TSNE
tsne = TSNE(2)

(5) 通过在图像潜变量 (latent_vectors) 上运行 fit_transform 方法来拟合 t-SNE

clustered = tsne.fit_transform(latent_vectors)

(6) 拟合 t-SNE 后绘制数据点:

fig = plt.figure(figsize=(12,10))
cmap = plt.get_cmap('Spectral', 10)
plt.scatter(*zip(*clustered), c=classes, cmap=cmap)
plt.colorbar(drawedges=True)
plt.show()

聚类结果

可以看到同一类别的图像能够聚集在一起,即相似的图像将具有相似的潜变量值。

小结

卷积自编码器是一种基于卷积神经网络结构的自编码器,适用于处理图像数据。卷积自编码器在图像处理领域有广泛的应用,包括图像去噪、图像压缩、图像生成等任务。通过训练卷积自编码器,可以提取出输入图像的关键特征,并实现对图像数据的降维和压缩,同时保留重要的空间信息。在本节中,我们介绍了卷积自编码器的模型架构,使用 PyTorch 从零开始实现在 MNIST 数据集上训练了一个简单的卷积自编码器,并使用 t-SNE 技术在二维平面中表示了所有测试图像的潜变量。

系列链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习
PyTorch深度学习实战(16)——面部关键点检测
PyTorch深度学习实战(17)——多任务学习
PyTorch深度学习实战(18)——目标检测基础
PyTorch深度学习实战(19)——从零开始实现R-CNN目标检测
PyTorch深度学习实战(20)——从零开始实现Fast R-CNN目标检测
PyTorch深度学习实战(21)——从零开始实现Faster R-CNN目标检测
PyTorch深度学习实战(22)——从零开始实现YOLO目标检测
PyTorch深度学习实战(23)——使用U-Net架构进行图像分割
PyTorch深度学习实战(24)——从零开始实现Mask R-CNN实例分割
PyTorch深度学习实战(25)——自编码器(Autoencoder)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/235905.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【PHP入门】2.2 流程控制

-流程控制- 流程控制:代码执行的方向 2.2.1控制分类 顺序结构:代码从上往下,顺序执行。(代码执行的最基本结构) 分支结构:给定一个条件,同时有多种可执行代码(块)&am…

阿里推荐 LongAdder ,不推荐 AtomicLong !

其他系列文章导航 Java基础合集数据结构与算法合集 设计模式合集 多线程合集 分布式合集 ES合集 文章目录 其他系列文章导航 文章目录 前言 一、CAS 1.1 CAS 全称 1.2 通俗理解CAS 1.3 CAS的问题 1.4 解决 ABA 问题 二、LongAdder 2.1 什么是 LongAdder 2.2 为什么推…

用JVS低代码实现业务流程的撤回和重新开始

在当今的数字化时代,业务流程的效率和准确性对于企业的运营至关重要。在实际业务场景中,我们可能需要处理一些复杂的流程,例如申请审批流程、合同签订流程等。这些流程在执行过程中可能会遇到各种情况,例如某个审批步骤需要重新审…

❀My虚拟机上的ftp服务器搭建(centos)❀

❀My虚拟机上的ftp服务器搭建(centos)❀ 在CentOS上搭建FTP服务器可以使用vsftpd软件,下面是详细的搭建教程: ①安装vsftpd软件 在终端中输入以下命令进行安装: sudo yum install vsftpd ②配置vsftpd 打开vsftpd的配置文件,…

【深度学习】序列生成模型(五):评价方法计算实例:计算BLEU-N得分【理论到程序】

文章目录 一、BLEU-N得分(Bilingual Evaluation Understudy)1. 定义2. 计算N1N2BLEU-N 得分 3. 程序 给定一个生成序列“The cat sat on the mat”和两个参考序列“The cat is on the mat”“The bird sat on the bush”分别计算BLEU-N和ROUGE-N得分(N1或…

WEB渗透—PHP反序列化(六)

Web渗透—PHP反序列化 课程学习分享(课程非本人制作,仅提供学习分享) 靶场下载地址:GitHub - mcc0624/php_ser_Class: php反序列化靶场课程,基于课程制作的靶场 课程地址:PHP反序列化漏洞学习_哔哩…

Ubuntu 22.04 禁用(彻底移除)Snap

什么是Snaps Snaps 是 Ubuntu 的母公司 Canonical 于 2016 年 4 月发布 Ubuntu 16.04 LTS(Long Term Support,长期支持版)时引入的一种容器化的软件包格式。自 Ubuntu 16.04 LTS 起,Ubuntu 操作系统可以同时支持 Snap 及 Debian …

3dsmax渲染太慢,用云渲染农场多少钱?

对于许多从事计算机图形设计的创作者来说,渲染速度慢是一个常见问题,尤其是对于那些追求极致出图效果的室内设计师和建筑可视化师,他们通常使用3ds Max这样的工具,而高质量的渲染经常意味着长时间的等待。场景复杂、细节丰富&…

APView500PV电能质量在线监测装置——安科瑞 顾烊宇

概述 APView500PV电能质量在线监测装置采用了高性能多核平台和嵌入式操作系统,遵照IEC61000-4-30《测试和测量技术-电能质量测量方法》中规定的各电能质量指标的测量方法进行测量,集谐波分析、波形采样、电压暂降/暂升/中断、闪变监测、电压不平衡度监测…

CentOS操作学习(二)

上一篇学习了CentOS的常用指令CentOS指令学习-CSDN博客 现在我们接着学习 一、Vi编辑器 这是CentOS中自带的编辑器 三种模式 进入编辑模式后 i:在光标所在字符前开始插入a:在光标所在字符串后开始插入o:在光标所在行的下面另起一新行插入…

命令执行 [SWPUCTF 2021 新生赛]easyrce

打开题目 提示要用url传参,但实际是用url进行一些系统命令执行 那我们就用whoami命令来查看用户和权限 那我们直接用ls / 去查看当下根目录下有哪些文件 我们看到根目录下有flag 直接cat读取就行 知识点: system system是一个函数 用来运行外部的程序…

4.CentOS7开启ssh

Centos7开启ssh 通过命令查看是否安装了ssh服务 rpm -qa | grep openssh 修改主配置文件 vim /etc/ssh/sshd_config 将PermitRootLogin,RSAAuthentication,PubkeyAuthentication的设置打开 RSAAuthentication yes# 启用 RSA 认证PubkeyAuthenticatio…

19_20-Golang中的切片

**Golang **中的切片 主讲教师:(大地) 合作网站:www.itying.com** **(IT 营) 我的专栏:https://www.itying.com/category-79-b0.html 1、为什么要使用切片 因为数组的长度是固定的并且数组长…

【.NET后端工具系列】MediatR实现进程内消息通讯

阅读本文你的收获 学习MediatR工具,实现进程内消息发送和处理过程的解耦学习MediatR的两种消息处理模式了解中介者模式和其好处 一、什么是MediatR? MediatR是一款基于中介者模式的思想而实现的.NET库,支持.NET Framework和跨平台 的.NET C…

aws配置以及下载 spaceNet6 数据集

一:注册亚马逊账号 注册的时候,唯一需要注意的是信用卡绑定,这个可以去淘宝买,搜索aws匿名卡。 注册完记得点击登录,记录一下自己的账户ID哦! 二:登录自己的aws账号 2.1 首先创建一个用户 首…

从YOLOv1到YOLOv8的YOLO系列最新综述【2023年4月】

作者:Juan R. Terven 、Diana M. Cordova-Esparaza 摘要:YOLO已经成为机器人、无人驾驶汽车和视频监控应用的核心实时物体检测系统。我们对YOLO的演变进行了全面的分析,研究了从最初的YOLO到YOLOv8每次迭代的创新和贡献。我们首先描述了标准…

研发管理-代码管理篇

前言: 工作了这些年,工作了三家公司,也用过主流的代码管理平台,比如SVN,git系列(gitlib,gitee),各有优点,我个人比较喜欢SVN,多人协作的代码管理难免会有代码冲突&#…

2024年【北京市安全员-B证】考试试卷及北京市安全员-B证复审模拟考试

题库来源:安全生产模拟考试一点通公众号小程序 北京市安全员-B证考试试卷根据新北京市安全员-B证考试大纲要求,安全生产模拟考试一点通将北京市安全员-B证模拟考试试题进行汇编,组成一套北京市安全员-B证全真模拟考试试题,学员可…

深入了解 npm 命令

目录 前言1 初始化项目2 安装依赖3 更新依赖4 发布包5 卸载包6 查看依赖7 运行脚本8 包搜索9 查看包信息结语 前言 在现代 Web 开发中,JavaScript 是一种至关重要的语言,而 npm(Node Package Manager)作为 Node.js 平台的默认软件…