【PyTorch单点知识】深入理解与应用转置卷积ConvTranspose2d模块

文章目录

      • 0. 前言
      • 1. 转置卷积概述
      • 2. `nn.ConvTranspose2d` 模块详解
        • 2.1 主要参数
        • 2.2 属性与方法
      • 3. 计算过程(重点)
        • 3.1 基本过程
        • 3.2 调整stride
        • 3.3 调整dilation
        • 3.4 调整padding
        • 3.5 调整output_padding
      • 4. 应用实例
      • 5. 总结

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

nn.ConvTranspose2d 模块是用于实现二维转置卷积(又称为反卷积)的核心组件。本文将详细介绍 ConvTranspose2d 的概念、工作原理、参数设置以及实际应用。

本文的说明参考了PyTorch的官方文档

1. 转置卷积概述

转置卷积(Transposed Convolution),有时也被称为“反卷积”(尽管严格来说它并不是真正意义上的卷积的逆运算),是一种特殊的卷积操作,常用于从较低分辨率的特征图上采样到较高分辨率的空间维度。

在诸如深度卷积生成对抗网络(DCGAN)和条件生成对抗网络(CGANs)等任务中,转置卷积被广泛用于将网络内部的紧凑特征(较小的特征)表示恢复为与原始输入尺寸相匹配或接近的(较大的特征)输出。

2. nn.ConvTranspose2d 模块详解

nn.ConvTranspose2d 是 PyTorch 中 torch.nn 模块的一部分,专门用于定义和实例化二维转置卷积层。其构造函数接受一系列参数来配置卷积行为:

2.1 主要参数
  1. in_channels (int) - 输入特征图的通道数,即前一层的输出通道数。

  2. out_channels (int) - 输出特征图的通道数,即本层产生的新特征通道数。

  3. kernel_size (inttuple) - 卷积核大小,通常是一个整数(当使用方形卷积核时)或包含两个整数的元组(分别对应卷积核的高度和宽度)。

  4. stride (inttuple, default=1) - 卷积步长,决定了卷积核在输入特征图上滑动的距离。与 kernel_size 类似,它可以是单个整数(对所有维度相同)或一个包含两个整数的元组。

  5. padding (inttuple, default=0) - 填充量,用于控制输出尺寸和保持边界信息。

  6. output_padding (inttuple, default=0) - 用于调整输出尺寸的额外填充量,仅应用于转置卷积。它在卷积计算后增加到输出边缘的额外像素数量。

  7. groups (int, default=1) - 分组卷积参数,当大于1时,输入和输出通道将被分成若干组,每组内的卷积相互独立。

  8. bias (bool, default=True) - 表示是否为该层添加可学习的偏置项。

  9. dilation (inttuple, default=1) - 卷积核元素之间的间距(膨胀率),控制卷积核中非零元素之间的距离。

  10. padding_mode (str , default=zeros) - 填充数据方式,zeros为全部填充0

  11. device (str , default=cpu) - 处理数据的设备

  12. dtype (str, default=None ) - 数据类型

2.2 属性与方法
  • .weight (Tensor) - 存储转置卷积核的权重,形状为 (out_channels, in_channels, kernel_size[0], kernel_size[1]),是可学习的模型参数。

  • .bias (Tensor) - 若 bias=True,则包含与每个输出通道关联的偏置项,形状为 (out_channels),也是可学习的参数。

  • .forward(input) - 接受输入张量 input,执行转置卷积运算并返回输出特征图。

3. 计算过程(重点)

输入输出图像一般为4维或3维,即[B, C, H, W]或[C, H, W],其中:

  • B:Batch_size,每批的样本数
  • C:channel,通道数
  • H, W:图像的高和宽

以图像高度H为例(宽度W同理),转置卷积的输出尺寸可以通过以下公式计算:

H o u t = ( H i n − 1 ) × stride − 2 × padding + dilation × ( kernel-size − 1 ) + output-padding + 1 H_{out}=(H_{in}-1) \times \text{stride} -2 \times \text{padding} + \text{dilation} \times (\text{kernel-size}-1) + \text{output-padding}+1 Hout=(Hin1)×stride2×padding+dilation×(kernel-size1)+output-padding+1

这个公式看起来比较复杂,下面我们通过实例来理解转置卷积的计算过程。

3.1 基本过程

输入原图size为[1, 2, 2],卷积核也size也为[1, 2, 2],其余参数如下:

in_channels=1, out_channels=1, kernel_size=2, stride=1, padding=0, output_padding=0,dilation=1,bias=False

计算过程:

在这里插入图片描述

容易看出,经历转置卷积后特征图会扩大,即上采样。使用代码验算:

import torchinput = torch.tensor([[[[0,1],[2,3]]]],dtype=torch.float32)ConvTrans = torch.nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=2, stride=1, padding=0, output_padding=0,dilation=1,bias=False)
ConvTrans.weight = torch.nn.Parameter(torch.tensor([[[[ 1.1, 2.2],[ 3.3, 4.4]]]], dtype=torch.float32,requires_grad=True))print(ConvTrans(input))

输出为:

tensor([[[[ 0.0000,  1.1000,  2.2000],[ 2.2000, 11.0000, 11.0000],[ 6.6000, 18.7000, 13.2000]]]], grad_fn=<ConvolutionBackward0>)
3.2 调整stride

把stride调整为2后,计算过程如下:

在这里插入图片描述
如果stride过大,则会在跳过的位置补0。例如上面的计算过程中,如果stride = 3输出则为:

在这里插入图片描述

注意,这里stride可以指定为tuple,即让横向和纵向的stride不一样,例如(1, 2),但其计算思路不变,这里直接用代码计算结果(懒得再画过程图了):

import torchinput = torch.tensor([[[[0,1],[2,3]]]],dtype=torch.float32)ConvTrans = torch.nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=2, stride=(1,2), padding=0, output_padding=0,dilation=1,bias=False)
ConvTrans.weight = torch.nn.Parameter(torch.tensor([[[[ 1.1, 2.2],[ 3.3, 4.4]]]], dtype=torch.float32,requires_grad=True))print(ConvTrans(input))

输出为:

tensor([[[[ 0.0000,  0.0000,  1.1000,  2.2000],[ 2.2000,  4.4000,  6.6000, 11.0000],[ 6.6000,  8.8000,  9.9000, 13.2000]]]],grad_fn=<ConvolutionBackward0>)
3.3 调整dilation

这个过程非常简单,可以分为2步:

  1. 把卷积核进行dilation(爆炸)处理
  2. 进行3.1基本过程

即:
在这里插入图片描述
代码验算过程如下:

import torchinput = torch.tensor([[[[0,1],[2,3]]]],dtype=torch.float32)ConvTrans_dilation2 = torch.nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=2, stride=1, padding=0, output_padding=0,dilation=2,bias=False)
ConvTrans_dilation2.weight = torch.nn.Parameter(torch.tensor([[[[ 1.1, 2.2],[ 3.3, 4.4]]]], dtype=torch.float32,requires_grad=True))print(ConvTrans_dilation2(input))ConvTrans_dilation1 = torch.nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=2, stride=1, padding=0, output_padding=0,dilation=1,bias=False)
ConvTrans_dilation1.weight = torch.nn.Parameter(torch.tensor([[[[1.1, 0, 2.2],[0, 0, 0],[3.3, 0, 4.4]]]], dtype=torch.float32,requires_grad=True))   #对卷积核进行dilationprint(ConvTrans_dilation1(input))
print(ConvTrans_dilation2(input) == ConvTrans_dilation1(input))

输出为:

tensor([[[[ 0.0000,  1.1000,  0.0000,  2.2000],[ 2.2000,  3.3000,  4.4000,  6.6000],[ 0.0000,  3.3000,  0.0000,  4.4000],[ 6.6000,  9.9000,  8.8000, 13.2000]]]],grad_fn=<ConvolutionBackward0>)
tensor([[[[ 0.0000,  1.1000,  0.0000,  2.2000],[ 2.2000,  3.3000,  4.4000,  6.6000],[ 0.0000,  3.3000,  0.0000,  4.4000],[ 6.6000,  9.9000,  8.8000, 13.2000]]]],grad_fn=<ConvolutionBackward0>)
tensor([[[[True, True, True, True],[True, True, True, True],[True, True, True, True],[True, True, True, True]]]])
3.4 调整padding

这是一个下采样的过程,会减少输出size。具体计算方法也很简单:给输出数据减去padding。基于3.1基本过程举例说明padding = 1的情况如下:

在这里插入图片描述

3.5 调整output_padding

这个参数用于给最终输出补0,output_padding必须要比stride或者dilation小。需要注意的是output_padding补0只能补半圈,如下:

我也想不明白为什么不是补一整圈?

在这里插入图片描述

4. 应用实例

在实际使用中,nn.ConvTranspose2d 可以嵌入到神经网络结构中,用于实现上采样、特征图尺寸放大或生成与输入尺寸相似的输出。以下是一个简单的使用示例:

import torch
import torch.nn as nn# 定义一个包含转置卷积层的简单模型
class TransposedConvModel(nn.Module):def __init__(self, in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1, output_padding=0):super().__init__()self.conv_transpose = nn.ConvTranspose2d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,padding=padding,output_padding=output_padding,bias=True)def forward(self, x):return self.conv_transpose(x)# 实例化模型并应用到输入数据
model = TransposedConvModel()
input_tensor = torch.randn(1, 32, 16, 16)  # (batch_size, in_channels, height, width)
output = model(input_tensor)
print("Output shape:", output.shape)

输出为:

Output shape: torch.Size([1, 64, 32, 32])

5. 总结

nn.ConvTranspose2d 是 PyTorch 中用于实现二维转置卷积的关键模块,它通过逆向的卷积操作实现了特征图的上采样和空间维度的扩大。

正确理解和配置其参数(如 kernel_sizestridepaddingoutput_padding 等),可以帮助开发者构建出适应特定任务需求的神经网络架构,特别是在图像生成、超分辨率、语义分割等需要从低分辨率特征恢复到高分辨率输出的应用场景中发挥关键作用。通过实践和调整这些参数,研究人员和工程师能够灵活地设计和优化基于转置卷积的深度学习模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/835218.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Docker快速搭建NAS服务——FileBrowser

Docker快速搭建NAS服务——FileBrowser 文章目录 前言FileBrowser的搭建docker-compose文件编写运行及访问 总结 前言 本文主要讲解如何使用docker在本地快速搭建NAS服务&#xff0c;这里主要写如下两种&#xff1a; FileBrowser1&#xff1a;是一个开源的Web文件管理器&…

运行SpringBoot项目失败?异常显示Can‘t load IA 32-bit .dll on a AMD 64-bit platform,让我来看看~

原因是&#xff0c;我放入jdk的bin文件夹下的tcnative-1.dll文件是32位的&#xff0c;那么肯定是无法在AMD 64位平台上加载IA 32位.dll。但是网站上给出的都是32位呀&#xff0c;没有64位怎么办&#xff1a; 其实当我们把“tomcat-native-1.2.34-openssl-1.1.1o-win32-bin.zip”…

机器学习-如何为模型选择评估指标?

为机器学习模型选择评估指标是一个关键步骤&#xff0c;因为它直接关联到如何衡量模型的性能。以下是选择评估指标的一些建议&#xff1a; 1、理解问题类型&#xff1a; 分类问题&#xff1a;对于二分类问题&#xff0c;常见的评估指标包括准确率、精确率、召回率、F1分数、R…

对多重继承关系的父子抽象类中子类的方法进行测试时如何回避Mock父类中的Protected方法

标题的说法就比较绕口&#xff0c;但是这个具体的问题大家看了下面内容就明白了。 如果在自己工作中遇到类似问题时可以试试这个解决办法。如果您技术好的话&#xff0c;其实不仔细看也行的&#xff0c;哈哈。 假设你有以下的类结构&#xff0c;该如何使用junit5,cdi-unit,moc…

以无侵方式实现Deployment原地升级

如何以无侵方式实现Deployment原地升级&#xff1f; 本文将展示如何以无侵、原生的方式实现Deployment原地升级。 在文章末尾会提供shell脚本供大家参考。 本文的原地升级仅指镜像更新 本篇kubernetes版本为v1.27.3。 原地升级的概念以及OpenKruise的实现方式可以参考文章&a…

Linux下GraspNet复现流程

Linux&#xff0c;Ubuntu中GraspNet复现流程 文章目录 Linux&#xff0c;Ubuntu中GraspNet复现流程1.安装cuda和cudnn2.安装pytorch3.编译graspnetAPIReference &#x1f680;非常重要的环境配置&#x1f680; ubuntu 20.04cuda 11.0.1cudnn v8.9.7python 3.8.19pytorch 1.7.0…

十大排序算法(java实现)

注&#xff1a;本篇仅用来自己学习&#xff0c;大量内容来自菜鸟教程&#xff08;地址&#xff1a;1.0 十大经典排序算法 | 菜鸟教程&#xff09; 排序算法可以分为内部排序和外部排序&#xff0c;内部排序是数据记录在内存中进行排序&#xff0c;而外部排序是因排序的数据很大…

Microsoft Edge浏览器,便携增强版 v118.0.5993.69

01 软件介绍 Microsoft Edge浏览器&#xff0c;便携增强版&#xff0c;旨在无需更新组件的情况下提供额外的功能强化。这一增强版专注于优化用户体验和系统兼容性&#xff0c;具体包含以下核心功能的提升&#xff1a; 数据保存&#xff1a;通过优化算法增强了其数据保存能力&…

1707jsp电影视频网站系统Myeclipse开发mysql数据库web结构java编程计算机网页项目

一、源码特点 JSP 校园商城派送系统 是一套完善的web设计系统&#xff0c;对理解JSP java编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统采用web模式&#xff0c;系统主要采用B/S模式开发。开发环境为TOMCAT7.0,Myeclipse8.5开发&#xff0c;数…

Bugku Crypto 部分题目简单题解(三)

where is flag 5 下载打开附件 Gx8EAA8SCBIfHQARCxMUHwsAHRwRHh8BEQwaFBQfGwMYCBYRHx4SBRQdGR8HAQ0QFQ 看着像base64解码 尝试后发现&#xff0c;使用在线工具无法解密 编写脚本 import base64enc Gx8EAA8SCBIfHQARCxMUHwsAHRwRHh8BEQwaFBQfGwMYCBYRHx4SBRQdGR8HAQ0QFQ tex…

【嵌入式DIY实例】-DDS信号生成器

DDS信号生成器 文章目录 DDS信号生成器1、AD9805介绍2、硬件准备与接线3、代码实现在本文中,将详细介绍如何使用AD9850来搭建一个简易的DDS(Direct Digital signal )信号生成器。 1、AD9805介绍 AD9850是一款高度集成的器件,采用先进的DDS技术,内置一个高速、高性能数模转…

powershell@管道符过滤的顺序问题@powershell管道符如何工作

文章目录 select 和 where谁先执行powershell管道符stop-service 为例查看文档中的典型参数介绍stop-process为例介绍管道符传参是怎么工作的Id参数InputObject 参数Name参数额外的试验反面例子应用:get-process 和stop-process配合 select 和 where谁先执行 在执行筛选时&…

每日一练 | 华为认证真题练习 - OSPF 协议基础

★ 题目 关于OSPF&#xff08;开放最短路径优先&#xff09;邻居状态机的描述&#xff0c;以下哪项是正确的&#xff1f; A. Attempt 状态只在 NBMA&#xff08;非广播多路访问&#xff09;网络中出现 B. Attempt 状态只在 NBMA 和 P2MP&#xff08;点对多点&#xff09;网络…

Unity构建详解(12)——自动构建

【前言】 自动构建是指整个构建流程不需要人工操作&#xff0c;只需要输入启动构建指令即可获取构建结果。实现这样的自动构建需要满足以下条件&#xff1a; 支持命令行参数启动 我们不可能每次构建时都打开Unity去手动点击构建&#xff0c;必须支持通过命令行启动Unity自动执…

java.lang.NoSuchMethodException: com.ruoyi.web.controller.test.bean.HeadTeacher

软件开发过程中使用Java反射机制时遇到了下面的问题 com.ruoyi.web.controller.test.bean.HeadTeacher4b9af9a9 com.ruoyi.web.controller.test.bean.HeadTeacher4b9af9a9java.lang.NoSuchMethodException: com.ruoyi.web.controller.test.bean.HeadTeacher.<init>(java…

【软考高项】三十八、风险管理7个过程

一、规划风险管理 1、定义、作用 定义&#xff1a;定义如何实施项目风险管理活动的过程作用&#xff1a;确保风险管理的水平、方法和可见度与项目风险程度相匹配&#xff0c;与对组织和其他干系人的重要程度相匹配 2、输入 项目管理计划 项目章程 项目文件 干系人登记册…

C语言从头学04——介绍占位符和输出格式

占位符、输出格式都是与 printf() 相关的&#xff0c;当然其它函数也有用到占位符的。这里先介绍它们在 printf() 的使用。 一、先介绍占位符&#xff0c;所谓“占位符”通俗讲就是先占个位置&#xff0c;后边再找具体值(参数)代入进行显示的一种方法。先用一个例子说明…

【一刷《剑指Offer》】面试题 17:合并两个排序的链表

力扣对应题目链接&#xff1a;21. 合并两个有序链表 - 力扣&#xff08;LeetCode&#xff09; 核心考点&#xff1a;链表合并。 一、《剑指Offer》内容 二、分析题目 这道题的解题思路有很多&#xff1a; 可以一个一个节点的归并。可以采用递归完成。 三、代码 1、易于理解的…

Java 多线程补充

线程池 Java线程池是一种能够有效管理线程资源的机制&#xff0c;它可以显著提高应用性能并降低资源消耗。 线程池的主要优点包括&#xff1a; 资源利用高效&#xff1a;通过重用已存在的线程&#xff0c;减少了频繁创建和销毁线程带来的系统开销。响应速度提升&#xff1a;…

智慧公厕,小民生里的“大智慧”!

公共厕所是城市社会生活的基础设施&#xff0c;而智慧公厕则以其独特的管理模式为城市居民提供更优质的服务。通过智能化的监测和控制系统&#xff0c;智慧公厕实现了厕位智能引导、环境监测、资源消耗监测、安全防范管理、卫生消杀设备、多媒体信息交互、自动化控制、自动化清…