(表征学习论文阅读)FINITE SCALAR QUANTIZATION: VQ-VAE MADE SIMPLE

1. 前言

向量量化(Vector Quantization)或称为矢量量化最早在1984年由Gray提出,主要应用于数据压缩、检索领域,具体的阐述可以参考我写的另一篇关于VQ算法的文章。随着基于神经网络的离散表征学习模型的兴起,VQ技术也开始重新被重视。它在图像、音频等表征学习中体现出了优秀的性能,并且有希望成为多模态大语言模型的重要组件。

在AI领域,最为知名应该是VQ-VAE(Vector Quantized-Variational Autoencoder)了,它的思想是将图像 x x x映射为表征 z k × d z^{k \times d} zk×d,其中 z k × d z^{k \times d} zk×d由一组维度为 d d d的特征向量构成,VQ-VAE引入了一个codebook记为 C n × d C^{n \times d} Cn×d z k × d z^{k \times d} zk×d会和 C n × d C^{n \times d} Cn×d中的向量进行距离计算,可以是欧式距离也可以是余弦相似度,用 C n × d C^{n \times d} Cn×d中距离最近或者最相似的向量来表示 z k × d z^{k \times d} zk×d中的向量。这种量化操作往往不可微,因此VQ-VAE使用了一个非常简单的技巧straight through estimator (STE)来解决,具体的实现可以看代码。

VQ-VAE的损失函数主要由三个部分组成,以确保模型能够有效地学习到有用的离散表征,并同时保持输入数据的重建质量:
L = L recon + α L quant + β L commit L = L_{\text{recon}} + \alpha L_{\text{quant}} + \beta L_{\text{commit}} L=Lrecon+αLquant+βLcommit

  • 重建损失(Reconstruction
    Loss):这部分的损失计算了模型重建的输出与原始输入之间的差异。目标是最小化这一差异,以确保重建的数据尽可能接近原数据。常见的重建损失包括均方误差(MSE)或交叉熵损失,具体取决于输入数据的类型。
  • 量化损失(Quantization Loss)或 码本损失(Codebook Loss):在训练过程中,当输入数据通过编码器被编码到潜在空间后,每个潜在表示会被量化为最近的码本向量。量化损失计算潜在表示与其对应的最近码本向量之间的距离。通过最小化量化损失,模型优化码本向量的位置,使其更好地代表输入数据的潜在表示。这有助于模型更准确地量化潜在空间,并提高重建质量。
  • 提交损失(Commitment Loss):提交损失主要用于稳定训练过程,它鼓励编码器生成的潜在表示靠近选中的码本向量。这样做可以防止码本向量在训练过程中出现较大的变动,从而确保模型的稳定性。提交损失通过计算编码器输出的潜在表示与选中的码本向量之间的距离来实现其目标。因此,提交损失主要影响编码器的参数更新,帮助编码器学习生成与码本向量更接近的潜在表示。

虽然VQ-VAE的效果比传统的VAE要好,但是它使用的codebook中的大部分向量并未被利用到,造成了存储和计算的大量浪费,此外,它额外引入的两项损失即codebook loss和commitment loss也带来些许复杂性。

FSQ(FINITE SCALAR QUANTIZATION: VQ-VAE MADE SIMPLE)这篇文章的目的就是优化以上两个问题。

2. 方法

作者发现,传统的编码器所得到的表征向量 z z z中的每一个元素(标量)的值并没有一个明确的边界,也就是说 z z z在特征空间中不受任何约束。那么,作者就想到了为 z z z中的每个标量都设定好取值的范围和能够取值的个数。
在这里插入图片描述
假设有一个d维特征向量 z z z,将每个标量 z i z_i zi都限制只能取 L L L个值,将 z i → ⌊ L / 2 ⌋ t a n h ( z i ) z_i \rightarrow \left\lfloor L/2 \right\rfloor tanh(z_i) ziL/2tanh(zi)然后四舍五入为一个整数值。例如图中所示,取d=3,L=3,代表codebook C = { ( − 1 , − 1 , − 1 ) , ( − 1 , − 1 , 0 ) , . . . , ( 1 , 1 , 1 ) } C=\left\{(-1, -1, -1), (-1, -1, 0), ..., (1, 1, 1)\right\} C={(1,1,1),(1,1,0),...,(1,1,1)},一共有27种组合,即一个3维向量的每个标量都有三种值的取法。值得一提的是,FSQ中的codebook不像VQ-VAE那样是显式存在的,而是隐式的,编码器直接输出量化后的特征向量 z ^ \hat{z} z^。因此,FSQ也就没有了VQ-VAE损失的后两项了。
在这里插入图片描述

3. 代码实现

from typing import List, Tuple, Optional
import torch
import torch.nn as nn
from torch.nn import Module
from torch import Tensor, int32
from torch.cuda.amp import autocastfrom einops import rearrange, pack, unpack# helper functionsdef exists(v):return v is not Nonedef default(*args):for arg in args:if exists(arg):return argreturn Nonedef pack_one(t, pattern):return pack([t], pattern)def unpack_one(t, ps, pattern):return unpack(t, ps, pattern)[0]# tensor helpersdef round_ste(z: Tensor) -> Tensor:"""Round with straight through gradients."""zhat = z.round()  # round操作是将z中的元素四舍五入到最接近的整数return z + (zhat - z).detach()class FSQ(Module):def __init__(self,levels: List[int],dim: Optional[int] = None,num_codebooks=1,keep_num_codebooks_dim: Optional[bool] = None,scale: Optional[float] = None,allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64)):super().__init__()_levels = torch.tensor(levels, dtype=int32)self.register_buffer("_levels", _levels, persistent=False)  #persistent=False表示不会被保存到checkpoint中_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)self.register_buffer("_basis", _basis, persistent=False)self.scale = scalecodebook_dim = len(levels)  # codebook_dim表示每个codebook的维度self.codebook_dim = codebook_dimeffective_codebook_dim = codebook_dim * num_codebooks  # effective_codebook_dim表示所有codebook的维度的总和self.num_codebooks = num_codebooksself.effective_codebook_dim = effective_codebook_dimkeep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)assert not (num_codebooks > 1 and not keep_num_codebooks_dim)self.keep_num_codebooks_dim = keep_num_codebooks_dimself.dim = default(dim, len(_levels) * num_codebooks)has_projections = self.dim != effective_codebook_dimself.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()self.has_projections = has_projectionsself.codebook_size = self._levels.prod().item()implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False)self.register_buffer("implicit_codebook", implicit_codebook, persistent=False)self.allowed_dtypes = allowed_dtypesdef bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:"""Bound `z`, an array of shape (..., d)."""half_l = (self._levels - 1) * (1 + eps) / 2offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)shift = (offset / half_l).atanh()  # atanh是双曲正切函数的反函数,能够将值映射到[-1, 1]之间return (z + shift).tanh() * half_l - offsetdef quantize(self, z: Tensor) -> Tensor:"""Quantizes z, returns quantized zhat, same shape as z."""quantized = round_ste(self.bound(z))half_width = self._levels // 2  # Renormalize to [-1, 1].return quantized / half_widthdef _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:# 将zhat_normalized的值映射到[0, levels]之间half_width = self._levels // 2return (zhat_normalized * half_width) + half_widthdef _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor:half_width = self._levels // 2return (zhat - half_width) / half_widthdef codes_to_indices(self, zhat: Tensor) -> Tensor:"""Converts a `code` to an index in the codebook."""assert zhat.shape[-1] == self.codebook_dimzhat = self._scale_and_shift(zhat)return (zhat * self._basis).sum(dim=-1).to(int32)def indices_to_codes(self,indices: Tensor,project_out=True) -> Tensor:"""Inverse of `codes_to_indices`."""is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))indices = rearrange(indices, '... -> ... 1')codes_non_centered = (indices // self._basis) % self._levelscodes = self._scale_and_shift_inverse(codes_non_centered)if self.keep_num_codebooks_dim:codes = rearrange(codes, '... c d -> ... (c d)')if project_out:codes = self.project_out(codes)if is_img_or_video:codes = rearrange(codes, 'b ... d -> b d ...')return codes@autocast(enabled=False)def forward(self, z: Tensor) -> Tensor:"""einstein notationb - batchn - sequence (or flattened spatial dimensions)d - feature dimensionc - number of codebook dim"""orig_dtype = z.dtypeis_img_or_video = z.ndim >= 4# make sure allowed dtypeif z.dtype not in self.allowed_dtypes:z = z.float()# standardize image or video into (batch, seq, dimension)if is_img_or_video:# 将图片和视频的空间、时间维度展平z = rearrange(z, 'b d ... -> b ... d')z, ps = pack_one(z, 'b * d')assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'z = self.project_in(z)z = rearrange(z, 'b n (c d) -> b n c d', c=self.num_codebooks)codes = self.quantize(z)print(f"codes: {codes}")indices = self.codes_to_indices(codes)codes = rearrange(codes, 'b n c d -> b n (c d)')out = self.project_out(codes)# reconstitute image or video dimensionsif is_img_or_video:out = unpack_one(out, ps, 'b * d')out = rearrange(out, 'b ... d -> b d ...')indices = unpack_one(indices, ps, 'b * c')if not self.keep_num_codebooks_dim:indices = rearrange(indices, '... 1 -> ...')# cast back to original dtypeif out.dtype != orig_dtype:out = out.type(orig_dtype)# return quantized output and indicesreturn out, indices

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/790639.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

二维动画制作软件 Animate 2024 for mac激活版

Animate 2024 for Mac是一款功能强大的二维动画制作软件,专为Mac用户打造。它提供了丰富的动画编辑功能,使用户能够轻松创建出生动逼真的动画作品。无论是短片、广告还是游戏等应用领域,Animate 2024都能发挥出出色的表现。 软件下载&#xf…

部署k8s客户端,及docker私仓部署

1.部署一个docker私仓 mkdir /opt/docker/registry #配置仓库密码 mkdir /opt/docker/auth cd /opt/docker/auth htpasswd -Bbn admin admin > htpasswd#运行docker私仓服务,下面端口5000:5000 前面的5000对应本机端口可以自定义 docker run -itd \ -v /opt/d…

【Layui】------ layui实现table表格拖拽行、列位置的示例代码

一、完整的示例代码&#xff1a;&#xff08;请使用layui v2.8.3的版本&#xff09;看懂就能用、不要照搬、照搬会出错误、拷贝重要代码改改符合你自己的需求。 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><…

MapReduce [OSDI‘04] 论文阅读笔记

原论文&#xff1a;MapReduce: Simplified Data Processing on Large Clusters (OSDI’04) 1. Map and Reduce Map&#xff1a;处理键值对&#xff0c;生成一组中间键值对Reduce&#xff1a;合并与同一中间键相关的所有中间值process overview&#xff1a;分割输入数据&#x…

DSO9254A安捷伦DSO9254A示波器

181/2461/8938产品概述&#xff1a; 安捷伦DSO9254A的带宽为2.5 GHz&#xff0c;配备15英寸XGA液晶显示屏&#xff0c;采用静音封装&#xff0c;厚度仅为9英寸&#xff08;23厘米&#xff09;&#xff0c;重量仅为26磅&#xff08;11.8千克&#xff09;。DSO9254A集成了一个功…

UE4_自定义反射和折射和法线图

UE4 自定义反射和折射和法线图 2020-05-22 09:36 将ReflectionVector和反射图像进行ViewAlignedReflection,输出的textrue和相机位置CameraPosition的onePlus进行Dot点乘之后乘以一个float系数反射度&#xff0c;输出给固有色&#xff0c;就有反射效果了。球型反射。 折射&…

Coze工作流介绍(一)

Coze工作流介绍 工作流支持通过可视化的方式&#xff0c;对插件、大语言模型、代码块等功能进行组合&#xff0c;从而实现复杂、稳定的业务流程编排&#xff0c;例如旅行规划、报告分析等。 当目标任务场景包含较多的步骤&#xff0c;且对输出结果的准确性、格式有严格要求时…

JAVAEE—Callable接口,ReentrantLock,synchronized的工作过程

文章目录 Callable接口的用法Callable与FutureTask类 加锁的工作过程什么是偏向锁呢&#xff1f;举个例子 轻量级锁重量级锁 ReentrantLockReentrantLock 的用法: Callable接口的用法 Callable 是一个 interface . 相当于把线程封装了一个 “返回值”. 方便程序猿借助多线程的…

Ubuntu20.04使用Neo4j导入CSV数据可视化知识图谱

1.安装JDK&#xff08; Ubuntu20.04 JDK11&#xff09; sudo apt-get install openjdk-11-jdk -y java -version which java ls -l /usr/bin/java ls -l /etc/alternatives/java ls -l /usr/lib/jvm/java-11-openjdk-amd64/bin/java确认安装路径为/usr/lib/jvm/java-11-openjd…

Celery的任务流

Celery的任务流 在之前调用任务的时候只是使用delay()和apply_async()方法。但是有时我们并不想简单的执行单个异步任务&#xff0c;比如说需要将某个异步任务的结果作为另一个异步任务的参数或者需要将多个异步任务并行执行&#xff0c;返回一组返回值&#xff0c;为了实现此…

STL是什么?如何理解STL?

文章目录 1. 什么是STL2. STL的版本3. STL的六大组件4. 如何学习STL5.STL的缺陷 1. 什么是STL STL(standard template libaray-标准模板库)&#xff1a;是C标准库的重要组成部分&#xff0c;不仅是一个可复用的组件库&#xff0c;而且是一个包罗数据结构与算法的软件框架。 2. …

OpenHarmony实战开发-使用一次开发多端部署实现一多设置典型页面

介绍 本示例展示了设置应用的典型页面&#xff0c;其在小窗口和大窗口有不同的显示效果&#xff0c;体现一次开发、多端部署的能力。 1.本示例使用一次开发多端部署中介绍的自适应布局能力和响应式布局能力进行多设备&#xff08;或多窗口尺寸&#xff09;适配&#xff0c;保…

WebGIS 之 vue3+vite+ceisum

1.项目搭建node版本在16以上 1.1创建项目 npm create vite 项目名 1.2选择框架 vuejavaScript 1.3进入项目安装依赖 cd 项目名 npm install 1.4安装cesium依赖 pnpm i cesium vite-plugin-cesium 1.5修改vite.config.js文件 import { defineConfig } from vite import vue fr…

RK3568 RTC驱动实验

RK3568 RTC驱动实验 1. RTC简介 ​ RTC 也就是实时时钟&#xff0c;用于记录当前系统时间&#xff0c;对于 Linux 系统而言时间是非常重要的&#xff0c;使用 Linux 设备的时候也需要查看时间。RTC是Linux的时间系统。 ​ RTC 设备驱动是一个标准的字符设备驱动&#xff0c;…

基于Python的微博旅游情感分析、微博舆论可视化系统

博主介绍&#xff1a;✌程序员徐师兄、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;…

Python网络爬虫(三):Selenium--以携程酒店为例

1 Selenium简介 Selenium是一个用于网站应用程序自动化的工具&#xff0c;它可以直接运行在浏览器中&#xff0c;就像真正的用户在操作一样。它相当于一个机器人&#xff0c;可以模拟人类在浏览器上的一些行为&#xff0c;比如输入文本、点击、回车等。Selenium支持多种浏览器&…

记录一次官网访问很慢的情况

客户查看云监控,带宽未超限,客户取的是1分钟的原生值,也就是1分钟也是个平均值。 但是客户的原始值&#xff0c;其实就是1分钟内的平均值。所以客户的瞬时超限&#xff0c;其实是看不出来的。但是后端同事从实时监控里面可以看到超限的情况。 客户升带宽后&#xff0c; 发现还…

Flutter 应用数据持久化指南

1. 介绍 1.1 什么是数据持久化&#xff1f; 数据持久化是指将应用程序中的数据保存在持久存储介质&#xff08;如硬盘、数据库等&#xff09;中的过程。在计算机科学领域&#xff0c;持久化数据是指数据在程序退出或系统关机后仍然存在的能力。这种持久性使得数据可以在不同的…

是德科技keysight 33621A波形发生器

181/2461/8938产品概述&#xff1a; 与上一代DDS波形发生器相比&#xff0c;采用独家Trueform技术的安捷伦HP 33621A波形发生器具有更高的性能、保真度和灵活性。安捷伦HP 33621A 120 MHz、单通道、Trueform arbs&#xff0c;带时序控制和64 MSa存储器&#xff0c;1 ps抖动&am…

go juc 线程中的子类

1.go test() 主死随从 package mainimport ("fmt""strconv""time" )func test() {for i : 1; i < 10; i {fmt.Println("hello " strconv.Itoa(i))//阻塞time.Sleep(time.Second)} } func main() {//开启协程go test()for i : 1; …