在 CelebA 数据集上训练的 PyTorch 中的基本变分自动编码器

摩西·西珀博士

一、说明

        我最近发现自己需要一种方法将图像编码到潜在嵌入中,调整嵌入,然后生成新图像。有一些强大的方法可以创建嵌入从嵌入生成。如果你想同时做到这两点,一种自然且相当简单的方法是使用变分自动编码器。

        这样的深度网络不仅可以进行编码和解码,而且相当简单,我可以在以后的研究中使用它,而不必过多担心编码解码阶段的各种隐藏复杂性。我也更喜欢对软件内部有尽可能多的控制。

        因此,考虑到所有这些规范,我从 GitHub 收集了一些零碎的东西,施展了一些我自己的魔法,最终得到了一个漂亮、简单的变分自动编码器。我将在下面描述主要部分,完整的包可在以下位置找到:

vae-torch-celeba,PyTorch 中 CelebA 数据集的变分自动编码器,下载vae-torch-celeba的源码_GitHub_帮酷

PyTorch 中用于 CelebA 数据集的变分自动编码器 - GitHub - moshesipper/vae-torch-celeba:变分...

它相当小并且完全独立——这就是我的意图!

二、自动编码器

        为了使本文简短易懂,我将避免提供变分自动编码器的冗长概述。此外,您还可以在 Medium 上找到有关基础知识的优秀文章。我只提供三张快速图片。

        这是基本自动编码器的样子:

来源: https: //commons.wikimedia.org/wiki/File :Autoencoder_schema.png

        简而言之,网络将输入数据压缩为潜在向量(也称为嵌入),然后将其解压缩回来。这两个阶段称为编码解码

        变分自动编码器(VAE)看起来非常相似,除了中间的嵌入部分。对于每个输入,VAE 的编码器输出潜在空间中预定义分布的参数,而不是潜在空间中的向量:

来源:https ://commons.wikimedia.org/wiki/File:Reparameterized_Variational_Autoencoder.png

        最后一张图片:如果我们处理的是图像输入,我们需要一个卷积VAE,如下所示:

来源:https://github.com/arthurmeyer/Saliency_Detection_Convolutional_Autoencoder

        注意#1:观察编码器部分如何在每一层中添加越来越多的滤波器,图像变得越来越小;解码器则相反。

        注意#2:注意符号。如果只有一个通道,则术语“过滤器”和“内核”基本相同。对于多个通道,每个过滤器都是一组内核。查看这篇很棒的 Medium 文章:“直观地理解深度学习的卷积”。

三、CelebA数据集

我将使用的数据集是 CelebA,其中包含 202,599 张名人面孔图像。

CelebA 数据集

CelebFaces Attributes Dataset (CelebA) 是一个大规模人脸属性数据集,包含超过 20 万张名人图像……

可以通过以下方式访问它torchvision:

from torchvision.datasets import CelebAtrain_dataset = CelebA(path, split='train')
test_dataset = CelebA(path, split='valid') # or 'test'

四、VAE类

        我的 VAE 基于此PyTorch 示例和存储库的普通 VAE模型(将我使用的普通 VAE 替换为中的任何其他模型PyTorch-VAE应该不会太难)。PyTorch-VAE

        该文件vae.py包含VAE类以及图像大小的定义、两个潜在向量的维度(均值和方差)以及数据集的路径:

CELEB_PATH = './data/'
IMAGE_SIZE = 150
LATENT_DIM = 128
image_dim = 3 * IMAGE_SIZE * IMAGE_SIZE

在课堂上VAE,我使用了以下隐藏过滤器维度:

hidden_dims = [32, 64, 128, 256, 512]

编码器看起来像这样:

in_channels = 3
modules = []
for h_dim in hidden_dims:modules.append(nn.Sequential(nn.Conv2d(in_channels, out_channels=h_dim,kernel_size=3, stride=2, padding=1),nn.BatchNorm2d(h_dim),nn.LeakyReLU()))in_channels = h_dim
self.encoder = nn.Sequential(*modules)

然后是潜在向量:

self.fc_mu = nn.Linear(hidden_dims[-1] * self.size * self.size, LATENT_DIM)
self.fc_var = nn.Linear(hidden_dims[-1] * self.size * self.size, LATENT_DIM)

        最后我们用解码器“倒退”:

hidden_dims.reverse()for i in range(len(hidden_dims) - 1):modules.append(nn.Sequential(nn.ConvTranspose2d(hidden_dims[i],hidden_dims[i + 1],kernel_size=3,stride=2,padding=1,output_padding=1),nn.BatchNorm2d(hidden_dims[i + 1]),nn.LeakyReLU()))self.decoder = nn.Sequential(*modules)

这就是它的要点——还有一些零碎的内容vae.py可以完成这VAE门课。

五、训练

        该文件trainvae.py包含训练我们刚刚编码的 VAE 的代码。老实说,没什么花哨的......有 3 个主要函数:(train随着训练的进行,它也输出损失值),test(它还构建一个重建图像的小样本)和loss_function。训练和测试相当普通,损失函数是标准 VAE,带有重建组件 (MSE) 和 KL 散度组件。

        epoch 上的主循环执行 4 个操作:1) train、2) test、3) 生成随机潜在向量并调用decode以输出相应的输出图像,以及 4) 将 epoch 的模型保存到文件中pth

        以下是示例运行的输出。通过 20 个训练周期,您最终会得到 20 个重建图像文件、20 个潜在采样文件和 20 个 python 模型文件:

        这里reconstruction_20.png,顶行显示 8 张原始图片,底行显示经过训练的 VAE 的相应重建。

        在 epoch 20 时从模型重建(输出)图像。

        这里的sample_20.png,显示了从随机潜在向量生成的 64 张图像:

        只是为了好玩,我添加了一小段代码 — genpics.py— 从数据集中挑选一个随机图像并生成 7 个重建。以下是一些示例(最左边的图像是原始图像):

        最后,我再次放置 GitHub 链接。享受!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/131586.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SparkSQL

1、Spark简介 2、Spark-Core核心算子 3、Spark-Core 4、SparkSQL 文章目录 一、概述1、简介2、DataFrame、DataSet3、SparkSQL特点 二、Spark SQL编程1、SparkSession新API2、DataFrame2.1 创建DataFrame2.2 SQL 语法2.3 DSL语法 3、DataSet4、RDD、DataFrame、DataSet相互转换…

强大的pdf编辑软件:Acrobat Pro DC 2023中文

Acrobat Pro DC 2023是一款强大的PDF编辑和管理软件,它提供了广泛的功能,使用户能够轻松创建、编辑、转换和共享PDF文档。通过直观的界面和先进的工具,用户可以快速进行文本编辑、图像调整、页面管理等操作,同时支持OCR技术&#…

win10 + cmake3.17 + vs2017编译osgearth2.7.0遇到的坑

坑1&#xff1a;debug模式下生成osgEarthAnnotation时 错误&#xff1a;xmemory0(881): error C2440: “初始化”: 无法从“std::pair<const _Kty,_Ty>”转换为 to _Objty 出错位置&#xff1a;src/osgEarthFeatures/FeatureSourceIndexNode.cpp 解决办法&#xff1a; …

unity 使用TriLib插件动态读取外部模型

最近在做动态加载读取外部模型的功能使用了triLib插件&#xff0c;废话不多说直接干货。 第一步下载导入插件&#xff0c;直接分享主打白嫖共享&#xff0c;不搞花里胡哨的。 链接&#xff1a;https://pan.baidu.com/s/1DK474wSrIZ0R6i0EBh5V8A 提取码&#xff1a;tado 导入后第…

Spring Cloud之Seata的学习

目录 案例准备 分布式事务 基本理论 CAP定理 BASE理论 Seata 部署TC服务 数据库准备 修改Nacos配置并导入信息 启动Seata 集成Seata XA模式原理 Seata的XA实现 优点 缺点 实现 AT模式原理 AT模式的脏写问题 Seata的AT实现 XA与AT的区别 TCC模式原理 空回…

有人物联网模块连接阿里云物联网平台的方法

摘要&#xff1a;本文介绍有人物联网模块M100连接阿里云的参数设置&#xff0c;作为说明书的补充。 没有阿里云功能需求的请略过本文&#xff0c;不要浪费您宝贵的时间。 网络选择LTE&#xff0c;请先确保插入的SIM卡有流量。 接下来配置阿里云云服务。如下图所示&#xff0c;…

windows mysql安装

1、首先去官网下载mysql安装包&#xff0c;官网地址&#xff1a;MySQL :: Download MySQL Community Server 2&#xff1a;把安装包放到你安装mysql的地方&#xff0c;然后进行解压缩&#xff0c;注意&#xff0c;解压后的mysql没有配置文件&#xff0c;我们需要创建配置文件 配…

mediasoup webrtc音视频会议搭建

环境ubuntu22.10 nvm --version 0.33.11 node -v v16.20.2 npm -v 8.19.4 node-gyp -v v10.0.1 python3 --version Python 3.10.7 python with pip: sudo apt install python3-pip gcc&g version 12.2.0 (Ubuntu 12.2.0-3ubuntu1) Make 4.2.1 npm install mediasoup3 sudo …

S4.2.4.7 Start of Data Stream Ordered Set (SDS)

一 本章节主讲知识点 1.1 xxx 1.2 sss 1.3 ddd 二 本章节原文翻译 2.1 SDS 数据流开始有序集 SDS 代表传输的数据类型从有序集转为数据流。它会在 Configuration.Idle&#xff0c;Recovery.Idle 和 Tx 的 L0s.FTS 状态发送。Loopback 模式下&#xff0c;主机允许发送 SDS。…

初阶JavaEE(14)表白墙程序

接上次博客&#xff1a;初阶JavaEE&#xff08;13&#xff09;&#xff08;安装、配置&#xff1a;Smart Tomcat&#xff1b;访问出错怎么办&#xff1f;Servlet初识、调试、运行&#xff1b;HttpServlet&#xff1a;HttpServlet&#xff1b;HttpServletResponse&#xff09;-C…

Rust学习日记(二)变量的使用--结合--温度换算/斐波那契数列--实例

前言&#xff1a; 这是一个系列的学习笔记&#xff0c;会将笔者学习Rust语言的心得记录。 当然&#xff0c;这并非是流水账似的记录&#xff0c;而是结合实际程序项目的记录&#xff0c;如果你也对Rust感兴趣&#xff0c;那么我们可以一起交流探讨&#xff0c;使用Rust来构建程…

修复dinput8.dll文件的缺失,以及修复dinput8.dll文件时需要注意什么

dinput8.dll文件通常在使用大型游戏时容易出现dinput8.dll文件丢失的情况&#xff0c;今天这篇文章将要教大家修复dinput8.dll文件的缺失&#xff0c;同时在修复dinput8.dll文件时需要注意些什么&#xff1f;防止文件在修复的过程中出现其他的错误。 dinput8.dll是DirectInput库…

HarmonyOS列表组件

List组件的使用 import router from ohos.routerEntry Component struct Index {private arr: number[] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]build() {Row() {Column() {List({ space: 10 }) {ForEach(this.arr, (item: number) > {ListItem() {Text(${item}).width(100%).heig…

SQL Server2000mdf升级SQL Server2005数据库还原

SQL Server2000数据库还原sqlserver 2000mdf升级 sqlserver 2008数据库还原SQL Server2005数据库脚本 sqlserver数据库低版本升级成高版本 sqlserver数据库版本升级 数据库版本还原 如果本机安装了sqlserver2012或者sqlserver2019等高版本 怎么样才能运行sqlserver2000的数据库…

Make.com实现多个APP应用的自动化的入门指南

Make.com是一款基于云的自动化平台&#xff0c;可帮助用户将多个应用程序连接在一起&#xff0c;并通过设置自动化流程来简化日常任务。Make.com提供丰富的API集成&#xff0c;支持连接各种流行的应用程序&#xff0c;包括社交媒体、电子商务、CRM等。 使用Make.com实现多个AP…

基于8086家具门安全控制系统设计

**单片机设计介绍&#xff0c;基于8086家具门安全控制系统设计 文章目录 一 概要二、功能设计设计思路 三、 软件设计原理图 五、 程序六、 文章目录 一 概要 # 8086家具门安全控制系统设计介绍 8086家具门安全控制系统是一种用于保护家具和保证室内安全的系统。该系统基于808…

从0到1:腾讯云服务器使用教程

腾讯云服务器入门教程包括云服务器CPU内存带宽配置选择&#xff0c;选择云服务器CVM或轻量应用服务器&#xff0c;云服务器创建后重置密码、远程连接、搭建程序环境等&#xff0c;腾讯云服务器网txyfwq.com分享从0到1腾讯云服务器入门教程&#xff1a; 目录 腾讯云服务器入门…

6 从物理层到MAC层

1、实现局域网中玩游戏 在早期的80后的大学宿舍中&#xff0c;组件一个宿舍的局域网&#xff0c;以便于宿舍内部可以玩游戏. 第一层&#xff08;物理层&#xff09; 1.首先是实现电脑连接电脑&#xff0c;需要依靠网线&#xff0c;有两个头。 2.一头插在一台电脑的网卡上&am…

WebGL:基础练习 / 简单学习 / demo / canvas3D

一、前置内容 canvas&#xff1a;理解canvas / 基础使用 / 实用demo-CSDN博客 WebGL&#xff1a;开始学习 / 理解 WebGL / WebGL 需要掌握哪些知识 / 应用领域 / 前端值得学WebGL吗_webgl培训-CSDN博客 二、在线运行HTML 用来运行WebGL代码&#xff0c;粘贴--运行&#xff…

大数据毕业设计选题推荐-无线网络大数据平台-Hadoop-Spark-Hive

✨作者主页&#xff1a;IT毕设梦工厂✨ 个人简介&#xff1a;曾从事计算机专业培训教学&#xff0c;擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Py…