torch实现Gated PixelCNN

文章目录

  • PixelCNN
  • Gated PixelCNN

PixelCNN

import torch
import torch.nn as nn
import torch.nn.functional as F# Pixel CNNclass MaskConv2d(nn.Module):def __init__(self, conv_type, *args, **kwags):super().__init__()assert conv_type in ('A', 'B')self.conv = nn.Conv2d(*args, **kwags)H, W = self.conv.weight.shape[-2:]mask = torch.zeros((H, W), dtype=torch.float32)mask[0:H // 2, :] = 1mask[H // 2, 0:W // 2] = 1if conv_type == 'B':mask[H // 2, W // 2] = 1mask = mask.reshape((1, 1, H, W))self.register_buffer('mask', mask, False)def forward(self, x):self.conv.weight.data *= self.maskconv_res = self.conv(x)return conv_resclass ResidualBlock(nn.Module):def __init__(self, h, bn=True):super().__init__()self.relu = nn.ReLU()self.conv1 = nn.Conv2d(2 * h, h, 1)self.bn1 = nn.BatchNorm2d(h) if bn else nn.Identity()self.conv2 = MaskConv2d('B', h, h, 3, 1, 1)self.bn2 = nn.BatchNorm2d(h) if bn else nn.Identity()self.conv3 = nn.Conv2d(h, 2 * h, 1)self.bn3 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()def forward(self, x):y = self.relu(x)y = self.conv1(y)y = self.bn1(y)y = self.relu(y)y = self.conv2(y)y = self.bn2(y)y = self.relu(y)y = self.conv3(y)y = self.bn3(y)y = y + xreturn yclass PixelCNN(nn.Module):def __init__(self, n_blocks, h, linear_dim, bn=True, color_level=256):super().__init__()self.conv1 = MaskConv2d('A', 1, 2 * h, 7, 1, 3)self.bn1 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()self.residual_blocks = nn.ModuleList()for _ in range(n_blocks):self.residual_blocks.append(ResidualBlock(h, bn))self.relu = nn.ReLU()self.linear1 = nn.Conv2d(2 * h, linear_dim, 1)self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)self.out = nn.Conv2d(linear_dim, color_level, 1)def forward(self, x):x = self.conv1(x)x = self.bn1(x)for block in self.residual_blocks:x = block(x)x = self.relu(x)x = self.linear1(x)x = self.relu(x)x = self.linear2(x)x = self.out(x)return x

Gated PixelCNN

class VerticalMaskConv2d(nn.Module):def __init__(self, *args, **kwags):super().__init__()self.conv = nn.Conv2d(*args, **kwags)H, W = self.conv.weight.shape[-2:]mask = torch.zeros((H, W), dtype=torch.float32)mask[0:H // 2 + 1] = 1mask = mask.reshape((1, 1, H, W))self.register_buffer('mask', mask, False)def forward(self, x):self.conv.weight.data *= self.maskconv_res = self.conv(x)return conv_resclass HorizontalMaskConv2d(nn.Module):def __init__(self, conv_type, *args, **kwags):super().__init__()assert conv_type in ('A', 'B')self.conv = nn.Conv2d(*args, **kwags)H, W = self.conv.weight.shape[-2:]mask = torch.zeros((H, W), dtype=torch.float32)mask[H // 2, 0:W // 2] = 1if conv_type == 'B':mask[H // 2, W // 2] = 1mask = mask.reshape((1, 1, H, W))self.register_buffer('mask', mask, False)def forward(self, x):self.conv.weight.data *= self.maskconv_res = self.conv(x)return conv_resclass GatedBlock(nn.Module):def __init__(self, conv_type, in_channels, p, bn=True):super().__init__()self.conv_type = conv_typeself.p = pself.v_conv = VerticalMaskConv2d(in_channels, 2 * p, 3, 1, 1)self.bn1 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()self.v_to_h_conv = nn.Conv2d(2 * p, 2 * p, 1)self.bn2 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()self.h_conv = HorizontalMaskConv2d(conv_type, in_channels, 2 * p, 3, 1,1)self.bn3 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()self.h_output_conv = nn.Conv2d(p, p, 1)self.bn4 = nn.BatchNorm2d(p) if bn else nn.Identity()def forward(self, v_input, h_input):v = self.v_conv(v_input)v = self.bn1(v)v_to_h = v[:, :, 0:-1]v_to_h = F.pad(v_to_h, (0, 0, 1, 0))v_to_h = self.v_to_h_conv(v_to_h)v_to_h = self.bn2(v_to_h)v1, v2 = v[:, :self.p], v[:, self.p:]v1 = torch.tanh(v1)v2 = torch.sigmoid(v2)v = v1 * v2h = self.h_conv(h_input)h = self.bn3(h)h = h + v_to_hh1, h2 = h[:, :self.p], h[:, self.p:]h1 = torch.tanh(h1)h2 = torch.sigmoid(h2)h = h1 * h2h = self.h_output_conv(h)h = self.bn4(h)if self.conv_type == 'B':h = h + h_inputreturn v, hclass GatedPixelCNN(nn.Module):def __init__(self, n_blocks, p, linear_dim, bn=True, color_level=256):super().__init__()self.block1 = GatedBlock('A', 1, p, bn)self.blocks = nn.ModuleList()for _ in range(n_blocks):self.blocks.append(GatedBlock('B', p, p, bn))self.relu = nn.ReLU()self.linear1 = nn.Conv2d(p, linear_dim, 1)self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)self.out = nn.Conv2d(linear_dim, color_level, 1)def forward(self, x):v, h = self.block1(x, x)for block in self.blocks:v, h = block(v, h)x = self.relu(h)x = self.linear1(x)x = self.relu(x)x = self.linear2(x)x = self.out(x)return x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/102960.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JAVA NIO深入剖析

4.1 Java NIO 基本介绍 Java NIO(New IO)也有人称之为 java non-blocking IO是从Java 1.4版本开始引入的一个新的IO API,可以替代标准的Java IO API。NIO与原来的IO有同样的作用和目的,但是使用的方式完全不同,NIO支持面向缓冲区的、基于通道的IO操作。NIO将以更加高效的方…

MySQL数据生成工具mysql_random_data_load

在看MySQL文章的时候偶然发现生成数据的工具,此处直接将软件作者的文档贴了过来,说明了使用方式及下载地址 Random data generator for MySQL Many times in my job I need to generate random data for a specific table in order to reproduce an is…

深眸科技自研AI视觉分拣系统,实现物流行业无序分拣场景智慧应用

在机器视觉应用环节中,物体分拣是建立在识别、检测之后的一个环节,通过机器视觉系统对图像进行处理,并结合机械臂的使用实现产品分类。 通过引入视觉分拣技术,不仅可以实现自动化作业,还能提高生产线的生产效率和准确…

Paddle安装

Paddle安装参考 docs/tutorials/INSTALL_cn.md PaddlePaddle/PaddleDetection - Gitee.comhttps://gitee.com/paddlepaddle/PaddleDetection/blob/release/2.6/docs/tutorials/INSTALL_cn.md # 不指定版本安装paddle-gpu python -m pip install paddlepaddle-gpu# 测试安装 …

使用 Eziriz .NET Reactor 对c#程序加密

我目前测试过好几个c#加密软件。效果很多时候是加密后程序执行错误,或者字段找不到的现象 遇到这个加密软件用了一段时间都很正常,分享一下使用流程 破解版本自行百度。有钱的支持正版,我用的是 Eziriz .NET Reactor 6.8.0 第一步 安装 Ezi…

【JVM--StringTable字符串常量池】

文章目录 1. String 的基本特性2. 字符串拼接操作3. intern()的使用4. StringTable 的垃圾回收 1. String 的基本特性 String 声明为 final 的,不可被继承String 实现了 Serializable 接口:表示字符串是支持序列化的。String 实现了 Comparable 接口&am…

发行版兴趣小组季度动态:Anolis OS 支持大热 AI 软件栈,引入社区合作安全修复流程

发行版兴趣小组(Special Interest Group) :旨在为龙蜥社区构建、发布和维护一个稳定的操作系统发行版。 秋天的季节,发行版兴趣小组在 AI、安全、国产 OS 领域同样也是硕果累累。一起来看一下第三季度发行版兴趣小组的成果总结有…

【ppt技巧】ppt里的图片如何提取出来?

之前分享过如何将PPT文件导出成图片,今天继续分享PPT技巧,如何提取出PPT文件里面的图片。 首先,我们将PPT文件的后缀名,修改为rar,将文件改为压缩包文件 然后我们将压缩包文件进行解压 最好是以文件夹的形式解压出来…

【FreeRTOS】【STM32】02 FreeRTOS 移植

基于 [野火]《FreeRTOS%20内核实现与应用开发实战—基于STM32》 正点原子《STM32F429FreeRTOS开发手册_V1.2》 准备 基础工程,例如点灯 FreeRTOS 系统源码 FreeRTOS 移植 上一章节已经说明了Free RTOS的源码文件在移植时所需要的,FreeRTOS 为我们提供…

物联网AI MicroPython传感器学习 之 CCS811空气质量检测传感器

学物联网,来万物简单IoT物联网!! 一、产品简介 通过CCS811传感器模块可以测量环境中TVOC(总挥发性有机物质)浓度和eCO2(二氧化碳)浓度,作为衡量空气质量(IAQ)的指标。 引脚定义 VCC:3.3VGND&…

亚马逊,速卖通,敦煌产品测评补单攻略:低成本、高安全实操指南

随着电商平台的发展和消费者对产品质量的要求提升,测评补单成为了商家们提升销售和用户口碑的关键环节。然而,如何在保持成本低廉的同时确保操作安全,一直是卖家们面临的挑战。今天林哥分享一些实用的技巧和策略,帮助卖家们产品的…

精品Python的美食推荐系统厨房点餐订餐

《[含文档PPT源码等]精品Python的美食推荐系统》该项目含有源码、文档、PPT、配套开发软件、软件安装教程、项目发布教程等! 软件开发环境及开发工具: 开发语言:python 使用框架:Django 前端技术:JavaScript、VUE.…

Linux系统卡顿处理记录(Debian)

问题现象描述 现象linux操作系统卡顿(就是很慢),但是系统任然能够使用。 文章一步步的排查并且定位问题。 排查步骤 1. 使用top命令查看CPU是否占用过高。(未发现)排除问题 2. 使用df -h查看硬盘是否被占满。&#…

竞赛 深度学习 opencv python 实现中国交通标志识别

文章目录 0 前言1 yolov5实现中国交通标志检测2.算法原理2.1 算法简介2.2网络架构2.3 关键代码 3 数据集处理3.1 VOC格式介绍3.2 将中国交通标志检测数据集CCTSDB数据转换成VOC数据格式3.3 手动标注数据集 4 模型训练5 实现效果5.1 视频效果 6 最后 0 前言 🔥 优质…

Kubernetes使用OkHttp客户端进行网络负载均衡

在一次内部Java服务审计中,我们发现一些请求没有在Kubernetes(K8s)网络上正确地实现负载均衡。导致我们深入研究的问题是HTTP 5xx错误率的急剧上升,由于CPU使用率非常高,垃圾收集事件的数量很多以及超时,但…

ctfshow-ssti

web361 名字就是考点,所以注入点就是name 先测试一下存不存在ssti漏洞 利用os模块,脚本 查看一下子类的集合 ?name{{.__class__.__base__.__subclasses__()}} 看看有没有os模块,查找os 利用这个类,用脚本跑他的位置 import …

LeetCode(力扣)416. 分割等和子集Python

LeetCode416. 分割等和子集 题目链接代码 题目链接 https://leetcode.cn/problems/partition-equal-subset-sum/ 代码 class Solution:def canPartition(self, nums: List[int]) -> bool:sum 0dp [0]*10001for num in nums:sum numif sum % 2 1:return Falsetarget …

arm实验

设置按键中断,按键1按下,LED亮,再次按下,灭 按键2按下,蜂鸣器叫,再次按下,停 按键3按下,风扇转,再次按下,停 头文件 #ifndef __CTRL_KEY_H__ #define __CT…

axios的请求中断和请求重试

请求中断 场景:1、假如一个页面接口太多、或者当前网络太卡顿、这个时候跳往其他路由,当前页面可以做的就是把请求中断掉(优化)2、假如当前接口调取了第一页数据,又调去了第二页的数据,当我们调取第二页数…

Learning Sample Relationship for Exposure Correction 论文阅读笔记

这是中科大发表在CVPR2023的一篇论文,提出了一个module和一个损失项,能够提高现有exposure correction网络的性能。这已经是最近第三次看到这种论文了,前两篇分别是CVPR2022的ENC(和这篇文章是同一个一作作者)和CVPR20…