经典神经网络(11)VQ-VAE模型及其在MNIST数据集上的应用

经典神经网络(11)VQ-VAE模型及其在MNIST数据集上的应用

  • 我们之前已经了解了PixelCNN模型。

    经典神经网络(10)PixelCNN模型、Gated PixelCNN模型及其在MNIST数据集上的应用

  • 今天,我们了解下DeepMind在2017年提出的一种基于离散隐变量(Discrete Latent variables)的生成模型:VQ-VAE。

  • VQ-VAE采用离散隐变量,而不是像VAE那样采用连续的隐变量。其实VQ-VAE本质上是一种AE,只能很好地完成图像压缩,把图像变成一个短得多的向量,而不支持随机图像生成。

  • 那么,VQ-VAE会被归类到图像生成模型中呢?这是因为VQ-VAE单独训练了一个基于自回归的模型如PixelCNN来学习先验(prior),对VQ-VAE的离散编码空间采样。而不是像VAE那样采用一个固定的先验(标准正态分布)。

  • 此外,VQ-VAE还是一个强大的无监督表征学习模型,它学习的离散编码具有很强的表征能力:

    • OpenAI在2021年发布的文本转图像模型DALL-E就是基于VQ-VAE。
    • 另外,在BEiT中也用VQ-VAE得到的离散编码作为训练目标。
    • 注:推荐下EleutherAI团队的lucidrains(Phil Wang)的github,他开源复现了ViT、AlphaFold 2、DALLE、 DALLE2、imagen等项目
      https://github.com/lucidrains

1 VQ-VAE

1.1 从AE到VQ-VAE

  • AE是一类能够把图片压缩成较短的向量的神经网络模型。

    • AE的编码器编码出来的向量空间是不规整的。也就是说,解码器只认识经编码器编出来的向量,而不认识其他的向量。
    • 如下图,我们在code空间上,两张图片的编码点中间处取一点,然后将这一点交给解码器,我们希望新的生成图片是一张清晰的图片(类似3/4全月的样子)。但是,实际的结果是,生成图片是模糊且无法辨认的乱码图。
      在这里插入图片描述
  • 只要AE的编码空间比较规整,符合某个简单的数学分布(比如最常见的标准正态分布,如下图所示),那我们就可以从这个分布里随机采样向量,再让解码器根据这个向量来完成随机图片生成了。

    • VAE就是这样一种改进版的AE,它用一些巧妙的方法约束了编码向量z,使得z满足标准正态分布。
    • 训练完成后,我们就可以扔掉编码器,用来自标准正态分布的随机向量和解码器来实现随机图像生成了。

    在这里插入图片描述

  • VQ-VAE的作者认为,VAE的生成图片之所以质量不高,是因为图片被编码成了连续向量。而实际上,把图片编码成离散向量会更加自然。

  • 至于离散编码的原因,作者解释如下:https://avdnoord.github.io/homepage/slides/SANE2017.pdf

    在这里插入图片描述

1.2 VQVAE概述

把图像编码成离散向量后,会带来两个问题:

  • 第一个问题是,神经网络会默认输入满足一个连续的分布,而不善于处理离散的输入。

    • 如果你直接输入0, 1, 2这些数字,神经网络会默认1是一个处于0, 2中间的一种状态。为了解决这一问题,我们可以借鉴NLP中对于离散单词的处理方法。
    • 我们可以把嵌入层加到VQ-VAE的解码器前,这个嵌入层就是embedding space(嵌入空间),也称codebook
    • 注意:其实Encoder编码出来的是二维离散编码,下图画的是一维。

    在这里插入图片描述

  • 另一个问题是离散向量不好采样。

    • VAE之所以把图片编码成符合正态分布的连续向量,就是为了能在图像生成时把编码器扔掉,让随机采样出的向量也能通过解码器变成图片。现在,VQ-VAE把图片编码了一个离散向量,这个离散向量构成的空间是不好采样的。
    • VQ-VAE的作者之前设计了一种图像生成网络,叫做PixelCNN。可以用PixelCNN生成离散编码,再利用VQ-VAE的解码器把离散编码变成图像。
  • VQ-VAE的架构图,如下图所示:

    • 训练VQ-VAE的编码器和解码器,使得VQ-VAE能把图像变成latent image(下图zq),也能把latent image(下图zq)变回图像。
    • 训练PixelCNN,让它学习怎么生成latent image(下图zq)
    • 生成(采样)时,先用PixelCNN采样出latent image(下图zq),再用VQ-VAE把latent image(下图zq)翻译成最终的生成图像。

在这里插入图片描述

1.3 VQ-VAE设计细节

1.3.1 关联编码器的输出与解码器的输入

如何关联编码器的输出与解码器的输入呢?

  • 假设嵌入空间codebook已经训练完毕,那么对于编码器的每个输出向量 z e ( x ) ze(x) ze(x),我们需要找出它在嵌入空间里的最近邻 z q ( x ) zq(x) zq(x),把 z e ( x ) ze(x) ze(x)替换成 z q ( x ) zq(x) zq(x)作为解码器的输入。
  • 方式是:求最近邻,即先计算向量与嵌入空间K个向量每个向量的距离,再对距离数组取一个argmin,求出最近的下标(如上图中的shape为[1,7,7]),最后用下标去嵌入空间里取向量,就得到了 z q zq zq(如上图中的shape为[1,32,7,7])。下标构成的多维数组,也正是VQ-VAE的离散编码。

1.3.2 梯度复制

  • 我们现在能把编码器和解码器拼接到一起,但怎么让梯度从解码器的输入 z q ( x ) zq(x) zq(x)传到 z e ( x ) ze(x) ze(x)?从 z e ( x ) ze(x) ze(x) z q ( x ) zq(x) zq(x)的变换是一个从数组里取值,这个操作无法求导。
  • VQ-VAE使用了一种叫做"straight-through estimator"的技术【即前向传播和反向传播的计算可以不对应】来完成梯度复制。VQ-VAE使用了一种叫做sg(stop gradient,停止梯度)的运算:

s g ( x ) = { x , 前向传播 0 , 反向传播 前向传播时, s g 里的值不变;反向传播时, s g 按值为 0 求导,即此次计算无梯度。 sg(x)=\begin{cases} x, & 前向传播\\ 0,& 反向传播 \end{cases}\\ 前向传播时,sg里的值不变;反向传播时,sg按值为0求导,即此次计算无梯度。 sg(x)={x,0,前向传播反向传播前向传播时,sg里的值不变;反向传播时,sg按值为0求导,即此次计算无梯度。

由于VQ-VAE其实是一个AE,误差函数里应该只有原图像和目标图像的重建误差:
L r e c o n s t r u c t = ∣ ∣ x − d e c o d e r ( z q ( x ) ) ∣ ∣ 2 2 L_{reconstruct}=||x-decoder(z_q(x))||_2^2 Lreconstruct=∣∣xdecoder(zq(x))22
我们现在利用sg运算,设计新的重建误差:
L r e c o n s t r u c t = ∣ ∣ x − d e c o d e r ( z e ( x ) + s g [ z q ( x ) − z e ( x ) ] ) ∣ ∣ 2 2 前向传播时,就是拿解码器的输入 z q ( x ) 来算误差: L r e c o n s t r u c t = ∣ ∣ x − d e c o d e r ( z e ( x ) + z q ( x ) − z e ( x ) ) ∣ ∣ 2 2 = ∣ ∣ x − d e c o d e r ( z q ( x ) ) ∣ ∣ 2 2 反向传播时,等价于把解码器的梯度全部传给 z e ( x ) : L r e c o n s t r u c t = ∣ ∣ x − d e c o d e r ( z e ( x ) + s g [ z q ( x ) − z e ( x ) ] ) ∣ ∣ 2 2 = ∣ ∣ x − d e c o d e r ( z e ( x ) ) ∣ ∣ 2 2 L_{reconstruct}=||x-decoder(z_e(x)+sg[z_q(x)-z_e(x)])||_2^2\\ 前向传播时,就是拿解码器的输入z_q(x)来算误差:\\ L_{reconstruct}=||x-decoder(z_e(x)+z_q(x)-z_e(x))||_2^2\\ =||x-decoder(z_q(x))||_2^2\\ 反向传播时,等价于把解码器的梯度全部传给z_e(x):\\ L_{reconstruct}=||x-decoder(z_e(x)+sg[z_q(x)-z_e(x)])||_2^2\\ =||x-decoder(z_e(x))||_2^2 Lreconstruct=∣∣xdecoder(ze(x)+sg[zq(x)ze(x)])22前向传播时,就是拿解码器的输入zq(x)来算误差:Lreconstruct=∣∣xdecoder(ze(x)+zq(x)ze(x))22=∣∣xdecoder(zq(x))22反向传播时,等价于把解码器的梯度全部传给ze(x)Lreconstruct=∣∣xdecoder(ze(x)+sg[zq(x)ze(x)])22=∣∣xdecoder(ze(x))22
在PyTorch里,(x).detach()就是sg(x),它的值在前向传播时取x,反向传播时取0

# stop gradient
decoder_input = ze + (zq - ze).detach()
# decode
x_hat = decoder(decoder_input)
# l_reconstruct
l_reconstruct = mse_loss(x, x_hat)

1.3.3 优化嵌入空间codebook

嵌入空间的优化目标是什么呢?嵌入空间的每一个向量应该能概括一类编码器输出的向量。因此,嵌入空间的向量应该和其对应编码器输出尽可能接近。
L e = ∣ ∣ z e ( x ) − z q ( x ) ∣ ∣ 2 2 z e ( x ) 是编码器的输出向量, z q ( x ) 是其在嵌入空间的最近邻向量 L_e=||z_e(x)-z_q(x)||_2^2\\ z_e(x)是编码器的输出向量,z_q(x)是其在嵌入空间的最近邻向量 Le=∣∣ze(x)zq(x)22ze(x)是编码器的输出向量,zq(x)是其在嵌入空间的最近邻向量
作者认为,编码器和嵌入向量的学习速度应该不一样快。

于是,他们再次使用了停止梯度的技巧,把上面那个误差函数拆成了两部分。其中,β控制了编码器的相对学习速度。作者发现,算法对β的变化不敏感,β取0.1~2.0都差不多。
L e = ∣ ∣ s g [ z e ( x ) ] − z q ( x ) ∣ ∣ 2 2 + β ∣ ∣ z e ( x ) − s g [ z q ( x ) ] ∣ ∣ 2 2 L_e=||sg[z_e(x)]-z_q(x)||_2^2+\beta||z_e(x)-sg[z_q(x)]||_2^2\\ Le=∣∣sg[ze(x)]zq(x)22+β∣∣ze(x)sg[zq(x)]22

# vq loss
l_embedding = mse_loss(ze.detach(), zq)
# commitment loss
l_commitment = mse_loss(ze, zq.detach())

VQ-VAE总体的损失函数可以写成:
L t o t a l = L r e c o n s t r u c t + L e = ∣ ∣ x − d e c o d e r ( z e ( x ) + s g [ z q ( x ) − z e ( x ) ] ) ∣ ∣ 2 2 + α ∣ ∣ s g [ z e ( x ) ] − z q ( x ) ∣ ∣ 2 2 + β ∣ ∣ z e ( x ) − s g [ z q ( x ) ] ∣ ∣ 2 2 L_{total}=L_{reconstruct} + L_e \\ =||x-decoder(z_e(x)+sg[z_q(x)-z_e(x)])||_2^2 +\alpha||sg[z_e(x)]-z_q(x)||_2^2\\+\beta||z_e(x)-sg[z_q(x)]||_2^2 Ltotal=Lreconstruct+Le=∣∣xdecoder(ze(x)+sg[zq(x)ze(x)])22+α∣∣sg[ze(x)]zq(x)22+β∣∣ze(x)sg[zq(x)]22

# reconstruct loss
l_reconstruct = mse_loss(x, x_hat)
# vq loss
l_embedding = mse_loss(ze.detach(), zq)
# commitment loss
l_commitment = mse_loss(ze, zq.detach())# total loss
loss = l_reconstruct + \l_w_embedding * l_embedding + l_w_commitment * l_commitment

1.3.4 先验模型PixelCNN

  • 训练好VQ-VAE后,还需要训练一个先验模型来完成数据生成,论文中采用PixelCNN模型。
  • 这里我们不再是学习生成原始的pixels,而是学习生成离散编码:
    • 首先,我们需要用已经训练好的VQ-VAE模型对训练图像推理,得到每张图像对应的离散编码;
    • 然后用一个PixelCNN来对离散编码进行建模
    • 最后的预测层采用基于softmax的多分类,类别数为embedding空间的大小K。
  • 那么,生成图像的过程就比较简单了,首先用训练好的PixelCNN模型来采样一个离散编码样本(上图中shape为[1, 32, 7, 7]),然后送入VQ-VAE的decoder中,得到生成的图像。
  • 实际上,PixelCNN不是唯一可用的拟合离散分布的模型。我们可以把它换成Transformer,甚至是diffusion模型。

2 VQ-VAE模型在MNIST数据集上的应用

这里使用的模型为Gated PixelCNN模型,具体可参考:

经典神经网络(10)PixelCNN模型、Gated PixelCNN模型及其在MNIST数据集上的应用

网络结构图如下所示:

在这里插入图片描述

2.1 VQ-VAE模型

VQVAE的编码器和解码器的结构很简单,仅由普通的上/下采样层和残差块组成。

  • 编码器先是有两个3x3卷积+2倍下采样卷积的模块,再有两个残差块(ReLU, 3x3卷积, ReLU, 1x1卷积);
  • 解码器则反过来,先有两个残差块,再有两个3x3卷积+2倍上采样反卷积的模块。
# Reference: https://github.com/SingleZombie/DL-Demos/tree/master/dldemos/VQVAE
import os
import timeimport cv2
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as Fimport torchvision
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torchvision import transforms
from GatedPixelCNNDemo import GatedPixelCNN, GatedBlockclass ResidualBlock(nn.Module):def __init__(self, dim):super().__init__()self.relu = nn.ReLU()self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)self.conv2 = nn.Conv2d(dim, dim, kernel_size=1)def forward(self, x):tmp = self.relu(x)tmp = self.conv1(tmp)tmp = self.relu(tmp)tmp = self.conv2(tmp)return x + tmpclass VQVAE(nn.Module):def __init__(self, input_dim, dim, n_embedding):super().__init__()# 1、编码器self.encoder = nn.Sequential(nn.Conv2d(input_dim, dim, kernel_size=4, stride=2, padding=1),nn.ReLU(),nn.Conv2d(dim, dim, kernel_size=4, stride=2, padding=1),nn.ReLU(),nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1),ResidualBlock(dim),ResidualBlock(dim))self.vq_embedding = nn.Embedding(n_embedding, dim)# 初始化为均匀分布self.vq_embedding.weight.data.uniform_(-1.0 / n_embedding, 1.0 / n_embedding)# 2、解码器self.decoder = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1),ResidualBlock(dim),ResidualBlock(dim),nn.ConvTranspose2d(dim, dim, 4, 2, 1),nn.ReLU(),nn.ConvTranspose2d(dim, input_dim, 4, 2, 1))self.n_downsample = 2def forward(self, x):# encode [N, 1, 28, 28] -> [N, 32, 7, 7]ze = self.encoder(x)# ze: [N, C, H, W]# embedding [K, C]  [32, 32]embedding = self.vq_embedding.weight.dataN, C, H, W = ze.shapeK, _ = embedding.shape# 求解最近邻embedding_broadcast = embedding.reshape(1, K, C, 1, 1)ze_broadcast = ze.reshape(N, 1, C, H, W)distance = torch.sum((embedding_broadcast - ze_broadcast)**2, 2)nearest_neighbor = torch.argmin(distance, 1)# make C to the second dimzq = self.vq_embedding(nearest_neighbor).permute(0, 3, 1, 2)# stop gradientdecoder_input = ze + (zq - ze).detach()# decodex_hat = self.decoder(decoder_input)return x_hat, ze, zq@torch.no_grad()def encode(self, x):ze = self.encoder(x)embedding = self.vq_embedding.weight.data# ze: [N, C, H, W]# embedding [K, C]N, C, H, W = ze.shapeK, _ = embedding.shapeembedding_broadcast = embedding.reshape(1, K, C, 1, 1)ze_broadcast = ze.reshape(N, 1, C, H, W)distance = torch.sum((embedding_broadcast - ze_broadcast)**2, 2)nearest_neighbor = torch.argmin(distance, 1)return nearest_neighbor@torch.no_grad()def decode(self, discrete_latent):zq = self.vq_embedding(discrete_latent).permute(0, 3, 1, 2)x_hat = self.decoder(zq)return x_hat# Shape: [C, H, W]def get_latent_HW(self, input_shape):C, H, W = input_shapereturn (H // 2**self.n_downsample, W // 2**self.n_downsample)

2.2 先验模型

我们已经有了一个普通的PixelCNN模型GatedPixelCNN

  • 需要在整个模型的最前面套一个嵌入层,嵌入层的嵌入个数等于离散编码的个数(color_level),嵌入长度等于模型的特征长度(p)。
  • 由于嵌入层会直接输出一个长度为p的向量,我们还需要把第一个模块的输入通道数改成p
# 继承自我们之前实现的模型GatedPixelCNN
class PixelCNNWithEmbedding(GatedPixelCNN):def __init__(self, n_blocks, p, linear_dim, bn=True, color_level=256):super().__init__(n_blocks, p, linear_dim, bn, color_level)self.embedding = nn.Embedding(color_level, p)self.block1 = GatedBlock('A', p, p, bn)def forward(self, x):x = self.embedding(x)x = x.permute(0, 3, 1, 2).contiguous()return super().forward(x)

2.3 两种模型的训练

  • 下面就是常规的训练代码
  • 先训练VQVAE、再训练PixelCNN
def train_vqvae(model: VQVAE,img_shape=None,device='cuda',ckpt_path='./model.pth',batch_size=64,dataset_type='MNIST',lr=1e-3,n_epochs=100,l_w_embedding=1,l_w_commitment=0.25):print('batch size:', batch_size)dataloader = get_dataloader(dataset_type,batch_size,img_shape=img_shape)model.to(device)model.train()optimizer = torch.optim.Adam(model.parameters(), lr)mse_loss = nn.MSELoss()tic = time.time()for e in range(n_epochs):total_loss = 0for x in dataloader:current_batch_size = x.shape[0]x = x.to(device)x_hat, ze, zq = model(x)# 1、reconstruct lossl_reconstruct = mse_loss(x, x_hat)# 2、vq loss + commitment lossl_embedding = mse_loss(ze.detach(), zq)l_commitment = mse_loss(ze, zq.detach())# total lossloss = l_reconstruct + \l_w_embedding * l_embedding + l_w_commitment * l_commitmentoptimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item() * current_batch_sizetotal_loss /= len(dataloader.dataset)toc = time.time()torch.save(model.state_dict(), ckpt_path)print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')print('Done')def train_generative_model(vqvae: VQVAE,model,img_shape=None,device='cuda',ckpt_path='./gen_model.pth',dataset_type='MNIST',batch_size=64,n_epochs=50):print('batch size:', batch_size)dataloader = get_dataloader(dataset_type,batch_size,img_shape=img_shape)vqvae.to(device)vqvae.eval()model.to(device)model.train()optimizer = torch.optim.Adam(model.parameters(), 1e-3)# 交叉熵损失loss_fn = nn.CrossEntropyLoss()tic = time.time()for e in range(n_epochs):total_loss = 0for x in dataloader:current_batch_size = x.shape[0]with torch.no_grad():x = x.to(device)# 1、训练好的VQ-VAE模型对训练图像推理,得到每张图像对应的离散编码x = vqvae.encode(x)# 2、用一个PixelCNN来对离散编码进行建模predict_x = model(x)# 3、预测层采用基于softmax的多分类loss = loss_fn(predict_x, x)optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item() * current_batch_sizetotal_loss /= len(dataloader.dataset)toc = time.time()torch.save(model.state_dict(), ckpt_path)print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')print('Done')def reconstruct(model, x, device, dataset_type='MNIST'):model.to(device)model.eval()with torch.no_grad():x_hat, _, _ = model(x)n = x.shape[0]n1 = int(n**0.5)x_cat = torch.concat((x, x_hat), 3)x_cat = einops.rearrange(x_cat, '(n1 n2) c h w -> (n1 h) (n2 w) c', n1=n1)x_cat = (x_cat.clip(0, 1) * 255).cpu().numpy().astype(np.uint8)cv2.imwrite(f'work_dirs/vqvae_reconstruct_{dataset_type}.jpg', x_cat)
class MNISTImageDataset(Dataset):def __init__(self, img_shape=(28, 28)):super().__init__()self.img_shape = img_shapeself.mnist = torchvision.datasets.MNIST(root='/root/autodl-fs/data/minist')def __len__(self):return len(self.mnist)def __getitem__(self, index: int):img = self.mnist[index][0]pipeline = transforms.Compose([transforms.Resize(self.img_shape),transforms.ToTensor()])return pipeline(img)def get_dataloader(type,batch_size,img_shape=None,dist_train=False,num_workers=0,**kwargs):if type == 'MNIST':if img_shape is not None:dataset = MNISTImageDataset(img_shape)else:dataset = MNISTImageDataset()if dist_train:sampler = DistributedSampler(dataset)dataloader = DataLoader(dataset,batch_size=batch_size,sampler=sampler,num_workers=num_workers)return dataloader, samplerelse:dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers)return dataloadercfg = dict(dataset_type='MNIST',img_shape=(1, 28, 28),dim=32,n_embedding=32,batch_size=32,n_epochs=20,l_w_embedding=1,l_w_commitment=0.25,lr=2e-4,n_epochs_2=50,batch_size_2=32,pixelcnn_n_blocks=15,pixelcnn_dim=128,pixelcnn_linear_dim=32,vqvae_path='./model_mnist.pth',gen_model_path='./gen_model_mnist.pth')if __name__ == '__main__':os.makedirs('work_dirs', exist_ok=True)device = 'cuda' if torch.cuda.is_available() else 'cpu'img_shape = cfg['img_shape']# 初始化模型vqvae = VQVAE(img_shape[0], cfg['dim'], cfg['n_embedding'])gen_model = PixelCNNWithEmbedding(cfg['pixelcnn_n_blocks'],cfg['pixelcnn_dim'],cfg['pixelcnn_linear_dim'], True,cfg['n_embedding'])# 1. Train VQVAEtrain_vqvae(vqvae,img_shape=(img_shape[1], img_shape[2]),device=device,ckpt_path=cfg['vqvae_path'],batch_size=cfg['batch_size'],dataset_type=cfg['dataset_type'],lr=cfg['lr'],n_epochs=cfg['n_epochs'],l_w_embedding=cfg['l_w_embedding'],l_w_commitment=cfg['l_w_commitment'])# 2. Test VQVAE by visualizaing reconstruction resultvqvae.load_state_dict(torch.load(cfg['vqvae_path']))dataloader = get_dataloader(cfg['dataset_type'],16,img_shape=(img_shape[1], img_shape[2]))img = next(iter(dataloader)).to(device)reconstruct(vqvae, img, device, cfg['dataset_type'])# 3. Train Generative model (Gated PixelCNN)vqvae.load_state_dict(torch.load(cfg['vqvae_path']))train_generative_model(vqvae,gen_model,img_shape=(img_shape[1], img_shape[2]),device=device,ckpt_path=cfg['gen_model_path'],dataset_type=cfg['dataset_type'],batch_size=cfg['batch_size_2'],n_epochs=cfg['n_epochs_2'])# 4. Sample VQVAEvqvae.load_state_dict(torch.load(cfg['vqvae_path']))gen_model.load_state_dict(torch.load(cfg['gen_model_path']))sample_imgs(vqvae,gen_model,cfg['img_shape'],device=device,n_sample=1,dataset_type=cfg['dataset_type'])

2.4 图像生成(采样)

def sample_imgs(vqvae: VQVAE,gen_model,img_shape,n_sample=81,device='cuda',dataset_type='MNIST'):vqvae = vqvae.to(device)vqvae.eval()gen_model = gen_model.to(device)gen_model.eval()C, H, W = img_shapeH, W = vqvae.get_latent_HW((C, H, W))input_shape = (n_sample, H, W)# 初始化为0x = torch.zeros(input_shape).to(device).to(torch.long)with torch.no_grad():# 逐像素预测for i in range(H):for j in range(W):output = gen_model(x)prob_dist = F.softmax(output[:, :, i, j], -1)# 从概率分布中采样pixel = torch.multinomial(prob_dist, 1)x[:, i, j] = pixel[:, 0]# 解码imgs = vqvae.decode(x)imgs = imgs * 255imgs = imgs.clip(0, 255)imgs = einops.rearrange(imgs,'(n1 n2) c h w -> (n1 h) (n2 w) c',n1=int(n_sample**0.5))imgs = imgs.detach().cpu().numpy().astype(np.uint8)cv2.imwrite(f'work_dirs/vqvae_sample_{dataset_type}.jpg', imgs)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/854896.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

celery骚操作之把任务写在类中可能造成的问题

celery注册异步函数是模块级别的,也就是同个模块不能有同名函数,比如搞个骚操作,将celery任务写在类中如下(注意这个静态方法是个特殊的装饰器,他实际是个描述器,他必须写在最上面) 实际注册的任务是apps.business.tas…

3D视觉引导机器人提升生产线的自动化水平和智能化程度

随着智能化技术的不断发展,汽车制造企业正积极寻求提升智能化水平的途径。富唯智能的3D视觉引导机器人抓取技术为汽车制造企业提供了一种高效、智能的自动化解决方案。 项目目标 某汽车制造企业希望通过引入智能化技术提升生产线的自动化水平和智能化程度。他们希望…

python错题(3)

round四舍五入 title()把单词首字母大写 all() 函数用于判断给定的可迭代参数 iterable 中的所有元素是否都为 TRUE,如果是返回 True,否则返回 False。 元素除了是 0、空、None、False 外都算 True 。空元组、空列表返回值为True,这里要特…

2023数A题——WLAN网络信道接入机制建模

A题——WLAN网络信道接入机制建模 思路:该题主要考察的WLAN下退避机制建模仿真。 资料获取 问题1: 假设AP发送包的载荷长度为1500Bytes(1Bytes 8bits),PHY头时长为13.6μs,MAC头为30Bytes,MA…

是否可以外链代发?

当然是可以的,代发外链是一种有效的提升网站SEO排名和流量的方法。通过在高质量的网站上发布包含你网站链接的内容,可以提高你网站的权重和可信度。而在所有代发外链的方式中,GPB外链无疑是最好的选择。 GPB外链,每一条GPB外链都是…

【UE4】角色御剑飞行的蓝图实现

沉沉更鼓急,渐渐人声绝 吹灯窗更明,月照一天雪 UE4简单的实现御剑飞行的功能 契子✨ 所谓的御剑飞行的原理就跟 《御板》 飞行的原理差不多,不过是在人物脚上插把剑在飞行的时候显示出来罢了。简单来讲就是只要渲染做的足够牛,土鸡…

App上架和推广前的准备

众所周知,App推广的第一步是上架各大应用下载市场,然后才是其他推广渠道。所以本文主要分两部分,第一部分主要介绍的是上架各大应用市场方面的准备,第二部分主要介绍的是其他渠道推广方面的准备。 一、App上架前的准备 1.1 上架…

Servlet基础(续集2)

HttpServletResponse web服务器接收到客户端的http的请求,针对这个请求,分别创建一个代表请求的HttpServletRequest对象,代表响应的一个HttpServletResponse 如果要获取客户端请求过来的参数:找HttpServletRequest如果要给客户端…

【前端面试】二叉树递归模板和题解

递归模板和步骤 递归题目的通用步骤递归模板总结1. 树的遍历(DFS)2. 二叉树的最大深度3. 二叉树的最近公共祖先 递归题目的记忆技巧 递归题目的通用步骤 明确递归函数的功能:确定递归函数的输入参数和返回值,明确函数的功能。基准…

如何在本地部署ChatTTS? 完美部署 简单几步 cpu gpu cuda

前言 最近,24-05-27号,github上出现了一个新项目,ChatTTS。该项目提供了一个文本转语音(Text To Speech)的开源方案,同时支持中文和英文。在官网的演示视频中,可以看到合成效果高度接近真人。 到目前(06-04)为止,已经有18.3k的star。 那我们就来看看这个模型的基本…

63、上海大学:MSConvNet-多尺度卷积神经网络解码大鼠运动疲劳数据[攒劲的模型来喽]

1、介绍&#xff1a; 文章&#xff1a;<A multiscale convolutional neural network based on time-frequency features for decoding rat exercise fatigue LFP >&#xff0c;本文由上海大学于2024.4.8日发表于<Biomedical Signal Processing and Control >&…

语音翻译软件app排名来啦,这些工具让旅游畅通无阻

#这个夏天我们一定要去看海# 出国旅行时&#xff0c;语言障碍常常是最让人头疼的问题之一。 特别是在像缅甸这样英语并不普及的国家&#xff0c;基本的日常交流&#xff0c;比如用餐或问路&#xff0c;都可能成为难题。 然而&#xff0c;随着技术的进步&#xff0c;现在有了…

全功能知识付费小程序源码系统 界面支持万能DIY装修 带完整的安装代码包以及搭建部署教程

系统概述 在当今数字化时代&#xff0c;知识付费已经成为一种重要的商业模式。为了满足市场对于便捷、高效、个性化的知识付费解决方案的需求&#xff0c;小编给大家分享一款全功能知识付费小程序源码系统。这一系统不仅具备界面支持万能 DIY 装修的独特优势&#xff0c;还配备…

游戏开发丨基于PyGame的消消乐小游戏

文章目录 写在前面PyGame消消乐注意事项系列文章写在后面 写在前面 本期内容&#xff1a;基于pygame实现喜羊羊与灰太狼版消消乐小游戏 下载地址&#xff1a;https://download.csdn.net/download/m0_68111267/88700193 实验环境 python3.11及以上pycharmpygame 安装pygame…

Cocos2dlua棋牌Lua解密

点击上方↑↑↑蓝字[协议分析与还原]关注我们 “ 介绍使用libcocos2dlua.so库的游戏的解密分析方法。” Cocos2dlua是一款流行的游戏引擎&#xff0c;常用于开发棋牌游戏。为了保护游戏代码&#xff0c;Cocos2dlua通常会对游戏脚本lua文件进行加密&#xff0c;生成Luac文件&…

电脑已删除的文件在回收站找不到怎么办?数据恢复办法分享!

电脑中的数据已经成为了我们生活和工作的重要部分。无论是珍贵的照片、重要的文档&#xff0c;还是日常的工作文件&#xff0c;我们都希望能够妥善保存很久。 然而&#xff0c;误删除文件的情况时有发生&#xff0c;而当我们急切地打开回收站试图找回这些文件时&#xff0c;却…

这些已经死去的软件,依旧无可替代

互联网这条长河里&#xff0c;软件们就像流星一样&#xff0c;一闪而过。有的软件火过一段时间&#xff0c;然后就慢慢消失了。 说不定有些软件你以前天天用&#xff0c;但不知道从什么时候开始就不再用了。时间一天天过去&#xff0c;我们的热情、记忆都在消退&#xff0c;还…

[巨详细]使用HBuilder-X新建uniapp项目教程

文章目录 安装HBuilder-X启动uniapp项目其他&#xff1a;下载预览浏览器下载终端插件 安装HBuilder-X 详细步骤可看上文》》 启动uniapp项目 先打开HBuilder-X 点击新建项目 选择uniapp侧边栏&#xff0c;mian中的点击浏览 选择已经安装到本地的uniapp项目&#xff0c;并输…

数据分析中的数学:从基础到应用20240617

数据分析中的数学&#xff1a;从基础到应用 数据分析离不开数学的支持&#xff0c;统计学和概率论是其重要组成部分。本文将通过几个具体的实例&#xff0c;详细讲解数据分析中常用的数学知识&#xff0c;并通过Python代码演示如何应用这些知识。 1. 描述性统计 基本概念和用…

运营一个商城网站需要办理什么许可证?

搭建一个商城网站以下资质是必须要办理的&#xff1a;网站ICP备案以及增值电信业务经营许可证。 一、网站ICP备案 国家对提供互联网信息服务的ICP实行许可证制度。从而&#xff0c;ICP证成为网络经营的许可证&#xff0c;经营性网站必须办理ICP证&#xff0c;否则就属于非法经营…