使用Pytorch从零开始构建Conditional PixelCNN

条件 PixelCNN

PixelCNN 是 PixelRNN 的卷积版本,它将图像中的像素视为一个序列,并在看到前面的像素后预测每个像素(定义如上和左,尽管这是任意的)。PixelRNN 是图像联合先验分布的自回归模型:
p ( x ) = p ( x 0 ) ∏ p ( x i ∣ x 0 , ⋯ , x i − 1 ) p(x) = p(x_0 ) ∏ p(x_i | x_0, \cdots,x_{i-1} ) p(x)=p(x0)p(xix0,,xi1)
PixelRNN 的训练速度很慢,因为循环无法并行化——即使是小图像也有数百或数千个像素,这对于 RNN 来说是一个相对较长的序列。用掩码卷积替换循环,使卷积滤波器仅看到上方和左侧的像素,从而实现更快的训练(图来自条件 PixelCNN 论文)。
在这里插入图片描述

然而,值得注意的是,最初的 PixelCNN 实现产生的结果比 PixelRNN 更差。在后续论文(使用 PixelCNN 解码器生成条件图像)中推测,结果降级的一个可能原因是 PixelCNN 中的 ReLU 激活与 LSTM 中的门控连接相比相对简单。Conditional PixelCNN 论文随后用门控激活取代了 ReLU:
y = t a n h ( W f ∗ x ) • σ ( W g ∗ x ) y = tanh (W f * x) • σ(W g * x) y=tanh(Wfx)σ(Wgx)
后续论文中提供的另一个可能的原因是,堆叠掩模卷积滤波器会导致盲点,无法捕获预测像素之上的所有像素(论文中的图):
在这里插入图片描述

PixelCNN 与 GAN

PixelCNN 和 GAN 是目前用于生成图像的两种深度学习模型。GAN 最近受到了很多关注,但在很多方面我发现它们的流行是没有根据的。

目前尚不清楚 GAN 实际上试图优化什么目标,因为训练目标的最小值(即愚弄鉴别器)将导致生成器重新创建所有训练图像和/或生成不一定类似于自然图像的对抗性示例。这反映在训练 GAN 的众所周知的困难以及无数的对其进行正则化的技巧上。让两个网络相互对抗以产生训练信号的想法很有趣,并且已经产生了许多好的论文(尤其是 CycleGAN),但我仍然不相信它们除了在社交媒体上发布华丽的帖子之外还有其他用途。

另一方面,PixelCNN 有很好的概率基础。这使得它们不仅可以通过对分布进行采样(从左到右,从上到下,遵循自回归定义)来生成图像,而且还意味着它们可以用于其他任务。例如:作为预筛选网络来检测域外或对抗性示例;用于检测训练集中的异常值;或估计测试中的不确定性。我将在下一篇文章中详细介绍其中一些扩展。

我很想知道是否有人尝试过将 PixelCNN 和 GAN 结合起来。也许 PixelCNN 可以用作解码器的前级或最后阶段(以一些更高级别的学习表示为条件),以避免 GAN 的一些训练困难。

实现

我的实现使用门控块,但为了快速实现,我决定放弃针对盲点问题的双流解决方案(将滤波器分为水平和垂直组件)。有代码可用于解决 Tensorflow 中的盲点问题,并且在 PyTorch 中重写它相当简单。这样,掩蔽就很简单:当前像素下方和右侧的所有内容在滤波器中都被清零,并且在第一层中,当前像素也在滤波器中设置为零。

class MaskedConv(nn.Conv2d):def __init__(self,mask_type,in_channels,out_channels,kernel_size,stride=1):"""mask_type: 'A' for first layer of network, 'B' for all others"""super(MaskedConv,self).__init__(in_channels,out_channels,kernel_size,stride,padding=kernel_size//2)assert mask_type in ('A','B')mask = torch.ones(1,1,kernel_size,kernel_size)mask[:,:,kernel_size//2,kernel_size//2+(mask_type=='B'):] = 0mask[:,:,kernel_size//2+1:] = 0self.register_buffer('mask',mask)def forward(self,x):self.weight.data *= self.maskreturn super(MaskedConv,self).forward(x)

门控 ResNet 块的实现稍微复杂一些:PixelCNN 在网络的两半之间有快捷连接,就像 U-Net 一样;PyTorch 允许模块的前向方法仅在输入是变量时才接受多个输入;由于网络前半部分的特征图不是变量,因此它们必须与其他输入(前一层的特征)连接起来。使用条件向量可以避免这种情况,因为它是一个变量(在本例中为类标签)。

class GatedRes(nn.Module):def __init__(self,in_channels,out_channels,n_classes,kernel_size=3,stride=1,aux_channels=0):super(GatedRes,self).__init__()self.conv = MaskedConv('B',in_channels,2*out_channels,kernel_size,stride)self.y_embed = nn.Linear(n_classes,2*out_channels)self.out_channels = out_channelsif aux_channels!=2*out_channels and aux_channels!=0:self.aux_shortcut = nn.Sequential(nn.Conv2d(aux_channels,2*out_channels,1),nn.BatchNorm2d(2*out_channels,momentum=0.1))if in_channels!=out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels,out_channels,1),nn.BatchNorm2d(out_channels,momentum=0.1))self.batchnorm = nn.BatchNorm2d(out_channels,momentum=0.1)def forward(self,x,y):# check for aux input from first half of net stacked into xif x.dim()==5:x,aux = torch.split(x,1,dim=0)x = torch.squeeze(x,0)aux = torch.squeeze(x,0)else:aux = Nonex1 = self.conv(x)y = torch.unsqueeze(torch.unsqueeze(self.y_embed(y),-1),-1)if aux is not None:if hasattr(self,'aux_shortcut'):aux = self.aux_shortcut(aux)x1 = (x1+aux)/2# split for gate (note: pytorch dims are [n,c,h,w])xf,xg = torch.split(x1,self.out_channels,dim=1)yf,yg = torch.split(y,self.out_channels,dim=1)f = torch.tanh(xf+yf)g = torch.sigmoid(xg+yg)if hasattr(self,'shortcut'):x = self.shortcut(x)return x+self.batchnorm(g*f)

我不确定在阅读原始论文时将批量归一化放在哪里,所以我将它放在我认为有意义的地方:在添加剩余连接之前。

实现这两个类后,整个网络就相对容易了。PyTorch 方案将所有内容定义为 的子类nn.Module,初始化所有层/操作/等。在构造函数中,然后在forward方法中将它们连接在一起可能会很混乱。如果您有大量快捷连接并且想要使用任意深度的循环对模型进行编码,则尤其如此。

注意:为了能够保存/恢复模型,您必须将图层存储在一个ModuleList而不是常规列表中。不过,附加和索引此列表在其他方面是相同的。

class PixelCNN(nn.Module):def __init__(self,in_channels,n_classes,n_features,n_layers,n_bins,dropout=0.5):super(PixelCNN,self).__init__()self.layers = nn.ModuleList()self.n_layers = n_layers# Up passself.input_batchnorm = nn.BatchNorm2d(in_channels,momentum=0.1)for l in range(n_layers):if l==0:  # start with normal convblock = nn.Sequential(MaskedConv('A',in_channels+1,n_features,kernel_size=7),nn.BatchNorm2d(n_features,momentum=0.1),nn.ReLU())else:block = GatedRes(n_features, n_features, n_classes)self.layers.append(block)# Down passfor _ in range(n_layers):block = GatedRes(n_features, n_features,n_classes,aux_channels=n_features)self.layers.append(block)# Last layer: project to n_bins (output is [-1, n_bins, h, w])self.layers.append(nn.Sequential(nn.Dropout2d(dropout),nn.Conv2d(n_features,n_bins,1),nn.LogSoftmax(dim=1)))def forward(self,x,y):# Add channel of ones so network can tell where padding isx = nn.functional.pad(x,(0,0,0,0,0,1,0,0),mode='constant',value=1)# Up passfeatures = []i = -1for _ in range(self.n_layers):i += 1if i>0:x = self.layers[i](x,y)else:x = self.layers[i](x)features.append(x)# Down passfor _ in range(self.n_layers):i += 1x = self.layers[i](torch.stack((x,features.pop())),y)# Last layeri += 1x = self.layers[i](x)assert i==len(self.layers)-1assert len(features)==0return x

MNIST 实际上是黑白的,因此我将标签离散为仅 4 个灰度级,以便计算交叉熵损失。在自然图像上,输出级别的数量显然需要更高。网络中的所有层都有 200 个特征。对于数据增强,我使用了 +/-5 度的随机旋转和最近邻采样。对于训练,我使用 Adam,学习率为 10 -4,dropout 率为 0.9。

更高的特征数量(比 MNIST 所需的特征数量更多)和更高的 dropout 是训练时间与正则化之间的权衡。这是一个在论文中很少提及的技巧,但有助于避免过度拟合——我只在一篇关于视频中动作识别训练的论文中看到过它,其中由于高维度与当前数据集大小,过度拟合是一个问题可用的。

我有一个 GTX1070 GPU,所以我没有运行任何类型的超参数优化:猜测合理的超参数并使模型工作的能力很大程度上说明了 Adam + 批量归一化 + dropout 的稳健性。学习率肯定可以更高,但这会产生更有趣的 GIF。

结果

在这里插入图片描述

上面的 gif 显示了整个训练过程中每个epochs后生成的一批 50 张图像(每类 5 个示例),从看似随机的涂鸦到类似于实际数字的东西。这是最佳epochs的结果:
在这里插入图片描述
这项工作的动机是看看条件 PixelCNN 是否也可以在类之间生成合理的示例。这是通过调节软标签而不是单热编码标签来完成的。

让我们尝试一下我所期望的容易混淆的数字对:(1,7), (3,8), (4,9), (5,6)
在这里插入图片描述
生成的类间示例并不像正常示例那样真实。模型可能需要一些额外的训练信号(例如来自分类器网络的教师强制)才能沿着图像流形进行插值。这有点令人失望,因为我曾希望生成类间示例可能允许使用学习的混合形式(而不是平均图像)。显然,进一步测试这个想法将需要更多的 GPU 来生成批量输入,所以无论如何,它目前超出了我的范围。

本文的完整代码可在Github代码库中查看。

本博文译自 jrbtaylor 的博客。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/168485.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【开源】基于Vue和SpringBoot的食品生产管理系统

项目编号: S 044 ,文末获取源码。 \color{red}{项目编号:S044,文末获取源码。} 项目编号:S044,文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 加工厂管理模块2.2 客户管理模块2.3…

【Proteus仿真】【STM32单片机】智能垃圾桶设计

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 本项目使用Proteus8仿真STM32单片机控制器,使用报警模块、LCD1602液晶模块、按键模块、人体红外传感器、HCSR04超声波、有害气体传感器、SG90舵机等。 主要功能: 系统运行后&…

文章解读与仿真程序复现思路——电网技术EI\CSCD\北大核心《基于Fisher时段划分的配电网源网荷储多时间尺度协调优化调控策略》

这个标题涉及到电力系统领域的一些关键概念和方法。让我们逐步解读: 基于Fisher时段划分: "基于"表示这个策略或方法的核心基础是某个特定的理论或技术。"Fisher时段划分"可能指的是使用Fisher信息矩阵进行时间划分。Fisher信息矩阵…

居家适老化设计第三十条---卫生间之坐便

以上产品图片均来源于淘宝 侵权联系删除 在居家适老化中,马桶是非常重要的设施之一,它能够提供方便、安全、舒适的上厕所体验。以下是一些居家适老化中常见的马桶设计和功能:1. 高度合适:为了方便老年人坐起和站起,马…

Da-transunet:将空间和通道双重关注与Transformer u-net相结合用于医学图像分割

DA-TRANSUNET: INTEGRATING SPATIAL AND CHANNEL DUAL ATTENTION WITH TRANSFORMER U-NET FOR MEDICAL IMAGE SEGMENTATION 1、方法1.1 模型1.2 双注意力模块(DA-Block)1.2.1 PAM( 位置注意力模块)1.2.2 CAM(通道注意力…

NX二次开发UF_CURVE_ask_int_parms_sc 函数介绍

文章作者:里海 来源网站:https://blog.csdn.net/WangPaiFeiXingYuan UF_CURVE_ask_int_parms_sc Defined in: uf_curve.h int UF_CURVE_ask_int_parms_sc(tag_t int_curve_object, int * num_objects_set_1, tag_t * * object_set_1, int * num_object…

Swing程序设计(6)边界布局,网格布局

文章目录 前言一、布局介绍 1.边界布局2.网格布局3.网格组布局.总结 前言 Swing程序中还有两种方式边界布局,网格布局供程序员使用。这两种布局方式更能体现出软件日常制作的排列布局格式。 一、布局介绍 1.BorderLayout边界布局 语法:new BorderLayout …

laravel8安装多应用多模块(笔记三)

先安装laravel8 Laravel 安装(笔记一)-CSDN博客 一、进入项目根目录安装 laravel-modules composer require nwidart/laravel-modules 二、 大于laravel5需配置provider,自动生成配置文件 php artisan vendor:publish --provider"Nwid…

windows cmd执行远程长脚本

背景 有时候我们想在未进行一些环境设置,或者工具使用者电脑中执行一段初始化脚本,为了简化使用者的理解成本,通常给使用者一段代码执行初始化电脑中的设置,尤其是这段初始化脚本比较长的时候。 脚本制作者 比如将需要执行的命…

H5ke12--2--学生选课表格的编辑

方法1不可以修改的用label,如何按了哪一行 就会在下面有个文本显示可编辑的一行 方法2每一行后面都有一个编辑, 3对每一个修改,每一个td失去焦点都会有,直接到达我们服务器 注意 如果用span的每一个html元素都可以自己定义属性 Data-属性名,data-Address links也要给为span 1…

递归算法学习——二叉树的伪回文路径

1,题目 给你一棵二叉树,每个节点的值为 1 到 9 。我们称二叉树中的一条路径是 「伪回文」的,当它满足:路径经过的所有节点值的排列中,存在一个回文序列。 请你返回从根到叶子节点的所有路径中 伪回文 路径的数目。 示例…

软件设计中如何画各类图之二深入解析数据流图(DFD):系统设计与分析的关键视觉工具

目录 1 前言2 数据流图(DFD)的重要性3 数据流图的符号说明4 清晰的数据流图步骤4.1 确定系统边界4.2 识别数据流4.3 定义处理过程4.4 确认数据存储4.5 建立数据流动的连线4.6 细化和优化 5 数据流图的用途6 使用场景7 实际应用场景举例8 结语 1 前言 当…

使用 Python 和 NLTK 进行文本摘要

一、说明 文本摘要是一种自然语言处理技术,允许用户将大量文本总结为小块,而不会丢失任何重要信息。本文介绍NLP中使用Gensim和Sumy实现文本摘要的步骤。 二、为什么要总结文本? 互联网包含大量信息,而且每秒都在增加。文本摘要可…

鼠标点击位置获取几何体对象_vtkAreaPicker_vtkInteractorStyleRubberBandPick

开发环境: Windows 11 家庭中文版Microsoft Visual Studio Community 2019VTK-9.3.0.rc0vtk-example参考代码 demo解决问题:框选或者点选某一区域,并获取区域prop3D对象(红线内为有效区域,polydata组成的3d几何对象&a…

力扣刷题篇之排序算法

系列文章目录 前言 本系列是个人力扣刷题汇总,本文是排序算法。刷题顺序按照[力扣刷题攻略] Re:从零开始的力扣刷题生活 - 力扣(LeetCode) 这个之前写的左神的课程笔记里也有: 左程云算法与数据结构代码汇总之排序&am…

【前端】数据行点击选择

前言 【前篇文章】说了,我们公司的核心价值就是让人越来越懒,能怎么便捷就怎么便捷,主打一个简单实用又快捷,为了实现这个目标,我看成这个列表陷入了深思在想,要不要子表的数据加载在点击这个行时,就可以展示数据,这样就不用每次都要点那个小圆圈啦。 查资料 这显然…

2023.11.25-istio安全

目录 文章目录 目录本节实战1、安全概述2、证书签发流程1.签发证书2.身份认证 3、认证1.对等认证a.默认的宽容模式b.全局严格 mTLS 模式c.命名空间级别策略d.为每个工作负载启用双向 TLS 2.请求认证a.JWK 与 JWKS 概述b.配置 JWT 终端用户认证c.设置强制认证规则 关于我最后 本…

RevCol实战:使用RevCol实现图像分类任务(二)

文章目录 训练部分导入项目使用的库设置随机因子设置全局参数图像预处理与增强读取数据设置Loss设置模型设置优化器和学习率调整策略设置混合精度,DP多卡,EMA定义训练和验证函数训练函数验证函数调用训练和验证方法 运行以及结果查看测试完整的代码 在上…

「Verilog学习笔记」数据串转并电路

专栏前言 本专栏的内容主要是记录本人学习Verilog过程中的一些知识点,刷题网站用的是牛客网 关于什么是Valid/Ready握手机制: 深入 AXI4 总线(一)握手机制 - 知乎 时序图含有的信息较多,观察时序图需要注意&#xff1a…

Redis常用操作及应用(一)

一、五种数据结构 二、String结构 1、字符串常用操作 SET key value //存入字符串键值对 MSET key value [key value ...] //批量存储字符串键值对 SETNX key value //存入一个不存在的字符串键值对 GET key //获取一个字符串键值 MGET key [ke…