经典神经网络(11)VQ-VAE模型及其在MNIST数据集上的应用
-
我们之前已经了解了PixelCNN模型。
经典神经网络(10)PixelCNN模型、Gated PixelCNN模型及其在MNIST数据集上的应用
-
今天,我们了解下DeepMind在2017年提出的一种基于离散隐变量(Discrete Latent variables)的生成模型:VQ-VAE。
-
VQ-VAE采用离散隐变量,而不是像VAE那样采用连续的隐变量。其实VQ-VAE本质上是一种AE,只能很好地完成图像压缩,把图像变成一个短得多的向量,而不支持随机图像生成。
-
那么,VQ-VAE会被归类到图像生成模型中呢?
这是因为VQ-VAE单独训练了一个基于自回归的模型如PixelCNN来学习先验(prior),对VQ-VAE的离散编码空间采样。而不是像VAE那样采用一个固定的先验(标准正态分布)。 -
此外,VQ-VAE还是一个强大的无监督表征学习模型,它学习的离散编码具有很强的表征能力:
- OpenAI在2021年发布的文本转图像模型DALL-E就是基于VQ-VAE。
- 另外,在BEiT中也用VQ-VAE得到的离散编码作为训练目标。
注:推荐下EleutherAI团队的lucidrains(Phil Wang)的github,他开源复现了ViT、AlphaFold 2、DALLE、 DALLE2、imagen等项目
https://github.com/lucidrains
1 VQ-VAE
1.1 从AE到VQ-VAE
-
AE是一类能够把图片压缩成较短的向量的神经网络模型。
- AE的编码器编码出来的向量空间是不规整的。也就是说,解码器只认识经编码器编出来的向量,而不认识其他的向量。
- 如下图,我们在code空间上,两张图片的编码点中间处取一点,然后将这一点交给解码器,我们希望新的生成图片是一张清晰的图片(类似3/4全月的样子)。但是,实际的结果是,生成图片是模糊且无法辨认的乱码图。
-
只要AE的编码空间比较规整,符合某个简单的数学分布(比如最常见的标准正态分布,如下图所示),那我们就可以从这个分布里随机采样向量,再让解码器根据这个向量来完成随机图片生成了。
- VAE就是这样一种改进版的AE,它用一些巧妙的方法约束了编码向量z,使得z满足标准正态分布。
- 训练完成后,我们就可以扔掉编码器,用来自标准正态分布的随机向量和解码器来实现随机图像生成了。
-
VQ-VAE的作者认为,VAE的生成图片之所以质量不高,是因为图片被编码成了连续向量。而实际上,把图片编码成离散向量会更加自然。
-
至于离散编码的原因,作者解释如下:https://avdnoord.github.io/homepage/slides/SANE2017.pdf
1.2 VQVAE概述
把图像编码成离散向量后,会带来两个问题:
-
第一个问题是,神经网络会默认输入满足一个连续的分布,而不善于处理离散的输入。
- 如果你直接输入0, 1, 2这些数字,神经网络会默认1是一个处于0, 2中间的一种状态。为了解决这一问题,我们可以借鉴NLP中对于离散单词的处理方法。
- 我们可以把嵌入层加到VQ-VAE的解码器前,这个嵌入层就是
embedding space(嵌入空间)
,也称codebook
。 注意:其实Encoder编码出来的是二维离散编码,下图画的是一维。
-
另一个问题是离散向量不好采样。
- VAE之所以把图片编码成符合正态分布的连续向量,就是为了能在图像生成时把编码器扔掉,让随机采样出的向量也能通过解码器变成图片。现在,VQ-VAE把图片编码了一个离散向量,这个离散向量构成的空间是不好采样的。
- VQ-VAE的作者之前设计了一种图像生成网络,叫做PixelCNN。可以用PixelCNN生成离散编码,再利用VQ-VAE的解码器把离散编码变成图像。
-
VQ-VAE的架构图,如下图所示:
- 训练VQ-VAE的编码器和解码器,使得VQ-VAE能把图像变成
latent image(下图zq)
,也能把latent image(下图zq)
变回图像。 - 训练PixelCNN,让它学习怎么生成
latent image(下图zq)
。 - 生成(采样)时,先用PixelCNN采样出
latent image(下图zq)
,再用VQ-VAE把latent image(下图zq)
翻译成最终的生成图像。
- 训练VQ-VAE的编码器和解码器,使得VQ-VAE能把图像变成
1.3 VQ-VAE设计细节
1.3.1 关联编码器的输出与解码器的输入
如何关联编码器的输出与解码器的输入呢?
假设嵌入空间codebook已经训练完毕
,那么对于编码器的每个输出向量 z e ( x ) ze(x) ze(x),我们需要找出它在嵌入空间里的最近邻 z q ( x ) zq(x) zq(x),把 z e ( x ) ze(x) ze(x)替换成 z q ( x ) zq(x) zq(x)作为解码器的输入。- 方式是:求最近邻,即先计算向量与嵌入空间K个向量每个向量的距离,再对距离数组取一个
argmin
,求出最近的下标(如上图中的shape为[1,7,7]),最后用下标去嵌入空间里取向量,就得到了 z q zq zq(如上图中的shape为[1,32,7,7])。下标构成的多维数组,也正是VQ-VAE的离散编码。
1.3.2 梯度复制
- 我们现在能把编码器和解码器拼接到一起,但怎么让梯度从解码器的输入 z q ( x ) zq(x) zq(x)传到 z e ( x ) ze(x) ze(x)?从 z e ( x ) ze(x) ze(x)到 z q ( x ) zq(x) zq(x)的变换是一个从数组里取值,这个操作无法求导。
- VQ-VAE使用了一种叫做"straight-through estimator"的技术【即前向传播和反向传播的计算可以不对应】来完成梯度复制。VQ-VAE使用了一种叫做sg(stop gradient,停止梯度)的运算:
s g ( x ) = { x , 前向传播 0 , 反向传播 前向传播时, s g 里的值不变;反向传播时, s g 按值为 0 求导,即此次计算无梯度。 sg(x)=\begin{cases} x, & 前向传播\\ 0,& 反向传播 \end{cases}\\ 前向传播时,sg里的值不变;反向传播时,sg按值为0求导,即此次计算无梯度。 sg(x)={x,0,前向传播反向传播前向传播时,sg里的值不变;反向传播时,sg按值为0求导,即此次计算无梯度。
由于VQ-VAE其实是一个AE,误差函数里应该只有原图像和目标图像的重建误差:
L r e c o n s t r u c t = ∣ ∣ x − d e c o d e r ( z q ( x ) ) ∣ ∣ 2 2 L_{reconstruct}=||x-decoder(z_q(x))||_2^2 Lreconstruct=∣∣x−decoder(zq(x))∣∣22
我们现在利用sg运算,设计新的重建误差:
L r e c o n s t r u c t = ∣ ∣ x − d e c o d e r ( z e ( x ) + s g [ z q ( x ) − z e ( x ) ] ) ∣ ∣ 2 2 前向传播时,就是拿解码器的输入 z q ( x ) 来算误差: L r e c o n s t r u c t = ∣ ∣ x − d e c o d e r ( z e ( x ) + z q ( x ) − z e ( x ) ) ∣ ∣ 2 2 = ∣ ∣ x − d e c o d e r ( z q ( x ) ) ∣ ∣ 2 2 反向传播时,等价于把解码器的梯度全部传给 z e ( x ) : L r e c o n s t r u c t = ∣ ∣ x − d e c o d e r ( z e ( x ) + s g [ z q ( x ) − z e ( x ) ] ) ∣ ∣ 2 2 = ∣ ∣ x − d e c o d e r ( z e ( x ) ) ∣ ∣ 2 2 L_{reconstruct}=||x-decoder(z_e(x)+sg[z_q(x)-z_e(x)])||_2^2\\ 前向传播时,就是拿解码器的输入z_q(x)来算误差:\\ L_{reconstruct}=||x-decoder(z_e(x)+z_q(x)-z_e(x))||_2^2\\ =||x-decoder(z_q(x))||_2^2\\ 反向传播时,等价于把解码器的梯度全部传给z_e(x):\\ L_{reconstruct}=||x-decoder(z_e(x)+sg[z_q(x)-z_e(x)])||_2^2\\ =||x-decoder(z_e(x))||_2^2 Lreconstruct=∣∣x−decoder(ze(x)+sg[zq(x)−ze(x)])∣∣22前向传播时,就是拿解码器的输入zq(x)来算误差:Lreconstruct=∣∣x−decoder(ze(x)+zq(x)−ze(x))∣∣22=∣∣x−decoder(zq(x))∣∣22反向传播时,等价于把解码器的梯度全部传给ze(x):Lreconstruct=∣∣x−decoder(ze(x)+sg[zq(x)−ze(x)])∣∣22=∣∣x−decoder(ze(x))∣∣22
在PyTorch里,(x).detach()
就是sg(x),它的值在前向传播时取x
,反向传播时取0
。
# stop gradient
decoder_input = ze + (zq - ze).detach()
# decode
x_hat = decoder(decoder_input)
# l_reconstruct
l_reconstruct = mse_loss(x, x_hat)
1.3.3 优化嵌入空间codebook
嵌入空间的优化目标是什么呢?嵌入空间的每一个向量应该能概括一类编码器输出的向量。因此,嵌入空间的向量应该和其对应编码器输出尽可能接近。
L e = ∣ ∣ z e ( x ) − z q ( x ) ∣ ∣ 2 2 z e ( x ) 是编码器的输出向量, z q ( x ) 是其在嵌入空间的最近邻向量 L_e=||z_e(x)-z_q(x)||_2^2\\ z_e(x)是编码器的输出向量,z_q(x)是其在嵌入空间的最近邻向量 Le=∣∣ze(x)−zq(x)∣∣22ze(x)是编码器的输出向量,zq(x)是其在嵌入空间的最近邻向量
作者认为,编码器和嵌入向量的学习速度应该不一样快。
于是,他们再次使用了停止梯度的技巧,把上面那个误差函数拆成了两部分。其中,β控制了编码器的相对学习速度。作者发现,算法对β的变化不敏感,β取0.1~2.0都差不多。
L e = ∣ ∣ s g [ z e ( x ) ] − z q ( x ) ∣ ∣ 2 2 + β ∣ ∣ z e ( x ) − s g [ z q ( x ) ] ∣ ∣ 2 2 L_e=||sg[z_e(x)]-z_q(x)||_2^2+\beta||z_e(x)-sg[z_q(x)]||_2^2\\ Le=∣∣sg[ze(x)]−zq(x)∣∣22+β∣∣ze(x)−sg[zq(x)]∣∣22
# vq loss
l_embedding = mse_loss(ze.detach(), zq)
# commitment loss
l_commitment = mse_loss(ze, zq.detach())
VQ-VAE总体的损失函数可以写成:
L t o t a l = L r e c o n s t r u c t + L e = ∣ ∣ x − d e c o d e r ( z e ( x ) + s g [ z q ( x ) − z e ( x ) ] ) ∣ ∣ 2 2 + α ∣ ∣ s g [ z e ( x ) ] − z q ( x ) ∣ ∣ 2 2 + β ∣ ∣ z e ( x ) − s g [ z q ( x ) ] ∣ ∣ 2 2 L_{total}=L_{reconstruct} + L_e \\ =||x-decoder(z_e(x)+sg[z_q(x)-z_e(x)])||_2^2 +\alpha||sg[z_e(x)]-z_q(x)||_2^2\\+\beta||z_e(x)-sg[z_q(x)]||_2^2 Ltotal=Lreconstruct+Le=∣∣x−decoder(ze(x)+sg[zq(x)−ze(x)])∣∣22+α∣∣sg[ze(x)]−zq(x)∣∣22+β∣∣ze(x)−sg[zq(x)]∣∣22
# reconstruct loss
l_reconstruct = mse_loss(x, x_hat)
# vq loss
l_embedding = mse_loss(ze.detach(), zq)
# commitment loss
l_commitment = mse_loss(ze, zq.detach())# total loss
loss = l_reconstruct + \l_w_embedding * l_embedding + l_w_commitment * l_commitment
1.3.4 先验模型PixelCNN
- 训练好VQ-VAE后,还需要训练一个先验模型来完成数据生成,论文中采用PixelCNN模型。
- 这里我们不再是学习生成原始的pixels,而是学习生成离散编码:
- 首先,我们需要用已经训练好的VQ-VAE模型对训练图像推理,得到每张图像对应的离散编码;
- 然后用一个PixelCNN来对离散编码进行建模
- 最后的预测层采用基于softmax的多分类,类别数为embedding空间的大小K。
- 那么,生成图像的过程就比较简单了,首先用训练好的PixelCNN模型来采样一个离散编码样本(上图中shape为[1, 32, 7, 7]),然后送入VQ-VAE的decoder中,得到生成的图像。
- 实际上,PixelCNN不是唯一可用的拟合离散分布的模型。
我们可以把它换成Transformer,甚至是diffusion模型。
2 VQ-VAE模型在MNIST数据集上的应用
这里使用的模型为Gated PixelCNN模型,具体可参考:
经典神经网络(10)PixelCNN模型、Gated PixelCNN模型及其在MNIST数据集上的应用
网络结构图如下所示:
2.1 VQ-VAE模型
VQVAE的编码器和解码器的结构很简单,仅由普通的上/下采样层和残差块组成。
- 编码器先是有两个3x3卷积+2倍下采样卷积的模块,再有两个残差块(ReLU, 3x3卷积, ReLU, 1x1卷积);
- 解码器则反过来,先有两个残差块,再有两个3x3卷积+2倍上采样反卷积的模块。
# Reference: https://github.com/SingleZombie/DL-Demos/tree/master/dldemos/VQVAE
import os
import timeimport cv2
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as Fimport torchvision
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torchvision import transforms
from GatedPixelCNNDemo import GatedPixelCNN, GatedBlockclass ResidualBlock(nn.Module):def __init__(self, dim):super().__init__()self.relu = nn.ReLU()self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)self.conv2 = nn.Conv2d(dim, dim, kernel_size=1)def forward(self, x):tmp = self.relu(x)tmp = self.conv1(tmp)tmp = self.relu(tmp)tmp = self.conv2(tmp)return x + tmpclass VQVAE(nn.Module):def __init__(self, input_dim, dim, n_embedding):super().__init__()# 1、编码器self.encoder = nn.Sequential(nn.Conv2d(input_dim, dim, kernel_size=4, stride=2, padding=1),nn.ReLU(),nn.Conv2d(dim, dim, kernel_size=4, stride=2, padding=1),nn.ReLU(),nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1),ResidualBlock(dim),ResidualBlock(dim))self.vq_embedding = nn.Embedding(n_embedding, dim)# 初始化为均匀分布self.vq_embedding.weight.data.uniform_(-1.0 / n_embedding, 1.0 / n_embedding)# 2、解码器self.decoder = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1),ResidualBlock(dim),ResidualBlock(dim),nn.ConvTranspose2d(dim, dim, 4, 2, 1),nn.ReLU(),nn.ConvTranspose2d(dim, input_dim, 4, 2, 1))self.n_downsample = 2def forward(self, x):# encode [N, 1, 28, 28] -> [N, 32, 7, 7]ze = self.encoder(x)# ze: [N, C, H, W]# embedding [K, C] [32, 32]embedding = self.vq_embedding.weight.dataN, C, H, W = ze.shapeK, _ = embedding.shape# 求解最近邻embedding_broadcast = embedding.reshape(1, K, C, 1, 1)ze_broadcast = ze.reshape(N, 1, C, H, W)distance = torch.sum((embedding_broadcast - ze_broadcast)**2, 2)nearest_neighbor = torch.argmin(distance, 1)# make C to the second dimzq = self.vq_embedding(nearest_neighbor).permute(0, 3, 1, 2)# stop gradientdecoder_input = ze + (zq - ze).detach()# decodex_hat = self.decoder(decoder_input)return x_hat, ze, zq@torch.no_grad()def encode(self, x):ze = self.encoder(x)embedding = self.vq_embedding.weight.data# ze: [N, C, H, W]# embedding [K, C]N, C, H, W = ze.shapeK, _ = embedding.shapeembedding_broadcast = embedding.reshape(1, K, C, 1, 1)ze_broadcast = ze.reshape(N, 1, C, H, W)distance = torch.sum((embedding_broadcast - ze_broadcast)**2, 2)nearest_neighbor = torch.argmin(distance, 1)return nearest_neighbor@torch.no_grad()def decode(self, discrete_latent):zq = self.vq_embedding(discrete_latent).permute(0, 3, 1, 2)x_hat = self.decoder(zq)return x_hat# Shape: [C, H, W]def get_latent_HW(self, input_shape):C, H, W = input_shapereturn (H // 2**self.n_downsample, W // 2**self.n_downsample)
2.2 先验模型
我们已经有了一个普通的PixelCNN模型GatedPixelCNN
- 需要在整个模型的最前面套一个嵌入层,嵌入层的嵌入个数等于离散编码的个数(
color_level
),嵌入长度等于模型的特征长度(p
)。 - 由于嵌入层会直接输出一个长度为
p
的向量,我们还需要把第一个模块的输入通道数改成p
。
# 继承自我们之前实现的模型GatedPixelCNN
class PixelCNNWithEmbedding(GatedPixelCNN):def __init__(self, n_blocks, p, linear_dim, bn=True, color_level=256):super().__init__(n_blocks, p, linear_dim, bn, color_level)self.embedding = nn.Embedding(color_level, p)self.block1 = GatedBlock('A', p, p, bn)def forward(self, x):x = self.embedding(x)x = x.permute(0, 3, 1, 2).contiguous()return super().forward(x)
2.3 两种模型的训练
- 下面就是常规的训练代码
- 先训练VQVAE、再训练PixelCNN
def train_vqvae(model: VQVAE,img_shape=None,device='cuda',ckpt_path='./model.pth',batch_size=64,dataset_type='MNIST',lr=1e-3,n_epochs=100,l_w_embedding=1,l_w_commitment=0.25):print('batch size:', batch_size)dataloader = get_dataloader(dataset_type,batch_size,img_shape=img_shape)model.to(device)model.train()optimizer = torch.optim.Adam(model.parameters(), lr)mse_loss = nn.MSELoss()tic = time.time()for e in range(n_epochs):total_loss = 0for x in dataloader:current_batch_size = x.shape[0]x = x.to(device)x_hat, ze, zq = model(x)# 1、reconstruct lossl_reconstruct = mse_loss(x, x_hat)# 2、vq loss + commitment lossl_embedding = mse_loss(ze.detach(), zq)l_commitment = mse_loss(ze, zq.detach())# total lossloss = l_reconstruct + \l_w_embedding * l_embedding + l_w_commitment * l_commitmentoptimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item() * current_batch_sizetotal_loss /= len(dataloader.dataset)toc = time.time()torch.save(model.state_dict(), ckpt_path)print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')print('Done')def train_generative_model(vqvae: VQVAE,model,img_shape=None,device='cuda',ckpt_path='./gen_model.pth',dataset_type='MNIST',batch_size=64,n_epochs=50):print('batch size:', batch_size)dataloader = get_dataloader(dataset_type,batch_size,img_shape=img_shape)vqvae.to(device)vqvae.eval()model.to(device)model.train()optimizer = torch.optim.Adam(model.parameters(), 1e-3)# 交叉熵损失loss_fn = nn.CrossEntropyLoss()tic = time.time()for e in range(n_epochs):total_loss = 0for x in dataloader:current_batch_size = x.shape[0]with torch.no_grad():x = x.to(device)# 1、训练好的VQ-VAE模型对训练图像推理,得到每张图像对应的离散编码x = vqvae.encode(x)# 2、用一个PixelCNN来对离散编码进行建模predict_x = model(x)# 3、预测层采用基于softmax的多分类loss = loss_fn(predict_x, x)optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item() * current_batch_sizetotal_loss /= len(dataloader.dataset)toc = time.time()torch.save(model.state_dict(), ckpt_path)print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')print('Done')def reconstruct(model, x, device, dataset_type='MNIST'):model.to(device)model.eval()with torch.no_grad():x_hat, _, _ = model(x)n = x.shape[0]n1 = int(n**0.5)x_cat = torch.concat((x, x_hat), 3)x_cat = einops.rearrange(x_cat, '(n1 n2) c h w -> (n1 h) (n2 w) c', n1=n1)x_cat = (x_cat.clip(0, 1) * 255).cpu().numpy().astype(np.uint8)cv2.imwrite(f'work_dirs/vqvae_reconstruct_{dataset_type}.jpg', x_cat)
class MNISTImageDataset(Dataset):def __init__(self, img_shape=(28, 28)):super().__init__()self.img_shape = img_shapeself.mnist = torchvision.datasets.MNIST(root='/root/autodl-fs/data/minist')def __len__(self):return len(self.mnist)def __getitem__(self, index: int):img = self.mnist[index][0]pipeline = transforms.Compose([transforms.Resize(self.img_shape),transforms.ToTensor()])return pipeline(img)def get_dataloader(type,batch_size,img_shape=None,dist_train=False,num_workers=0,**kwargs):if type == 'MNIST':if img_shape is not None:dataset = MNISTImageDataset(img_shape)else:dataset = MNISTImageDataset()if dist_train:sampler = DistributedSampler(dataset)dataloader = DataLoader(dataset,batch_size=batch_size,sampler=sampler,num_workers=num_workers)return dataloader, samplerelse:dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers)return dataloadercfg = dict(dataset_type='MNIST',img_shape=(1, 28, 28),dim=32,n_embedding=32,batch_size=32,n_epochs=20,l_w_embedding=1,l_w_commitment=0.25,lr=2e-4,n_epochs_2=50,batch_size_2=32,pixelcnn_n_blocks=15,pixelcnn_dim=128,pixelcnn_linear_dim=32,vqvae_path='./model_mnist.pth',gen_model_path='./gen_model_mnist.pth')if __name__ == '__main__':os.makedirs('work_dirs', exist_ok=True)device = 'cuda' if torch.cuda.is_available() else 'cpu'img_shape = cfg['img_shape']# 初始化模型vqvae = VQVAE(img_shape[0], cfg['dim'], cfg['n_embedding'])gen_model = PixelCNNWithEmbedding(cfg['pixelcnn_n_blocks'],cfg['pixelcnn_dim'],cfg['pixelcnn_linear_dim'], True,cfg['n_embedding'])# 1. Train VQVAEtrain_vqvae(vqvae,img_shape=(img_shape[1], img_shape[2]),device=device,ckpt_path=cfg['vqvae_path'],batch_size=cfg['batch_size'],dataset_type=cfg['dataset_type'],lr=cfg['lr'],n_epochs=cfg['n_epochs'],l_w_embedding=cfg['l_w_embedding'],l_w_commitment=cfg['l_w_commitment'])# 2. Test VQVAE by visualizaing reconstruction resultvqvae.load_state_dict(torch.load(cfg['vqvae_path']))dataloader = get_dataloader(cfg['dataset_type'],16,img_shape=(img_shape[1], img_shape[2]))img = next(iter(dataloader)).to(device)reconstruct(vqvae, img, device, cfg['dataset_type'])# 3. Train Generative model (Gated PixelCNN)vqvae.load_state_dict(torch.load(cfg['vqvae_path']))train_generative_model(vqvae,gen_model,img_shape=(img_shape[1], img_shape[2]),device=device,ckpt_path=cfg['gen_model_path'],dataset_type=cfg['dataset_type'],batch_size=cfg['batch_size_2'],n_epochs=cfg['n_epochs_2'])# 4. Sample VQVAEvqvae.load_state_dict(torch.load(cfg['vqvae_path']))gen_model.load_state_dict(torch.load(cfg['gen_model_path']))sample_imgs(vqvae,gen_model,cfg['img_shape'],device=device,n_sample=1,dataset_type=cfg['dataset_type'])
2.4 图像生成(采样)
def sample_imgs(vqvae: VQVAE,gen_model,img_shape,n_sample=81,device='cuda',dataset_type='MNIST'):vqvae = vqvae.to(device)vqvae.eval()gen_model = gen_model.to(device)gen_model.eval()C, H, W = img_shapeH, W = vqvae.get_latent_HW((C, H, W))input_shape = (n_sample, H, W)# 初始化为0x = torch.zeros(input_shape).to(device).to(torch.long)with torch.no_grad():# 逐像素预测for i in range(H):for j in range(W):output = gen_model(x)prob_dist = F.softmax(output[:, :, i, j], -1)# 从概率分布中采样pixel = torch.multinomial(prob_dist, 1)x[:, i, j] = pixel[:, 0]# 解码imgs = vqvae.decode(x)imgs = imgs * 255imgs = imgs.clip(0, 255)imgs = einops.rearrange(imgs,'(n1 n2) c h w -> (n1 h) (n2 w) c',n1=int(n_sample**0.5))imgs = imgs.detach().cpu().numpy().astype(np.uint8)cv2.imwrite(f'work_dirs/vqvae_sample_{dataset_type}.jpg', imgs)