PyTorch Conv2d 前向传递中发生了什么?


在这里插入图片描述
💝💝💝欢迎莅临我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。
在这里插入图片描述

  • 推荐:「stormsha的主页」👈,持续学习,不断总结,共同进步,为了踏实,做好当下事儿~
  • 专栏导航
    • Python面试合集系列:Python面试题合集,剑指大厂
    • GO基础学习笔记系列:记录博主学习GO语言的笔记,该笔记专栏尽量写的试用所有入门GO语言的初学者
    • 数据库系列:详细总结了常用数据库 mysql 技术点,以及工作中遇到的 mysql 问题等
    • 运维系列:总结好用的命令,高效开发
    • 算法与数据结构系列:总结数据结构和算法,不同类型针对性训练,提升编程思维

非常期待和您一起在这个小小的网络世界里共同探索、学习和成长。💝💝💝 ✨✨ 欢迎订阅本专栏 ✨✨

💖The Start💖点点关注,收藏不迷路💖

📒文章目录

  • 1、概述
  • 2、Conv2d 核心参数
    • 2.1、什么是卷积核(kernel)?
    • 2.2、可训练参数(Trainable Parameters)和偏置(Bias)
    • 2.3、输入通道和输出通道的数量(Number of Input and Output Channels)
    • 2.4、卷积核大小(Kernel size)
    • 2.5、步长(Strides)
    • 2.6、填充(Padding)
    • 2.7、膨胀(Dilation)
    • 2.8、组(Groups)
    • 2.9、输出通道大小(Output Channel Size)
  • 3、总结


1、概述

随着人功智能的发展,涌现出了很多深度学习的库和平台,如Tensorflow、Keras、Pytorch、Caffe或Theano,在我们的日常开发中为我们提供很多帮助,基于这些深度学习库的应用程序层出不穷也让我们感到惊叹。每个开发者都有自己最喜欢的框架,它们的共同点是易于使用且可根据需要进行配置使我们的工作变得简单。但我们还是需要了解这些工具可用的论据是什么,以便更好地利用这些框架赋予我们的所有功能。

在这篇文章中,我将尝试列出所有这些参数。如果你想了解它们对计算时间、可训练参数的数量以及卷积输出通道大小的影响,那么这篇文章适合你。
Input Shape : (3, 7, 7) — Output Shape : (2, 3, 3) — K : (3, 3) — P : (1, 1) — S : (2, 2) — D : (2, 2) — G : 1

Input Shape : (3, 7, 7) — Output Shape : (2, 3, 3) — K : (3, 3) — P : (1, 1) — S : (2, 2) — D : (2, 2) — G : 1

2、Conv2d 核心参数

这篇文章的部分内容将根据以下参数进行讲解。这些参数可以在 Pytorch Conv2d 模块的文档中找到

  • in_channels(int)- 输入图像的通道数
  • out_channels(int)- 卷积生成的通道数
  • kernel_size(int 或 tuple)- 卷积核的大小
  • stride(int 或 tuple,可选)- 卷积的步长。默认:1
  • padding(int 或 tuple,可选)- 添加到输入两侧的零填充。默认:0
  • dilation(int 或 tuple,可选)- 核元素之间的间距。默认:1
  • groups(int,可选)- 从输入通道到输出通道的块连接数。默认:1
  • bias(bool,可选)- 如果为 True,则在输出中添加可学习的偏置。默认:True

最后,我们将掌握根据参数和输入通道大小计算输出通道大小的所有关键点。

2.1、什么是卷积核(kernel)?

输入图像和内核之间的卷积过程

输入图像和kernel之间的卷积过程

先介绍一下 kernel(或卷积矩阵)是什么。 kernel 描述了我们将要在一个输入图像上进行卷积操作的滤波器。简单来说, kernel 会在整个图像上移动,从左到右,从上到下,通过应用卷积运算。这个操作的输出被称为过滤图像。

卷积积

卷积积(Convolution product)

在这里插入图片描述

Input shape : (1, 9, 9) — Output Shape : (1, 7, 7) — K : (3, 3) — P : (0, 0) — S : (1, 1) — D : (1, 1) — G : 1

举一个非常基本的例子,让我们想象一个3乘3的卷积核对一个9乘9的图像进行过滤。然后,这个卷积核(Convolution Kernel)会在整个图像上移动,以捕捉图像中所有相同大小的方块(3乘3)。卷积积(Convolution product)是一种元素级(或点积)的乘法。这个结果的总和就是输出(或过滤后)图像上的像素值。

如果你对滤波器和卷积矩阵还不熟悉,那么我强烈建议你花更多的时间来理解卷积核(Convolution Kernel)。它们是二维卷积层的核心。

2.2、可训练参数(Trainable Parameters)和偏置(Bias)

可训练参数,也被称为“参数”,是在网络训练过程中将被更新的所有参数。在Conv2d中,可训练的元素是构成卷积核的值。所以对于我们的3乘3卷积核,我们有3*3=9个可训练参数。
卷积积(Convolution Product)与偏置(Bias)

卷积积(Convolution Product)与偏置(Bias)

为了更完整,我们可以包括偏置或不包含。偏置的作用是被添加到卷积积的总和中。这个偏置也是一个可训练参数,这使得我们3乘3卷积核的可训练参数数量上升到10个。

2.3、输入通道和输出通道的数量(Number of Input and Output Channels)

请添加图片描述

Input Shape: (1, 7, 7) — Output Shape : (4, 5, 5) — K : (3, 3) — P : (0, 0) — S : (1, 1) — D : (1, 1) — G : 1

使用层级结构的优势在于能够同时执行类似的操作。换句话说,如果我们想对一个输入通道应用4个相同大小的不同滤波器,那么我们将得到4个输出通道。这些通道是4个不同滤波器的结果。所以来自于4个不同的卷积核

在这里插入图片描述
随着卷积核数量的增加,参数的数量也会线性增加。因此,它也与所需的输出通道数量成线性关系。同样需要注意的是,计算时间也与输入通道的大小和卷积核的数量成正比。

在这里插入图片描述

参数图中的曲线是相同的

同样的原则也适用于输入通道的数量。让我们考虑一个使用RGB编码的图像的情况。这个图像有3个通道:红、蓝和绿。我们可以决定使用相同大小的滤波器在这三个通道上提取信息,以获得四个新的通道。因此,这个操作在三个通道上是相同的,用于获得四个输出通道。

在这里插入图片描述

Input Shape: (3, 7, 7) — Output Shape : (4, 5, 5) — K : (3, 3) — P : (0, 0) — S : (1, 1) — D : (1, 1) — G : 1

每个输出通道是过滤后的输入通道的总和。对于4个输出通道和3个输入通道,每个输出通道是3个过滤后的输入通道的总和。换句话说,卷积层由4*3=12个卷积核组成。
在这里插入图片描述
参数的数量和计算时间与输出通道的数量成正比。这是因为每个输出通道都与与其他通道不同的卷积核相关联。对于输入通道的数量也是如此。计算时间和参数数量会按比例增长。

在这里插入图片描述

2.4、卷积核大小(Kernel size)

到目前为止,所有的例子都是使用3乘3大小的卷积核。事实上,选择它的大小完全取决于你。你可以创建一个具有11或1919大小的卷积层。

并不是必须使用正方形的卷积核。可以选择具有不同高度和宽度的卷积核。这在信号图像分析中经常出现。如果我们知道我们想要扫描一个信号或声音的图像,那么我们可能更喜欢使用5*1大小的卷积核。如下图所示:
在这里插入图片描述

Input Shape: (3, 7, 9) — Output Shape : (2, 3, 9) — K : (5, 2) — P : (0, 0) — S : (1, 1) — D : (1, 1) — G : 1

你会注意到所有大小都由奇数定义。定义一个偶数的卷积核大小也是可以接受的。但在实践中,这很少做。通常选择奇数大小的卷积核,因为在中心像素周围有对称性。
在这里插入图片描述
由于卷积层的所有(经典)可训练参数都在卷积核中,所以参数的数量随着卷积核大小的增加而线性增长。计算时间也成比例变化。

2.5、步长(Strides)

默认情况下,卷积核从左到右、从上到下逐个像素进行移动。但这种移动也可以改变。通常用于对输出通道进行降采样。例如,使用步长为(1, 3),滤波器在水平方向上每3个像素移动一次,在垂直方向上每1个像素移动一次。这将产生水平方向上降采样3倍的输出通道。
在这里插入图片描述

Input Shape: (3, 9, 9) — Output Shape : (2, 7, 3) — K : (3, 3) — P : (0, 0) — S : (1, 3) — D : (1, 1) — G : 1

在这里插入图片描述
步长对参数数量没有影响,但计算时间会随着步长的增加而线性减少。

2.6、填充(Padding)

填充是指在对输入通道进行卷积滤波之前,添加到输入通道边缘的像素数量。通常,填充像素被设置为零。输入通道被扩展。
在这里插入图片描述

Input Shape : (2, 7, 7) — Output Shape : (1, 7, 7) — K : (3, 3) — P : (1, 1) — S : (1, 1) — D : (1, 1) — G : 1

当您希望输出通道的大小等于输入通道的大小时,这非常有用。简单来说,当卷积核为3*3时,输出通道的大小在每个方向上减小一个像素。为了解决这个问题,我们可以使用1个像素的填充。

在这里插入图片描述

参数图中的曲线是相同的。

因此,填充对参数数量没有影响,但会产生与填充大小成正比的额外计算时间。但通常来说,填充相对于输入通道的大小来说通常足够小,可以认为对计算时间没有影响。

2.7、膨胀(Dilation)

膨胀可以看作是卷积核的宽度。默认情况下等于1,它对应于卷积过程中卷积核在输入通道上的每个像素之间的偏移量。
在这里插入图片描述

Input Shape: (2, 7, 7) — Output Shape : (1, 1, 5) — K : (3, 3) — P : (1, 1) — S : (1, 1) — D : (4, 2) — G : 1

在GIF图中有点夸张,但如果我们以(4, 2)的膨胀为例,那么卷积核在输入通道上的感受野在垂直方向上会扩大4 * (3 -1)=8个像素,在水平方向上会扩大2 * (3-1)=4个像素(对于一个3乘3的卷积核)。

在这里插入图片描述

参数图中的曲线是相同的。

就像填充一样,膨胀对参数数量没有影响,对计算时间的影响也非常有限。

2.8、组(Groups)

在某些特定情况下,组可以非常有用。例如,如果我们有多个连接的数据源。当没有必要将它们相互依赖地处理时,输入通道可以分组并独立处理。最后,输出通道在结束时连接在一起。

如果有2个输入通道和4个输出通道,并且有2个组。那么这就像将输入通道分成两个组(每个组中有1个输入通道),并通过一个输出通道数量减半的卷积层。然后输出通道被连接在一起。
请添加图片描述

Input Shape : (2, 7, 7) — Output Shape : (4, 5, 5) — K : (3, 3) — P : (2, 2) — S : (2, 2) — D : (1, 1) — G : 2

需要注意的是,组的数量必须能够整除输入通道的数量和输出通道的数量(公因数)。
在这里插入图片描述
因此,参数的数量被组的数量所除。至于使用Pytorch的计算时间,算法针对组进行了优化,因此应该减少计算时间。然而,也应该考虑到必须将组的形成和输出通道的连接的计算时间相加。

2.9、输出通道大小(Output Channel Size)

有了所有参数的知识,就可以根据输入通道的大小计算输出通道的大小。
在这里插入图片描述

3、总结

所有计算时间测试都是使用Pytorch在我的GPU(GeForce GTX 960M)上进行的,如果你想自己运行它们或进行其他测试,Gitee仓库地址:https://gitee.com/stormsha/conv2d_demo

以上所有的GIF图像都是由Python生成,源码地址:Gitee仓库地址:https://gitee.com/stormsha/conv2d_demo

参考

Deep Learning Tutorial, Y. LeCun

Documentation torch.nn, Pytorch

Convolutional Neural Networks, cs231n

Convolutional Layers, Keras


❤️❤️❤️本人水平有限,如有纰漏,欢迎各位大佬评论批评指正!😄😄😄

💘💘💘如果觉得这篇文对你有帮助的话,也请给个点赞、收藏、分享下吧,非常感谢!👍 👍 👍

🔥🔥🔥道阻且长,行则将至,让我们一起加油吧!🌙🌙🌙

💖The End💖点点关注,收藏不迷路💖

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/2838.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Servlet对象的生命周期

1.什么是Servlet的生命周期? Servlet对象什么时候被创建,神魔时候被销毁。 Servlet对象创建了几个? Servlet对象的生命周期表示:一个 Servlet对象从出生在最后死亡,整个过程是怎样的。 Servlet对象由随来维护? Servlet对象的…

[Java、Android面试]_19_单例模式(高频问题)

本人今年参加了很多面试,也有幸拿到了一些大厂的offer,整理了众多面试资料,后续还会分享众多面试资料。 整理成了面试系列,由于时间有限,每天整理一点,后续会陆续分享出来,感兴趣的朋友可关注收…

设计模式:解释器模式

定义 解释器模式(Interpreter Pattern)是一种行为型设计模式,它给定一个语言,定义它的文法的一种表示,并定义一个解释器,这个解释器使用该表示来解释语言中的句子。简单来说,它主要用于某些特定…

Springboot的日常操作技巧

文章目录 1、自定义横幅2、容器刷新后触发方法自定义3、容器启动后触发方法自定义**CommandLineRunner**ApplicationRunner 不定时增加 参考文章 1、自定义横幅 简单就一点你需要把banner.text放到classpath 路径下 ,默认它会找叫做banner的文件,各种格式…

spring的bean创建流程源码解析

文章目录 IOC 和 DIBeanFactoryApplicationContext实现的接口1、BeanFactory接口2、MessageSource 国际化接口3、ResourcePatternResolver,资源解析接口4、EnvironmentCapable接口,用于获取环境变量,配置信息5、ApplicationEventPublisher 事…

使用扩展卡尔曼滤波器进行包裹测量的状态估计

此示例说明如何使用扩展卡尔曼滤波器算法对涉及圆形包裹角度测量的 3D 跟踪进行非线性状态估计。对于目标跟踪,传感器通常采用球形框架来报告物体的方位角、距离和仰角位置。该组的角度测量值在一定范围内报告。例如,报告的方位角范围为- 180∘ 到180∘或…

Leetcode55LeetCode45

题目 55. 跳跃游戏 思路 一看跳跃,自然想到动态规划。去看了题解发现可以将该问题进行转化,记录每个下标能达到的最远距离,要是这个最远距离超过了数组长度则说明能够到达终点。真的很巧妙!但是最开始自己写,想的是用…

UE5 GAS开发P34 游戏效果理论

GameplayEffects Attributes(属性)和Gameplay Tags(游戏标签)分别代表游戏中实体的特性和标识。 Attributes(属性):Attributes是用来表示游戏中实体的特性或属性的值,例如生命值、…

【网络通信】TCP三次握手、四次挥手

TCP(传输控制协议)是一种面向连接的、可靠的、基于字节流的传输层通信协议。在TCP/IP协议族中,TCP协议负责在两个网络节点之间建立可靠的连接,并保证数据包的顺序传输和数据的完整性。 1.TCP三次握手 TCP三次握手(Thr…

【工具】录屏软件Captura安装使用及ffmpeg下载配置

开启技术视频创作,录屏软件林林总总,适合的、习惯的最好。 录屏软件Captura的使用及ffmpeg下载配置 1.Captura下载、安装2.FFmpeg下载、配置3.Captura屏幕录制试用、录制视频效果 1.Captura下载、安装 Captura主要是一个免费开源的录屏软件&#xff0c…

系统架构设计精华知识

数据流风格:适合于分阶段做数据处理,交互性差,包括:批处理序列、管理过滤器。调用/返回风格:一般系统都要用到,包括:主程序/子程序,面向对象,层次结构(分层越…

20240330-1-词嵌入模型w2v+tf-idf

Word2Vector 1.什么是词嵌入模型? 把词映射为实数域向量的技术也叫词嵌⼊ 2.介绍一下Word2Vec 谷歌2013年提出的Word2Vec是目前最常用的词嵌入模型之一。Word2Vec实际是一种浅层的神经网络模型,它有两种网络结构,分别是连续词袋&#xff…

54位大咖演讲精华! 中国生成式AI大会圆满收官,TOP50企业榜单揭晓

54位大咖演讲精华! 中国生成式AI大会圆满收官,TOP50企业榜单揭晓© 由 红板报 提供 智东西4月19日报道,为期两天的2024中国生成式AI大会,今日在京圆满收官。 54位产学研投嘉宾代表全程干货爆棚,报名咨询人数逾52…

操作符不存在:sde.st_geometry ^ !sde.st_geometry建议 SQL函 数st_intersects在内联inlining期间

操作符不存在:sde.st_geometry ^ !sde.st_geometry建议 SQL函 数st_intersects在内联inlining期间 问题:最近在使用SQL图形处理函数处理图形时,莫名奇妙报如下错误,甚是费解 于是开始四处"寻医问药" 1、nav…

G1、CMS垃圾回收期专题

共同点 非阻塞 使用三色标记法 初始标记、并发标记、重新标记、并发清理 (初始标记、重新标记需要stop world) CMS垃圾回收器 缺点 浮动垃圾 有垃圾碎片 关注停顿时间,使用了效率最高的标记清除算法 G1垃圾回收器 缺点 需要配置高&…

MySQL表级锁——技术深度+1

引言 本文是对MySQL表级锁的学习,MySQL一直停留在会用的阶段,需要弄清楚锁和事务的原理并DEBUG查看。 PS:本文涉及到的表结构均可从https://github.com/WeiXiao-Hyy/blog中获取,欢迎Star! MySQL表级锁 MySQL中表级锁主要有表锁…

【Java EE】 SpringBoot配置文件

文章目录 🍀配置文件的作用🌴SpringBoot配置文件🍃配置文件的格式🌳properties 配置文件说明🌸properties基本语法🌸读取配置文件🌸properties 缺点分析 🌲yml 配置文件说明&#x1…

Docker基本管理和虚拟化

一、docker的发展历史 https://www.cnblogs.com/rongba/articles/14782624.htmlhttps://www.cnblogs.com/rongba/articles/14782624.html 二、docker的概述 Docker是一个开源的应用容器引擎,基于go语言开发并遵循了apache2.0协议开源。 Docker是在Linux容器里运行…

centos7 Nginx一键安装自动化脚本

离线环境下 centos7 Nginx一键安装自动化脚本 本文介绍了一个 Bash 脚本,可用于自动化安装 Nginx 服务器。该脚本简化了安装过程,省去了手动配置的繁琐步骤。 脚本功能特点: 依赖检查和安装: 自动检查并安装 Nginx 所需的依赖包…

轨迹跟踪控制导读

系列文章目录 提示:这里可以添加系列文章的所有文章的目录,目录需要自己手动添加 TODO:写完再整理 文章目录 系列文章目录前言一、小车底盘运动学模型介绍二、路径跟踪方法:PID控制算法实现路径跟踪--飞思卡尔的方法三、路径跟踪方法&#x…