神经网络基础 | 给定条件下推导对应的卷积层参数

神经网络基础 | 给定条件下推导对应的卷积层参数

在这里插入图片描述

按照 PyTorch 文档中 给定的设置:

H o u t = ⌊ H i n + 2 × padding [ 0 ] − dilation [ 0 ] × ( kernel_size [ 0 ] − 1 ) − 1 stride [ 0 ] + 1 ⌋ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor Hout=stride[0]Hin+2×padding[0]dilation[0]×(kernel_size[0]1)1+1

W o u t = ⌊ W i n + 2 × padding [ 1 ] − dilation [ 1 ] × ( kernel_size [ 1 ] − 1 ) − 1 stride [ 1 ] + 1 ⌋ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor Wout=stride[1]Win+2×padding[1]dilation[1]×(kernel_size[1]1)1+1

其中涉及到了几个卷积的参数:

  • 输入的尺寸, i i i H i n H_{in} Hin W i n W_{in} Win
  • 输出的尺寸, o o o H o u t H_{out} Hout W i n W_{in} Win
  • 对应维度上的单侧 padding, p p p padding \text{padding} padding
  • 对应维度上的扩张率, d d d dilation \text{dilation} dilation
  • 卷积核尺寸, k k k kernel_size \text{kernel\_size} kernel_size
  • 卷积滑动步长, s s s stride \text{stride} stride

一般来说,H 方向和 W 方向的参数是一样的,所以后续的介绍中仅考虑单一 H 方向。

另外,该运算规则对于 nn.Unfold 这类操作同样是满足的。

已知: i , o , k , d i, o, k, d i,o,k,d

这里的 o o o 其实也可以理解为沿着指定方向,在特定参数约束下的实际执行计算的窗口数量。

由于涉及到扩张率 d d d,所以我们应该直接考虑等效的卷积核 k ′ = ( k − 1 ) × d + 1 k' = (k-1) \times d + 1 k=(k1)×d+1。注意,这里的 k ′ k' k 仅用来表示滑窗大小,而并非表示实际的参与计算的元素数量。实际参与计算的依然只有 k k k 个元素。前者可以用于计算实际的滑窗次数,而后者在这里的推导中并不需要考虑。

已知实际窗口 k ′ k' k,我们可以获得滑窗步长 s = ⌈ i − k ′ o − 1 ⌉ s = \left \lceil \frac{i-k'}{o-1} \right \rceil s=o1ik。这里要注意向上取整的操作,由于实际步长需要为整数,所以这里如果除不尽的话需要凑到整数,但是又不能向下取整,因为向下取整会导致滑窗无法完全覆盖所有输入数据,而向上取整,则可以尽可能充分的覆盖整个轴向的数据,而多出来的部分,则可以通过 padding 策略来进行补齐。于是我们也由此可以获得整体的 padding 数,即 ⌈ s × ( o − 1 ) + k ′ − i ⌉ \left \lceil s \times (o-1) + k' - i \right \rceil s×(o1)+ki。也就是通过新的 stride 和等效的 kernel_size,重新计算一次输入尺寸,多出来的部分就是 padding 的数量。

关于 padding 的计算实际上需要考虑框架实际的需求,对于 PyTorch 而言,Conv2d 和 Unfold 都是针对 H 和 W 两个方向的两侧同时进行相同的 padding 操作,也就是说,左右各自对应的 p p p 是一样的。上下也是类似。所以我们这里提供的 p p p 应该是单侧的 padding 值,而通过 stride 直接作差获得的是单轴上的总 padding 数 p t o t a l p_{total} ptotal。所以需要取一半。此时又面临了向上取整还是向下取整的问题。考虑这个问题我们就得了解卷积操作究竟是如何对待 padding 的。实际上,padding 后的输入,在卷积时,如果最后一个窗口内的元素数量不够,那么这个窗口就会被舍弃,也就不会赋到输出变量里。所以只要输入 padding 后在最后一个滑窗位置之后的位置上凑不够一个新的滑窗,那么其就是等价的。所以我们对获得的 padding 直接除以 2 并向上取整即可: p ′ = ⌈ p t o t a l 2 ⌉ p' = \left \lceil \frac{p_{total}}{2} \right \rceil p=2ptotal

我们将代码整理下,可以得到用于计算这些量的函数:

@lru_cache()
def get_unfold_params_v0(width, num_kernels, kernel_size=8, dilation=1):real_kernel_size = (kernel_size - 1) * dilation + 1if width <= real_kernel_size:padding = math.ceil((real_kernel_size - width) / 2)assert width + padding <= real_kernel_size <= width + 2 * paddingstride = padding + 1num_kernels = 1else:stride = math.ceil((width - real_kernel_size) / (num_kernels - 1))if stride == 1:num_kernels = width - real_kernel_size + 1padding = math.ceil((stride * (num_kernels - 1) + real_kernel_size - width) / 2)params = dict(kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation)if _DEBUG:print(dict(width=width, num_kernels=num_kernels, **params))if not (stride >= 1or padding >= 0or width <= stride * (num_kernels - 1) + real_kernel_size <= width + padding * 2):raise ValueError(f"valid params does not exist for {dict(width=width, num_kernels=num_kernels, **params)}")return params, width, num_kernels

这里额外考虑了:

  • 当输入宽度小于指定核大小时,这个时候直接 padding 就行,但是同时需要修改送入的 num_kernels 参数。当然,如果要严格限定,那这里可以改为报错即可。
  • 当输入宽度大于指定核大小时,此时需要考虑超出的量
    • 如果超出的量并不能大于 num_kernels - 1,此时虽然向上取整后结果为 1,但是每个未取整的 stride 的真实值是小于 1 的。这样使用取整更新过后 stride 计算时,会造成不必要的 padding。所以此时我们更新下 o o o,也就是将其直接设为 s = 1 s=1 s=1 的情况下对应的结果,此时的 padding 数量为 0。当然,如果这里并不想更改 o o o,那么直接计算 padding 即可。
  • 对输出的参数进行一个简单的约束:
    • s s s 要大于等于 1;
    • p p p 要大于等于 0;
    • i ≤ s × ( o − 1 ) + k ′ ≤ i + 2 p i \le s \times (o - 1) + k' \le i +2p is×(o1)+ki+2p。最大为 padding 后的尺寸,最小则为原始尺寸。

已知: i , k , s , d i, k, s, d i,k,s,d

这里的输入不再限定输出的尺寸,而是提供了初始的步长约束 s s s

同样的,先计算真实的卷积核大小 k ′ = ( k − 1 ) × d + 1 k' = (k-1) \times d + 1 k=(k1)×d+1

按照一般情况,输入核小于输入尺寸,此时我们可以计算得到的对应的输出尺寸: o = ⌈ i − k ′ s + 1 ⌉ o = \left \lceil \frac{i-k'}{s} + 1 \right \rceil o=sik+1

由于输出尺寸并不是严格使用输入的 s s s 计算获得的,这里涉及到了一个取整的过程,所以实际上对应的 stride 也发生了改变,我们有必要依此对 stride 进行一下更新: s ′ = ⌈ i − k ′ o − 1 ⌉ s'= \left \lceil \frac{i-k'}{o-1} \right \rceil s=o1ik

输出尺寸得到后就该计算单侧 padding 的大小了,这里同样使用向上取整: ⌈ k ′ + s ′ × ( o − 1 ) − i 2 ⌉ \left \lceil \frac{k'+s' \times (o-1) - i}{2} \right \rceil 2k+s×(o1)i

对应的代码为:

@lru_cache()
def get_unfold_params_v1(width, kernel_size=8, stride=8, dilation=1):real_kernel_size = (kernel_size - 1) * dilation + 1if width <= real_kernel_size:padding = math.ceil((real_kernel_size - width) / 2)assert width + padding <= real_kernel_size <= width + 2 * paddingstride = padding + 1num_kernels = 1else:num_kernels = math.ceil((width - real_kernel_size) / stride) + 1if num_kernels == 1:stride = width - real_kernel_sizepadding = 0else:stride = math.ceil((width - real_kernel_size) / (num_kernels - 1))padding = math.ceil((real_kernel_size + stride * (num_kernels - 1) - width) / 2)params = dict(kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation)if _DEBUG:print(dict(width=width, num_kernels=num_kernels, **params))if not (stride >= 1or padding >= 0or width <= stride * (num_kernels - 1) + real_kernel_size <= width + padding * 2):raise ValueError(f"valid params does not exist for {dict(width=width, num_kernels=num_kernels, **params)}")return params, width, num_kernels

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/68995.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

欧拉(Euler 22.03)安装ProxySQL

下载离线安装包 proxysql-2.0.8-1-centos7.x86_64.rpm 链接: https://pan.baidu.com/s/1R-SJiVUEu24oNnPFlm9wRw 提取码: sa2w离线安装proxysql yum localinstall -y proxysql-2.0.8-1-centos7.x86_64.rpm 启动proxysql并检查状态 systemctl start proxysql 启动proxysql syste…

Sharding-JDBC 5.4.1+SpringBoot3.4.1+MySQL8.4.1 使用案例

最近在升级 SpringBoot 项目&#xff0c;原版本是 2.7.16&#xff0c;要升级到 3.4.0 &#xff0c;JDK 版本要从 JDK8 升级 JDK21&#xff0c;原项目中使用了 Sharding-JDBC&#xff0c;版本 4.0.0-RC1&#xff0c;在升级 SpringBoot 版本到 3.4.0 之后&#xff0c;服务启动失败…

WPS计算机二级•幻灯片的基础操作

听说这是目录哦 PPT的正确制作步骤&#x1f6e3;️认识PPT界面布局&#x1f3dc;️PPT基础操作 快捷键&#x1f3de;️制作PPT时 常用的快捷技巧&#x1f3d9;️快速替换PPT的 文本字体&#x1f303;快速替换PPT 指定文本内容&#x1f305;能量站&#x1f61a; PPT的正确制作步…

easyexcel读取写入excel easyexceldemo

1.新建springboot项目 2.添加pom依赖 <name>excel</name> <description>excelspringboot例子</description><parent> <groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifactId&…

Neural networks 神经网络

发展时间线 基础概念 多层神经网络结构 神经网络中一个网络层的数学表达 TensorFlow实践 创建网络层 神经网络的创建、训练与推理 推理 推理可以理解为执行一次前向传播 前向传播 前向传播直观数学表达 前向传播直观数学表达的Python实现 前向传播向量化实现 相关数学知识…

AR智慧点巡检系统探究和技术方案设计

一、项目背景 随着工业生产规模的不断扩大和设备复杂度的提升&#xff0c;传统的人工点巡检方式效率低下、易出错&#xff0c;难以满足现代化企业对设备运行可靠性和安全性的要求。AR&#xff08;增强现实&#xff09;技术的发展为点巡检工作带来了新的解决方案&#xff0c;通…

AI如何帮助解决生活中的琐碎难题?

引言&#xff1a;AI已经融入我们的日常生活 你有没有遇到过这样的情况——早上匆忙出门却忘了带钥匙&#xff0c;到了公司才想起昨天的会议资料没有打印&#xff0c;或者下班回家还在纠结晚饭吃什么&#xff1f;这些看似微不足道的小事&#xff0c;往往让人疲惫不堪。而如今&a…

用Python绘制一只懒羊羊

目录 一、准备工作 二、Turtle库简介 三、绘制懒羊羊的步骤 1. 导入Turtle库并设置画布 2. 绘制头部 3. 绘制眼睛 4. 绘制嘴巴 5. 绘制身体 6. 绘制四肢 7. 完成绘制 五、运行代码与结果展示 六、总结 在这个趣味盎然的技术实践中,我们将使用Python和Turtle图形…

form表单row中的col排列错位混乱

如图所示 form表单新增时排列整齐 编辑时就混乱 具体原因未知 解决方法&#xff1a;在el-row中加入样式 style"flex-wrap: wrap; display: flex" <el-row style"flex-wrap: wrap; display: flex">

OpenCV:高通滤波之索贝尔、沙尔和拉普拉斯

目录 简述 什么是高通滤波&#xff1f; 高通滤波的概念 应用场景 索贝尔算子 算子公式 实现代码 特点 沙尔算子 算子公式 实现代码 特点 拉普拉斯算子 算子公式 实现代码 特点 高通滤波器的对比与应用场景 相关阅读 OpenCV&#xff1a;图像滤波、卷积与卷积核…

error Parsing error: invalid-first-character-of-tag-name vue/no-parsing-error

标签的第一个字符不符合 HTML 或 Vue 的语法要求 [0] Module Warning (from ./node_modules/eslint-loader/index.js): [0] [0] /Users/dgq/Downloads/cursor/spid-admin/src/views/tools/fake-strategy/components/identify-list.vue [0] 87:78 error Parsing error: in…

在Unity中使用大模型进行离线语音识别

文章目录 1、Vosk下载下载vosk-untiy-asr下载模型在项目中使用语音转文字音频转文字2、whisper下载下载unity项目下载模型在unity中使用1、Vosk 下载 下载vosk-untiy-asr Github链接:https://github.com/alphacep/vosk-unity-asr 进不去Github的可以用网盘 夸克网盘链接:h…

【c语言日寄】Vs调试——新手向

【作者主页】siy2333 【专栏介绍】⌈c语言日寄⌋&#xff1a;这是一个专注于C语言刷题的专栏&#xff0c;精选题目&#xff0c;搭配详细题解、拓展算法。从基础语法到复杂算法&#xff0c;题目涉及的知识点全面覆盖&#xff0c;助力你系统提升。无论你是初学者&#xff0c;还是…

双指针+前缀和习题(一步步讲解)

前言&#xff1a;如果解决下面这几道题有些问题&#xff0c;或者即使看了我画的过程图也不理解的可以去看看我的上一篇文章&#xff0c;有可能会对你有帮助。 一、《数值元素的目标和》---来自AcWing 数组元素的目标和 给定两个升序排序的有序数组 A和 B&#xff0c;以及一个…

单调栈详解

文章目录 单调栈详解一、引言二、单调栈的基本原理1、单调栈的定义2、单调栈的维护 三、单调栈的应用场景四、使用示例1、求解下一个更大元素2、计算柱状图中的最大矩形面积 五、总结 单调栈详解 一、引言 单调栈是一种特殊的栈结构&#xff0c;它在栈的基础上增加了单调性约束…

分布式光纤应变监测是一种高精度、分布式的监测技术

一、土木工程领域 桥梁结构健康监测 主跨应变监测&#xff1a;在大跨度桥梁的主跨部分&#xff0c;如悬索桥的主缆、斜拉桥的斜拉索和主梁&#xff0c;分布式光纤应变传感器可以沿着这些关键结构部件进行铺设。通过实时监测应变情况&#xff0c;能够精确捕捉到车辆荷载、风荷…

《安富莱嵌入式周报》第349期:VSCode正式支持Matlab调试,DIY录音室级麦克风,开源流体吊坠,物联网在军工领域的应用,Unicode字符压缩解压

周报汇总地址&#xff1a;嵌入式周报 - uCOS & uCGUI & emWin & embOS & TouchGFX & ThreadX - 硬汉嵌入式论坛 - Powered by Discuz! 视频版&#xff1a; 《安富莱嵌入式周报》第349期&#xff1a;VSCode正式支持Matlab调试&#xff0c;DIY录音室级麦克风…

Pyside6(PyQT5)中的QTableView与QSqlQueryModel、QSqlTableModel的联合使用

QTableView 是QT的一个强大的表视图部件&#xff0c;可以与模型结合使用以显示和编辑数据。QSqlQueryModel、QSqlTableModel 都是用于与 SQL 数据库交互的模型,将二者与QTableView结合使用可以轻松地展示和编辑数据库的数据。 QSqlQueryModel的简单应用 import sys from PySid…

uniapp+Vue3(<script setup lang=“ts“>)模拟12306城市左右切换动画效果

效果图&#xff1a; 代码&#xff1a; <template><view class"container"><view class"left" :class"{ sliding: isSliding }" animationend"resetSliding">{{ placeA }}</view><view class"center…

VUE elTree 无子级 隐藏展开图标

这4个并没有下级节点&#xff0c;即它并不是叶子节点&#xff0c;就不需求展示前面的三角展开图标! 查阅官方文档如下描述&#xff0c;支持bool和函数回调处理&#xff0c;这里咱们选择更灵活的函数回调实现。 给el-tree结构配置一下props&#xff0c;注意&#xff01; :pr…