经典神经网络(10)PixelCNN模型、Gated PixelCNN模型及其在MNIST数据集上的应用

经典神经网络(10)PixelCNN模型、Gated PixelCNN模型及其在MNIST数据集上的应用

1 PixelCNN

  • PixelCNN是DeepMind团队在论文Pixel Recurrent Neural Networks (16.01)提出的一种生成模型,实际上这篇论文共提出了两种架构:PixelRNNPixelCNN,两者的主要区别是前者用LSTM来建模,而PixelCNN是基于CNN的,相比RNN,CNN计算更高效,我们这里只讨论PixelCNN。

  • PixelCNN借用了NLP里的方法来生成图像。对于自然图像,每个像素值的取值范围为0~255,共256个离散值。PixelCNN模型会根据前i - 1个像素输出第i个像素的概率分布。

  • 训练时,和多分类任务一样,要根据第i个像素的真值和预测的概率分布求交叉熵损失函数

  • 采样时(图像生成时),会根据前i - 1个像素直接从预测的概率分布(多项分布)里采样出第i个像素。

1.1 单通道PixelCNN

1.1.1 掩码卷积

我们现在知道了PixelCNN的大体思路,就是根据前i - 1个像素输出第i个像素的概率分布。我们现在只考虑单通道图像,每个像素的颜色取值只有256种,那么很容易想到下面的实现方式:

在这里插入图片描述

但是只输出一个像素的概率分布,这样训练效率太低了。

  • 在训练时,我们可以输入一幅图像,同时让模型输出图像每一点像素的概率分布(如下图所示),这样就能通过每个像素的真值和模型预测的概率分布求交叉熵损失函数,进行并行训练。
  • 我们能这么做的原因是:在训练时,整幅训练图像是已知的,因此我们可以在一次前向传播后得到图像每一处的概率分布。
  • 当然,我们需要找到每个像素都忽略后续像素的信息的方法,即论文中提出的掩码卷积机制,我们后面再讲。

在这里插入图片描述

但是在生成图像(采样)时,还是要一个像素一个像素的生成(如下所示)

  • 在采样时,我们会先根据前i - 1个像素输出第i个像素的概率分布。
  • 然后,我们会从第i个像素的概率分布中进行采样(如下面代码所示)
# 假设颜色取值范围为[0, 7],下面为概率分布
prob_dist = torch.tensor([[0.1347, 0.1356, 0.1048, 0.1314, 0.1329, 0.1256, 0.1326, 0.1025]])# 我们并不是取概率最大的像素,而是从概率分布中采样(例如下面取像素值6)
# torch.multinomial会从input这个概率分布中,取num_samples个值
pixel = torch.multinomial(input=prob_dist, num_samples=1).float() # tensor([[6.]])

在这里插入图片描述

我们现在已经知道了训练及采样的大体过程。但是,我们现在还是有一个疑问,如何保证训练时候,每个像素都忽略后续像素的信息?

PixelCNN论文里提出了一种掩码卷积机制,这种机制可以巧妙地掩盖住每个像素右侧和下侧的信息。

  • 具体来说,PixelCNN使用了两类掩码卷积:
    • 我们把两类掩码卷积分别称为「A类」和「B类」。
    • 二者都是对卷积操作的卷积核做了掩码处理,使得卷积核的右下部分不产生贡献。
    • A类和B类的唯一区别在于:卷积核的中心像素是否产生贡献
    • CNN的第一个的卷积层使用A类掩码卷积,之后每一层的都使用B类掩码卷积

在这里插入图片描述

我们来分析下这样设计的优点:

  • 对于一个7x7的图像,我们先用1次3x3 A类掩码卷积,再用若干次3x3 B类掩码卷积。我们观察图像中心处的像素在每次卷积后的感受野(即输入图像中哪些像素的信息能够传递到中心像素上)
    • 经过了第一个A类掩码卷积后,每个像素就已经看不到自己位置上的输入信息了。
    • 再经过两次B类掩码卷积后,中心像素能够看到左上角大部分像素的信息(如下图所示,我们发现还是会看漏少部分的信息,后面的Gated PixelCNN对此进行了改进)。
    • 这满足PixelCNN的约束。

在这里插入图片描述

  • 如果一直使用A类掩码卷积,每次卷积后中心像素都会看漏一些信息,最终就会导致看漏很多信息

在这里插入图片描述

  • 如果第一层就使用B类卷积,中心像素还是能看到自己位置的输入信息。这打破了PixelCNN的约束。

总结如下:

  • 逐像素预测只依赖于前面的像素,因此在选择卷积核时要进行掩码操作避免看到未来的值,因此,在第一层预测时可采用掩码卷积A
  • 由于CNN的逐像素预测是多层卷积,所以当第一层结束后,图像缺失部分已经有了预测值,因此在进行下一次/层卷积操作时可以利用当前像素的预测值,因此采用下列掩码卷积B
  • 需要注意的是,这里只考虑了单通道,如果扩展到RGB三个通道时,该如何进行mask呢?

1.1.2 PixelCNN的网络架构

  • 利用两类掩码卷积,PixelCNN满足了每个像素只能接受之前像素的信息这一约束。
  • 我们可以用任意一种CNN架构来实现PixelCNN。
  • 下图红色框所示部分是PixelCNN的网络结构,其中,第一个7x7卷积层用了A类掩码卷积,之后所有3x3卷积都是B类掩码卷积。

在这里插入图片描述

1.1.3 PixelCNN在MNIST数据集上的应用

1.1.3.1 模型

实现PixelCNN,最重要的是实现掩码卷积。

  • 掩码卷积的实现思路就是在卷积核组上设置一个mask。在前向传播的时候,先让卷积核组乘mask,再做普通的卷积。
  • 由于输入输出都是单通道图像,我们只需要在卷积核的h, w两个维度设置掩码。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision.transforms import ToTensor
import time
import einops
import cv2
import numpy as np
import osclass MaskConv2d(nn.Module):"""掩码卷积的实现思路:在卷积核组上设置一个mask,在前向传播的时候,先让卷积核组乘mask,再做普通的卷积"""def __init__(self, conv_type, *args, **kwags):super().__init__()assert conv_type in ('A', 'B')self.conv = nn.Conv2d(*args, **kwags)H, W = self.conv.weight.shape[-2:]# 由于输入输出都是单通道图像,我们只需要在卷积核的h, w两个维度设置掩码mask = torch.zeros((H, W), dtype=torch.float32)mask[0:H // 2] = 1mask[H // 2, 0:W // 2] = 1if conv_type == 'B':mask[H // 2, W // 2] = 1# 为了保证掩码能正确广播到4维的卷积核组上,我们做一个reshape操作mask = mask.reshape((1, 1, H, W))# register_buffer可以把一个变量加入成员变量的同时,记录到PyTorch的Module中# 每当执行model.to(device)把模型中所有参数转到某个设备上时,被注册的变量会跟着转。# 第三个参数表示被注册的变量是否要加入state_dict中以保存下来self.register_buffer(name='mask', tensor=mask, persistent=False)def forward(self, x):self.conv.weight.data *= self.maskconv_res = self.conv(x)return conv_res

有了最核心的掩码卷积,我们来根据论文中的模型结构图把模型搭起来

在这里插入图片描述

  • 我们先实现残差块上图右部分的ResidualBlock,这里添加归一化
class ResidualBlock(nn.Module):"""残差块ResidualBlock"""def __init__(self, h, bn=True):super().__init__()self.relu = nn.ReLU()self.conv1 = nn.Conv2d(2 * h, h, 1)self.bn1 = nn.BatchNorm2d(h) if bn else nn.Identity()self.conv2 = MaskConv2d('B', h, h, 3, 1, 1)self.bn2 = nn.BatchNorm2d(h) if bn else nn.Identity()self.conv3 = nn.Conv2d(h, 2 * h, 1)self.bn3 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()def forward(self, x):# 1、ReLU + 1×1 Conv + bny = self.relu(x)y = self.conv1(y)y = self.bn1(y)# 2、ReLU + 3×3 Conv(mask B) + bny = self.relu(y)y = self.conv2(y)y = self.bn2(y)# 3、ReLU + 1×1 Conv + bny = self.relu(y)y = self.conv3(y)y = self.bn3(y)# 4、残差连接y = y + xreturn y
  • 有了所有这些基础模块后,我们就可以拼出最终的PixelCNN了。
  • 注意,我们可以自己决定颜色有几个亮度级别。要修改亮度级别的数量,只需要修改softmax输出的通道数color_level。
class PixelCNN(nn.Module):def __init__(self, n_blocks, h, linear_dim, bn=True, color_level=256):super().__init__()self.conv1 = MaskConv2d('A', 1, 2 * h, 7, 1, 3)self.bn1 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()self.residual_blocks = nn.ModuleList()for _ in range(n_blocks):self.residual_blocks.append(ResidualBlock(h, bn))self.relu = nn.ReLU()self.linear1 = nn.Conv2d(2 * h, linear_dim, 1)self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)self.out = nn.Conv2d(linear_dim, color_level, 1)def forward(self, x):# 1、7 × 7 conv(mask A)x = self.conv1(x)x = self.bn1(x)# 2、Multiple residual blocksfor block in self.residual_blocks:x = block(x)x = self.relu(x)# 3、1 × 1 convx = self.linear1(x)x = self.relu(x)x = self.linear2(x)x = self.out(x)return x
1.1.3.2 数据集及训练

准备好了模型代码,我们可以编写训练脚本了:

  • PixelCNN有15个残差块,中间特征的通道数为128,输出前线性层的通道数为32
def get_dataloader(batch_size: int):dataset = torchvision.datasets.MNIST(root='/root/autodl-fs/data/minist',train=True,transform=ToTensor())return DataLoader(dataset, batch_size=batch_size, shuffle=True)def train(model, device, model_path, batch_size=128, color_level=8, n_epochs=40):"""训练过程"""dataloader = get_dataloader(batch_size)model = model.to(device)optimizer = torch.optim.Adam(model.parameters(), 1e-3)loss_fn = nn.CrossEntropyLoss()tic = time.time()for e in range(n_epochs):total_loss = 0for x, _ in dataloader:current_batch_size = x.shape[0]x = x.to(device)# 把训练集的浮点颜色值转换成[0, color_level-1]之间的整型标签y = torch.ceil(x * (color_level - 1)).long()y = y.squeeze(1)predict_y = model(x)loss = loss_fn(predict_y, y)optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item() * current_batch_sizetotal_loss /= len(dataloader.dataset)toc = time.time()torch.save(model.state_dict(), model_path)print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')if __name__ == '__main__':os.makedirs('work_dirs', exist_ok=True)device = 'cuda' if torch.cuda.is_available() else 'cpu'# 需要注意的是:MNIST数据集的大部分像素都是0和255color_level = 8  # or 256# 1、创建PixelCNN模型model = PixelCNN(n_blocks=15, h=128, linear_dim=32, bn=True, color_level=color_level)# 2、模型训练model_path = f'work_dirs/model_pixelcnn_{color_level}.pth'train(model, device, model_path)# 3、采样sample(model, device, model_path, f'work_dirs/pixelcnn_{color_level}.jpg')        
1.1.3.3 采样
  • 在采样时,我们把x初始化成一个0张量。
  • 之后,循环遍历每一个像素,输入x,把预测出的下一个像素填入x.
def sample(model, device, model_path, output_path, n_sample=1):"""把x初始化成一个0张量。循环遍历每一个像素,输入x,把预测出的下一个像素填入x"""model.eval()model.load_state_dict(torch.load(model_path))model = model.to(device)C, H, W = get_img_shape()  # (1, 28, 28)x = torch.zeros((n_sample, C, H, W)).to(device)with torch.no_grad():for i in range(H):for j in range(W):# 我们先获取模型的输出,再用softmax转换成概率分布output = model(x)prob_dist = F.softmax(output[:, :, i, j], -1)# 再用torch.multinomial从概率分布里采样出【1】个[0, color_level-1]的离散颜色值# 再除以(color_level - 1)把离散颜色转换成浮点[0, 1]pixel = torch.multinomial(input=prob_dist, num_samples=1).float() / (color_level - 1)# 最后把新像素填入到生成图像中x[:, :, i, j] = pixel# 乘255变成一个用8位字节表示的图像imgs = x * 255imgs = imgs.clamp(0, 255)imgs = einops.rearrange(imgs, '(b1 b2) c h w -> (b1 h) (b2 w) c', b1=int(n_sample**0.5))imgs = imgs.detach().cpu().numpy().astype(np.uint8)cv2.imwrite(output_path, imgs)

1.2 多通道PixelCNN

如下图所示,作者假设RGB三个通道之间存在相互影响

  • 其中红色预测不受蓝色和绿色通道的影响,只受上下文影响
  • 绿色红色通道和上下文影响,但不受蓝色通道影响;
  • 蓝色通道受上下文、红色通道、绿色通道影响

在这里插入图片描述

更具体地,我们规定一个子像素只由它之前的子像素决定,生成图像时,我们一个子像素一个子像素地生成

  • 如下图所示,对于RGB图像,R子像素由它之前所有像素决定
  • G子像素由它的R子像素和之前所有像素决定,
  • B子像素由它的R、G子像素和它之前所有像素决定。

在这里插入图片描述

如下图所示,由于现在要预测三个颜色通道,网络的输出应该是一个[256x3, H, W]形状的张量

  • 即每个像素输出三个概率分布,分别表示R、G、B取某种颜色的概率。
  • 同时,本质上来讲,网络是在并行地为每个像素计算3组结果。因此,为了达到同样的性能,网络所有的特征图的通道数也要乘3。

在这里插入图片描述

图像变为多通道后,A类卷积和B类卷积的定义也需要做出一些调整。我们不仅要考虑像素在空间上的约束,还要考虑一个像素内子像素间的约束。为此,我们要用不同的策略实现约束。为了方便描述,我们设卷积核组的形状为[o, i, h, w],其中o为输出通道数,i为输入通道数,h, w为卷积核的高和宽。

  • 对于通道间的约束,我们要在o, i两个维度上设置掩码,如下图左边所示。
    • 设输出通道可以被拆成三组o1, o2, o3,输入通道可以被拆成三组i1, i2, i3
      • o1 = 0:o/3, o2 = o/3:o*2/3, o3 = o*2/3:o
      • i1 = 0:i/3, i2 = i/3:i*2/3, i3 = i*2/3:i
      • 序号1, 2, 3分别表示这组通道是在维护R, G, B的计算。
    • 我们对输入通道组和输出通道组之间进行约束。
    • 对于A类卷积,我们令o1看不到i1, i2, i3o2看不到i2, i3o3看不到i3
    • 对于B类卷积,我们取消每个通道看不到自己的限制,即在A类卷积的基础上令o1看到i1o2看到i2o3看到i3
  • 如下图右边所示,对于空间上的约束,我们还是和之前一样,在h, w两个维度上设置掩码。由于「是否看到自己」的处理已经在o, i两个维度里做好了,我们直接在空间上用原来的B类卷积就行。

在这里插入图片描述

  • 下面给出三维掩码示意图方便理解:

在这里插入图片描述

2 Gated PixelCNN

2.1 Gated PixelCNN简述

  • 可以参考大神讲解:Gated PixelCNN (sergeiturukin.com)

  • PixelCNN的掩码卷积其实有一个重大漏洞:像素存在视野盲区。如下图所示,中心像素看不到右上角三个本应该能看到的像素。

在这里插入图片描述

  • 为此,PixelCNN论文的作者又发表了Conditional Image Generation with PixelCNN Decoders(16.06)。这篇论文提出了一种叫做Gated PixelCNN的改进架构。Gated PixelCNN使用了一种更好的掩码卷积机制,消除了原PixelCNN里的视野盲区。

在这里插入图片描述

  • 如下图所示,Gated PixelCNN使用了两种卷积,即垂直卷积和水平卷积,来分别维护一个像素上侧的信息和左侧的信息
    • 垂直卷积的结果只是一些临时量
    • 而水平卷积的结果最终会被网络输出
    • 使用这种新的掩码卷积机制后,每个像素能正确地收到之前所有像素的信息了。

在这里插入图片描述

  • Gated PixelCNN用下图的模块代替了原PixelCNN的普通残差模块。
  • 模块的输入输出都是两个量,左边的量是垂直卷积中间结果,右边的量是最后用来计算输出的量。
  • 垂直卷积的结果会经过偏移和一个1x1卷积,再加到水平卷积的结果上。
  • 两条计算路线在输出前都会经过门激活单元。所谓门激活单元,就是输入两个形状相同的量,一个做tanh,一个做sigmoid,两个结果相乘再输出。
  • 此外,模块右侧还有一个残差连接。

在这里插入图片描述

2.2 Gated PixelCNN在MNIST数据集上的应用

2.2.1 创建模型

  • 首先,实现垂直卷积和水平卷积
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision.transforms import ToTensor
import time
import einops
import cv2
import numpy as np
import osclass VerticalMaskConv2d(nn.Module):"""垂直卷积"""def __init__(self, *args, **kwags):super().__init__()self.conv = nn.Conv2d(*args, **kwags)H, W = self.conv.weight.shape[-2:]mask = torch.zeros((H, W), dtype=torch.float32)mask[0:H // 2 + 1] = 1mask = mask.reshape((1, 1, H, W))self.register_buffer('mask', mask, False)def forward(self, x):self.conv.weight.data *= self.maskconv_res = self.conv(x)return conv_resclass HorizontalMaskConv2d(nn.Module):"""水平卷积"""def __init__(self, conv_type, *args, **kwags):super().__init__()assert conv_type in ('A', 'B')self.conv = nn.Conv2d(*args, **kwags)H, W = self.conv.weight.shape[-2:]mask = torch.zeros((H, W), dtype=torch.float32)mask[H // 2, 0:W // 2] = 1if conv_type == 'B':mask[H // 2, W // 2] = 1mask = mask.reshape((1, 1, H, W))self.register_buffer('mask', mask, False)def forward(self, x):self.conv.weight.data *= self.maskconv_res = self.conv(x)return conv_res
# 垂直卷积
tensor([[[[1., 1., 1.],[1., 1., 1.],[0., 0., 0.]]]])
# A类水平卷积
tensor([[[[0., 0., 0.],[1., 0., 0.],[0., 0., 0.]]]])
# B类水平卷积
tensor([[[[0., 0., 0.],[1., 1., 0.],[0., 0., 0.]]]])
  • 我们现在搭建Gated Block模块,这也是最难理解的一部分。
  • 可以参考的解释:https://segmentfault.com/a/1190000041189859?utm_source=sf-similar-article

在这里插入图片描述

  • # 这里比较难理解,通过对图像进行零填充并裁剪图像底部,可以确保垂直和水平堆栈之间的因果关系
    v_to_h = v[:, :, 0:-1]
    v_to_h = F.pad(v_to_h, (0, 0, 1, 0))
    # 注意到,v和i相加的位置只差了一个单位。
    # 为了把相加的位置对齐,我们要把v往下移一个单位,把原来在i-1处的信息移到i上。
    # 这样,移动过后的v_to_h就能和h直接用向量加法并行地加到一起了。
    

在这里插入图片描述

  • 维护两个v, h两个变量,分别表示垂直卷积部分的结果和水平卷积部分的结果。
    • v会经过一个垂直掩码卷积和一个门激活函数。
    • h会经过一个类似于残差块的结构,只不过第一个卷积是水平掩码卷积、激活函数是门激活函数、进入激活函数之前会和垂直卷积的信息融合。
class GatedBlock(nn.Module):def __init__(self, conv_type, in_channels, p, bn=True):super().__init__()self.conv_type = conv_typeself.p = pself.v_conv = VerticalMaskConv2d(in_channels, 2 * p, 3, 1, 1)self.bn1 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()self.v_to_h_conv = nn.Conv2d(2 * p, 2 * p, kernel_size=1)self.bn2 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()self.h_conv = HorizontalMaskConv2d(conv_type, in_channels, 2 * p, 3, 1,1)self.bn3 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()self.h_output_conv = nn.Conv2d(p, p, 1)self.bn4 = nn.BatchNorm2d(p) if bn else nn.Identity()def forward(self, v_input, h_input):# v代表垂直卷积部分的结果v = self.v_conv(v_input)v = self.bn1(v)# Note: 重点代码# 为了把v的信息贴到h上,我们并不是像前面的示意图所写的令v上移一个单位# 而是用下面的代码令v下移了一个单位(下移即去掉最下面一行,往最上面一行填0)v_to_h = v[:, :, 0:-1]v_to_h = F.pad(v_to_h, (0, 0, 1, 0))# 和h相加前,先经过 1×1 convv_to_h = self.v_to_h_conv(v_to_h)v_to_h = self.bn2(v_to_h)# 分为两份,经过tanh 和 sigmoidv1, v2 = v[:, :self.p], v[:, self.p:]v1 = torch.tanh(v1)v2 = torch.sigmoid(v2)v = v1 * v2# h代表水平卷积部分的结果h = self.h_conv(h_input)h = self.bn3(h)h = h + v_to_h# 分为两份,经过tanh 和 sigmoidh1, h2 = h[:, :self.p], h[:, self.p:]h1 = torch.tanh(h1)h2 = torch.sigmoid(h2)h = h1 * h2h = self.h_output_conv(h)h = self.bn4(h)# 在网络的第一层,每个数据是不能看到自己的。# 所以,当GatedBlock发现卷积类型为A类时,不应该对h做残差连接。if self.conv_type == 'B':h = h + h_inputreturn v, h
  • 最后,我们来用GatedBlock搭出Gated PixelCNN
  • Gated PixelCNN和PixelCNN的结构非常相似,只是把ResidualBlock替换成了GatedBlock而已。
class GatedPixelCNN(nn.Module):def __init__(self, n_blocks, p, linear_dim, bn=True, color_level=256):super().__init__()self.block1 = GatedBlock('A', 1, p, bn)self.blocks = nn.ModuleList()for _ in range(n_blocks):self.blocks.append(GatedBlock('B', p, p, bn))self.relu = nn.ReLU()self.linear1 = nn.Conv2d(p, linear_dim, 1)self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)self.out = nn.Conv2d(linear_dim, color_level, 1)def forward(self, x):v, h = self.block1(x, x)for block in self.blocks:v, h = block(v, h)x = self.relu(h)x = self.linear1(x)x = self.relu(x)x = self.linear2(x)x = self.out(x)return x

2.2.2 数据集、训练及采样

  • 数据集、训练及采样和PixelCNN一模一样,不再赘述。
def get_dataloader(batch_size: int):dataset = torchvision.datasets.MNIST(root='/root/autodl-fs/data/minist',train=True,transform=ToTensor())return DataLoader(dataset, batch_size=batch_size, shuffle=True)def train(model, device, model_path, batch_size=128, color_level=8, n_epochs=40):"""训练过程"""dataloader = get_dataloader(batch_size)model = model.to(device)optimizer = torch.optim.Adam(model.parameters(), 1e-3)loss_fn = nn.CrossEntropyLoss()tic = time.time()for e in range(n_epochs):total_loss = 0for x, _ in dataloader:current_batch_size = x.shape[0]x = x.to(device)# 把训练集的浮点颜色值转换成0~color_level-1之间的整型标签的y = torch.ceil(x * (color_level - 1)).long()y = y.squeeze(1)predict_y = model(x)loss = loss_fn(predict_y, y)optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item() * current_batch_sizetotal_loss /= len(dataloader.dataset)toc = time.time()torch.save(model.state_dict(), model_path)print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')def get_img_shape():return (1, 28, 28)def sample(model, device, model_path, output_path, n_sample=1):"""把x初始化成一个0张量。循环遍历每一个像素,输入x,把预测出的下一个像素填入x"""model.eval()model.load_state_dict(torch.load(model_path))model = model.to(device)C, H, W = get_img_shape()  # (1, 28, 28)x = torch.zeros((n_sample, C, H, W)).to(device)with torch.no_grad():for i in range(H):for j in range(W):# 我们先获取模型的输出,再用softmax转换成概率分布output = model(x)prob_dist = F.softmax(output[:, :, i, j], -1)# 再用torch.multinomial从概率分布里采样出【1个】0~(color_level-1)的离散颜色值# 再除以(color_level - 1)把离散颜色转换成浮点颜色(因为网络是输入是浮点颜色)pixel = torch.multinomial(input=prob_dist, num_samples=1).float() / (color_level - 1)# 最后把新像素填入生成图像x[:, :, i, j] = pixelimgs = x * 255imgs = imgs.clamp(0, 255)imgs = einops.rearrange(imgs, '(b1 b2) c h w -> (b1 h) (b2 w) c', b1=int(n_sample**0.5))imgs = imgs.detach().cpu().numpy().astype(np.uint8)cv2.imwrite(output_path, imgs)if __name__ == '__main__':os.makedirs('work_dirs', exist_ok=True)device = 'cuda' if torch.cuda.is_available() else 'cpu'color_level = 8  # or 256# 1、创建GatedPixelCNN模型model = GatedPixelCNN(n_blocks=15, p=128, linear_dim=32, bn=True, color_level=color_level)# 2、模型训练model_path = f'work_dirs/model_gatedpixelcnn_{color_level}.pth'train(model, device, model_path, batch_size=1)# 3、采样sample(model, device, model_path, f'work_dirs/gatedpixelcnn_{color_level}.jpg')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/24139.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

设计模式在芯片验证中的应用——迭代器

一、迭代器设计模式 迭代器设计模式(iterator)是一种行为设计模式, 让你能在不暴露集合底层表现形式 (列表、 栈和树等数据结构) 的情况下遍历集合中所有的元素。 在验证环境中的checker会收集各个monitor上送过来的transactions&#xff0…

设计软件有哪些?效果工具篇(2),渲染100邀请码1a12

这次我们继续介绍一些渲染效果和后期处理的工具。 1、Krakatoa Krakatoa是由Thinkbox Software开发的强大的粒子渲染器,可用于Autodesk 3ds Max等软件。它专注于处理大规模粒子数据,提供了高效的渲染解决方案,适用于各种特效、粒子系统和模…

TESSENT2024.1安装

一、安装过程参考Calibre安装过程(此处省略,不再赘述) 二、安装license管理器: SiemensLicenseServer_v2.2.1.0_Lnx64_x86-64.bin 三、Patch补丁: tessent安装目录和license管理安装目录,执行FlexNetLic…

企业必备技能导航栏的写法

创建一个导航栏是网页设计中的一个重要环节,它不仅有助于用户快速找到他们需要的信息,还能提升整个网站的用户体验。以下是一些基本步骤和技巧,可以帮助你快速制作一个高效且美观的导航栏: 确定导航栏位置:导航栏通常位…

C++:Traits编程技法在STL迭代器中的应用

文章目录 迭代器相应型别Traits(特性)编程技法——STL源代码门钥迭代器相应型别一:value_type迭代器相应型别二:difference_type迭代器相应型别三:reference_type迭代器相应型别四:pointer_type迭代器相应型别五:itera…

2 - 寻找用户推荐人(高频 SQL 50 题基础版)

2.寻找用户推荐人 考点: sql里面的不等于,不包含null -- null 用数字判断筛选不出来 select name from Customer where referee_id !2 OR referee_id IS NULL;

设置密码重要性!美国一配件制造商因忘设密码影响50 多万客户

1、Cox Biz 身份验证绕过漏洞使数百万台设备暴露于接管 美国一家领先宽带提供商cox的基础架构中存在 API 授权绕过漏洞,如果被利用攻击者不仅可以访问企业客户的个人身份信息 (PII),还可以访问 Wi-Fi 密码和连接设备上的信息&…

1501 - JUC高并发

须知少许凌云志,曾许人间第一流 看的是尚硅谷的视频做的学习总结,感恩老师,下面是视频的地址 传送门https://www.bilibili.com/video/BV1Kw411Z7dF 0.思维导图 1.JUC简介 1.1 什么是JUC JUC, java.util.concurrent工具包的简称…

STM32-呼吸灯仿真

目录 前言: 一.呼吸灯 二.跑马灯 三. 总结 前言: 本篇的主要内容是关于STM32-呼吸灯的仿真,包括呼吸灯,跑马灯的实现与完整代码,欢迎大家的点赞,评论和关注. 接上http://t.csdnimg.cn/mvWR4 既然已经点亮了一盏灯,接下来就可以做更多实验了, 一.呼吸灯 在上一个的基础上…

【一】apollo 环境配置

域控制器配置 google输入法安装 安装输入google pinyin法 sudo apt install fcitx-bin sudo apt install fcitx-table sudo apt-get install fcitx fcitx-googlepinyin -y 最后需要reboot 系统环境 修改文件夹名称为英文 export LANGen_US xdg-user-dirs-gtk-update 挂载硬…

2559. 统计范围内的元音字符串数(前缀和) o(n)时间复杂度

给你一个下标从 0 开始的字符串数组 words 以及一个二维整数数组 queries 。 每个查询 queries[i] [li, ri] 会要求我们统计在 words 中下标在 li 到 ri 范围内(包含 这两个值)并且以元音开头和结尾的字符串的数目。 返回一个整数数组,其中…

微前端之旅:探索Qiankun的实践经验

theme: devui-blue 什么是微前端? 微前端是一种前端架构方法,它借鉴了微服务的架构理念,将一个庞大的前端应用拆分为多个独立灵活的小型应用,每个应用都可以独立开发、独立运行、独立部署,再将这些小型应用联合为一个完…

淘宝天猫商品详情API接口详解

一、淘宝天猫商品详情API接口概述 淘宝天猫商品详情API接口是淘宝天猫开放平台提供的一项重要服务,它允许开发者通过API接口获取淘宝天猫商品的详细信息。这些信息包括但不限于商品标题、价格、描述、图片、销量、评价等。通过使用淘宝天猫商品详情API接口&#xf…

国密算法SM2的优势、原理和应用场景

随着信息化时代的到来,数据安全和网络空间的安全成为了国家安全的重要组成部分。密码学作为保障信息安全的关键技术,其重要性日益凸显。在这样的背景下,中国国家密码管理局推出了一系列自主的密码学算法,即国密算法,其…

12.【Orangepi Zero2】基于orangepi_Zero_2 Linux的智能家居项目

基于orangPi Zero 2的智能家居项目 需求及项目准备 语音接入控制各类家电,如客厅灯、卧室灯、风扇回顾二阶段的Socket编程,实现Sockect发送指令远程控制各类家电烟雾警报监测, 实时检查是否存在煤气泄漏或者火灾警情,当存在警情时…

SkyWalking之P0业务场景输出调用链路应用

延伸扩展:XX业务场景 路由标签打标、传播、检索 链路标签染色与传播 SW: SkyWalking的简写 用户请求携带HTTP头信息X-sw8-correlation “X-sw8-correlation: key1value1,key2value2,key3value3” 网关侧读取解析HTTP头信息X-sw8-correlation,然后通过SW…

探索未来制造,BFT Robotics引领潮流

“买机器人,上BFT” 在这个快速变化的时代,创新和效率是企业发展的关键。BFT Robotics,作为您值得信赖的合作伙伴,专注于为您提供一站式的机器人采购和自动化解决方案。 产品系列: 协作机器人:安全、灵活、…

Linux C语言:指针和指针变量

一、指针的作用 使程序简洁、紧凑、高效有效地表示复杂的数据结构动态分配内存能直接访问硬件能够方便的处理字符串得到多于一个的函数返回值 二、内存、地址和变量 1、内存地址 2、变量和地址 1)变量用来在程序中保存数据 比如: int k 58; //声明一个int变…

基于JSP技术的社区疫情防控管理信息系统

你好呀,我是计算机学长猫哥!如果有相关需求,文末可以找到我的联系方式。 开发语言:JSP 数据库:MySQL 技术:JSPJavaBeans 工具:MyEclipse、Tomcat、Navicat 系统展示 首页 用户注册与登录界…

2024-5-7 石群电路-26

2024-6-7,星期五,15:00,天气:阴转小雨,心情:晴。今天虽然是阴雨天,但是心情不能差哦,离答辩越来越近了,今天学完习好好准备准备ppt,加油学习喽~ 今日观看了石…