VQ-VAE(Neural Discrete Representation Learning)论文解读及实现

pytorch 实现git地址
论文地址:Neural Discrete Representation Learning

1 论文核心知识点

  • encoder
    将图片通过encoder得到图片点表征
    如输入shape [32,3,32,32]
    通过encoder后输出 [32,64,8,8] (其中64位输出维度)

  • 量化码本
    先随机构建一个码本,维度与encoder保持一致
    这里定义512个离散特征,码本shap 为[512,64]

  • encoder 码本中向量最近查找
    encoder输出shape [32,64,8,8], 经过维度变换 shape [3288,64]
    在码本中找到最相近的向量,并替换为码本中相似向量
    输出shape [3288,64],维度变换后,shape 为 [32,64,8,8]

  • decoder
    将上述数据,喂给decoder,还原原始图片

  • loss
    loss 包含两部分
    a . encoder输出和码本向量接近
    b. 重构loss,重构图片与原图片接近

在这里插入图片描述

2 论文实现

2.1 encoder

encoder是常用的图片卷积神经网络
输入x shape [32,3,32,32]
输出 shape [32,128,8,8]

def __init__(self, in_dim, h_dim, n_res_layers, res_h_dim):super(Encoder, self).__init__()kernel = 4stride = 2self.conv_stack = nn.Sequential(nn.Conv2d(in_dim, h_dim // 2, kernel_size=kernel,stride=stride, padding=1),nn.ReLU(),nn.Conv2d(h_dim // 2, h_dim, kernel_size=kernel,stride=stride, padding=1),nn.ReLU(),nn.Conv2d(h_dim, h_dim, kernel_size=kernel-1,stride=stride-1, padding=1),ResidualStack(h_dim, h_dim, res_h_dim, n_res_layers))def forward(self, x):return self.conv_stack(x)

2.2 VectorQuantizer 向量量化层

  • 输入:
    为encoder的输出z,shape : [32,64,8,8]
  • 码本维度:
    encoder维度变换为[2024,64],和码本embeddign shape [512,64]计算相似度
  • 相似计算:使用 ( x − y ) 2 = x 2 + y 2 − 2 x y (x-y)^2=x^2+y^2-2xy (xy)2=x2+y22xy计算和码本的相似度
  • z_q生成
    然后取码本中最相似的向量替换encoder中的向量
  • z_1维度:
    得到z_q shape [2024,64],经维度变换 shape [32,64,8,8] ,维度与输入z一致
  • 损失函数:
    使 z_q和z接近,构建损失函数
    在这里插入图片描述

decoder 层

decoder层比较简单,与encoder层相反
输入x shape 【32,64,8,8】
输出shape [32,3,32,32]

class Decoder(nn.Module):"""This is the p_phi (x|z) network. Given a latent sample z p_phi maps back to the original space z -> x.Inputs:- in_dim : the input dimension- h_dim : the hidden layer dimension- res_h_dim : the hidden dimension of the residual block- n_res_layers : number of layers to stack"""def __init__(self, in_dim, h_dim, n_res_layers, res_h_dim):super(Decoder, self).__init__()kernel = 4stride = 2self.inverse_conv_stack = nn.Sequential(nn.ConvTranspose2d(in_dim, h_dim, kernel_size=kernel-1, stride=stride-1, padding=1),ResidualStack(h_dim, h_dim, res_h_dim, n_res_layers),nn.ConvTranspose2d(h_dim, h_dim // 2,kernel_size=kernel, stride=stride, padding=1),nn.ReLU(),nn.ConvTranspose2d(h_dim//2, 3, kernel_size=kernel,stride=stride, padding=1))def forward(self, x):return self.inverse_conv_stack(x)

2.3 损失函数

损失函数为重构损失和embedding损失之和

  • decoder 输出为图片重构x_hat
  • embedding损失,为encoder和码本的embedding近似损失
  • 重点:(decoder计算损失时,由于中间有取最小值,导致梯度不连续,因此decoder loss 不能直接对encocer推荐进行求导,采用了复制梯度的方式: z_q = z + (z_q - z).detach(),及
    for i in range(args.n_updates):(x, _) = next(iter(training_loader))x = x.to(device)optimizer.zero_grad()embedding_loss, x_hat, perplexity = model(x)recon_loss = torch.mean((x_hat - x)**2) / x_train_varloss = recon_loss + embedding_lossloss.backward()optimizer.step()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/603450.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ROS2/ROS+conda+pytorch配置

0、需求 项目开发中遇到在ROS2中调用pytorch,但pytorch安装在了conda环境下。如果独立安装ros和conda会存在python版本、ubuntu系统版本的问题。网上还没看到比较好的解决方案,通过探索发现以下方案,实现的效果是在一个conda环境中&#xff…

Linux操作系统基础(14):文件管理-文件属性命令

1. 查看文件属性 stat命令用于显示文件的详细信息,包括文件的权限、所有者、大小、修改时间等。 #1.显示文件信息 stat file.txt#2.显示文件系统状态 stat -f file.txt#3.显示以时间戳的形式文件信息 stat -t file.txt2. 修改文件时间戳 touch命令用于创建新的空…

window.print打印事件,固定打印界面,打印成功或取消返回打印前界面,再次点击打印事件不生效

我是弹框中有打印&#xff0c;然后如果还原界面后在点打印事件不生效 我用 window.location.reload() 后刷新界面有返回的界面是关闭了弹框。我需要的是打印成功或取消返回打印不关闭弹框 之前打印代码 我这是是vue3 &#xff0c;我打印界面是单独写的 <printPag ref"…

GitHub Copilot 功能介绍和使用场景

原文 &#xff1a; https://openaigptguide.com/github-copilot/ GitHub Copilot是一款由GitHub、OpenAI和Microsoft联合开发的AI辅助开发工具&#xff0c;它以人工智能的方式提供语法结构、表达式、变量名等的自动补全建议&#xff0c;并对代码进行注释解释&#xff0c;将代码…

linux stop_machine 停机机制应用及一次触发 soft lockup 分析

文章目录 stop_mchine 引起的 soft lockup触发 soft lockup 原因分析&#xff08;一&#xff09;&#xff1a;触发 soft lockup 原因分析&#xff08;二&#xff09;触发 soft lockup 原因分析&#xff08;三&#xff09; stop_mchine 引起的 soft lockup 某次在服务器上某节点…

1389 蓝桥杯 二分查找数组元素 简单

1389 蓝桥杯 二分查找数组元素 简单 //C风格解法1&#xff0c;lower_bound(),通过率100% //利用二分查找的方法在有序的数组中查找&#xff0c;左闭右开 #include <bits/stdc.h> using namespace std;int main(){int data[200];for(int i 0 ; i < 200 ; i) data[i] …

LeetCode简单题记录

1、两数之和&#xff0c;给定数组nums&#xff0c;求和为target的两个数组元素的下标 我用了两个for循环&#xff0c;官方解为 哈希表&#xff0c;知识盲区 class Solution { public:vector<int> twoSum(vector<int>& nums, int target) {unordered_map<i…

React Hooks中useState的介绍,并封装为useSetState函数的使用

useState 允许我们定义状态变量&#xff0c;并确保当这些状态变量的值发生变化时&#xff0c;页面会重新渲染。 useState 返回值 const [state, setState] useState(initialState);useState 返回一个长度为 2 的数组。通常&#xff0c;我们这样定义状态变量&#xff1a; co…

Socket.D 替代 http 协议像 Ajax 一样开发前端接口

我们在"前端接口"开发时&#xff0c;使用 socket.d 协议有什么好处&#xff1a; 功能上可以替代 http 和原生 ws安全&#xff01;安全&#xff01;安全&#xff01;现有的工具想抓包数据&#xff0c;难&#xff01;难&#xff01;难&#xff01;&#xff08;socket.…

向爬虫而生---Redis 拓宽篇3 <GEO模块>

前言: 继上一章: 向爬虫而生---Redis 拓宽篇2 &#xff1c;Pub/Sub发布订阅&#xff1e;-CSDN博客 这一章的用处其实不是特别大,主要是针对一些地图和距离业务的;就是Redis的GEO模块。 GEO模块是Redis提供的一种高效的地理位置数据管理方案&#xff0c;它允许我们存储和查询…

1868_C语言单向链表的实现

Grey 全部学习内容汇总&#xff1a; GitHub - GreyZhang/c_basic: little bits of c. 1868_C语言中简单的链表实现 简单整理一下链表的实现&#xff0c;这一次结合前面看到的一些代码简单修改做一个小结。 主题由来介绍 以前工作之中链表的使用其实不多&#xff0c;主要是…

vue多tab页面全部关闭后自动退出登录

业务场景&#xff1a;主项目是用vue写的单页面应用&#xff0c;但是有多开页面的需求&#xff0c;现在需要在用户关闭了所有的浏览器标签页面后&#xff0c;自动退出登录。 思路&#xff1a;因为是不同的tab页面&#xff0c;我只能用localStorage来通信&#xff0c;新打开一个…

axios拦截器的使用?

Axios是一个基于Promise的HTTP库&#xff0c;可以用于浏览器和Node.js。Axios具有拦截请求和响应的能力&#xff0c;使得我们可以在请求被发送之前或响应被处理之前对其进行修改或查看。下面是一个Axios拦截器的简单示例&#xff1a; 1.添加请求拦截器&#xff1a; axios.in…

LightGlue-OpenCV 实现实时相机图片特征点匹配

LightGlue-OpenCV 文章目录 LightGlue-OpenCVStep 1: 创建虚拟环境Step 2: 安装 LightGlue-OpenCV 并运行Step3: 运行 demo_camera.py效果 原理 LightGlue 是一种新的基于深度神经网络&#xff0c;用来匹配图像中的局部特征的深度匹配器。是 SuperGlue 的加强版本。相比于 Supe…

Qt/QML编程学习之心得:Linux下USB接口使用(25)

很多linux嵌入式系统都有USB接口,那么如何使用USB接口呢? 首先,linux的底层驱动要支持,在linux kernal目录下可以找到对应的dts文件,(device tree) usb0: usb@ee520000{compatible = "myusb,musb";status = "disabled";reg = <0xEE520000 0x100…

【C程序设计】C指针

学习 C 语言的指针既简单又有趣。通过指针&#xff0c;可以简化一些 C 编程任务的执行&#xff0c;还有一些任务&#xff0c;如动态内存分配&#xff0c;没有指针是无法执行的。所以&#xff0c;想要成为一名优秀的 C 程序员&#xff0c;学习指针是很有必要的。 正如您所知道的…

探索LinkedIn:使用TypeScript和jsdom库的高级内容下载器

概述 LinkedIn是一个专业的社交网络平台&#xff0c;拥有超过7亿的用户和数以亿计的职位、公司和教育机构的信息。对于数据分析师、市场营销人员、招聘人员和其他对LinkedIn数据感兴趣的人来说&#xff0c;能够从LinkedIn上获取和分析这些信息是非常有价值的。 因此&#xff0…

如何恢复Mac误删文件?

方法1. 使用撤消命令 当你在 Mac 上删除了错误的文件并立即注意到你的错误时&#xff0c;你可以使用撤消命令立即恢复它。顾名思义&#xff0c;此命令会反转上次完成的操作&#xff0c;并且有多种方法可以调用它。如果你已经采取了其他操作或退出了用于删除文件的应用程序&…

虾皮怎么选品:虾皮(Shopee)跨境电商业务成功的关键步骤

在虾皮&#xff08;Shopee&#xff09;平台上进行跨境电商业务&#xff0c;选品是至关重要的一环。有效的选品策略可以帮助卖家更好地了解市场需求&#xff0c;提高销售业绩和客户满意度。以下是一些成功的选品策略&#xff0c;可以帮助卖家在虾皮平台上取得更好的业务成绩。 先…