计算CNN卷积层和全连接层的参数量

计算CNN卷积层和全连接层的参数量

先前阅读

  • CNN Explainer
  • A Comprehensive Guide to Convolutional Neural Networks — the ELI5 way

本文主旨意在搞明白2个问题:
第一个问题
一个卷积操作,他的参数,也就是我们要训练的参数,也就是我们说的权重,有多少个? 看到一个nn.Conv()函数,就能知道有多少个,它由那些因子决定的?
参数量是由以下3个因子决定的:

  • 卷积核大小(HxW)
  • 卷积核维度(D)
  • 卷积核有多少个

则卷积层的参数量为 卷积核大小(HxW) * 卷积核维度(D) * 卷积核有多少个

第二个问题
一个全连接操作,参数又有多少个?它由那些因子决定的?

  • 输入大小为 N
  • 输出大小为 M

则全连接层的参数量为 N×M

计算CNN卷积层的参数量

案例1

在这里插入图片描述

动态演示
请添加图片描述

看上图案例1的计算,输入图像为 5x5x1, 卷积核3x3x1, 输出3x3x1;

思考3个参数:

  • 卷积核大小(HxW) ==》3x3
  • 卷积核维度(D) ==》1
  • 卷积核有多少个 ==》1

参数量为 3x3x1x1 = 9个

案例2

在这里插入图片描述
看上图案例2的计算,输入图像为 H1xW1x3, 卷积核3x3x3, 输出H2xW2x1;
思考3个参数:

  • 卷积核大小(HxW) ==》3x3
  • 卷积核维度(D) ==》3
  • 卷积核有多少个 ==》1

参数量为 3x3x3x1 = 27个

从上面的两个案例可以看出, 参数量与输入图像的HxW没有关系, 参数量与输出图像的HxW也没有关系。

案例3

VGG-16为例,conv1-1,第一层
输入224x224x3, 输出是224x224x64,卷积核3x3
思考3个参数:

  • 卷积核大小(HxW) ==》3x3
  • 卷积核维度(D) ==》3
  • 有多少个卷积核 ==》64

卷积核的维度是多少? 是由输入图像的维度决定,这里是3
卷积核的个数是多少? 是由输出图像的维度决定,这里是64

所以参数量 = 3x3x卷积核维度x卷积核个数 = 3x3x3x64 = 27个

Pytorch代码辅助理解

代码

nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')

案例3中的卷积操作如下:

nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')

参数量计算: = kernel_size * kernel_size * in_channels(卷积核维度) * out_channels(卷积核个数) = 3 * 3 * 3 * 64

stride=1, padding=0, 这两个会影响到输出的HxW,上文已经提到和我们要计算的参数量无关。

最后,补上偏置参数,
每个卷积核都加个偏置 ,所以总得参数量:
参数量计算: = kernel_size * kernel_size * in_channels(卷积核维度) * out_channels(卷积核个数) + bias(=卷积核个数) = 3 * 3 * 3 * 64+64

计算FC全连接层的参数量

先看一段代码,这是我们经常看到的一段代码,先把x解析到1x9的维度,再做全连接操作

self.fc = nn.Linear(9, 4)x = x.view(-1, 9) # 把x,解析到1x9的维度,这一个操作是没有权重的
x = self.fc(x) # 做全连接操作

上面的代码对应的操作图,如下
在这里插入图片描述
图片来源 | Fully Connected Layer vs. Convolutional Layer: Explained

红色框的参数,就是我们要找的权重参数,有多少个?
思考问题?

  • 输入大小为 N = 9
  • 输出大小为 M =4

计算参数量 = 9x4 = 36个

再看对应的连接图
在这里插入图片描述
上图中的每一条连接线(橙色和蓝色的线),都有一个权重参数,共36条,所以有36个参数。

最后,补上偏置参数,
偏置参数数量: 每个输出节点有一个偏置项(bias),因此偏置参数的数量等于输出节点的数量,即 M=4
所以,总的参数数量为N×M+M = 40,即 M 为输出节点数量,N 为输入节点数量。

END


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/648593.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

快速添加Android seLinux权限

selinux 权限问题中90%的场景都是在补足缺少的权限,下面的通用方法主要用来解决我们在日志中获取到 avc denied 的问题: 首先获取avc的打印信息,可以通过 logcat | grep avc 获取,假设有如下日志: type1400 audit(0.…

常见逻辑漏洞

挖掘重点: 业务流程和HTTP/HTTPS请求篡改 支付漏洞和越权漏洞是金融业务中常见的 支付漏洞 (1) 密码重置 验证码直接在HTTP响应中返回; 验证码未绑定用户,没和手机号和邮箱号做匹配验证; 未校验用户字段值,改自…

Navigation 2 学习01 介绍及安装及运行示例

Navigation 2 是什么 Nav2 是 ROS 导航 的综合控制服务,类似人类的小脑控制人类的行走及身体平衡,Nav2 针对移动和地面机器人提供支持的自动驾驶车辆的相同类型的技术,经过优化和改造。该项目旨在找到一种安全的方法,使移动机器人…

nginx离线部署-aarch64架构

nginx离线部署-aarch64架构 服务器环境: 架构:aarch64, 系统:Red Hat (CentOS 7) nginx 1.24 需要准备这些: 可以先尝试安装 Nginx 安装NGINX 内网是没有网络的需要使用 RPM 包安装 gcc, g…

USART通讯

提示:文章 文章目录 前言一、背景二、 2.1 2.2 总结 前言 前期疑问: 1、一开始没有搞明白到底是USART还是UART。 2、其中还涉及到一个同步的概念。同步就是是否有时钟线同步。USART是串口同步异步发送接收器。USART没有时钟线是怎么实现同步的。 3、…

服务器上面安装nodejs react

1、nvm管理nodejs 2、修改端口 /node_modules/react-scripts/scripts/start.js // 这是start.js部分源码 const DEFAULT_PORT parseInt(process.env.PORT, 10) || 3000; const HOST process.env.HOST || 0.0.0.0;// 将3000修改自己需要的端口号 const DEFAULT_PORT parseIn…

KMP字符串匹配算法

介绍: KMP算法是一种改进的字符串匹配算法,由D.E.Knuth,J.H.Morris和V.R.Pratt提出的,因此人们称它为克努特—莫里斯—普拉特操作(简称KMP算法)。KMP算法的核心是利用匹配失败后的信息,尽量减少…

绘制太极图 - 使用 PyQt

大家好!今天我们将一起来探讨一下如何使用PyQt,这是一个强大的Python库,来绘制一个传统的太极图。这个图案代表着古老的阴阳哲学,而我们的代码将以大白话的方式向你揭示它的奥秘。 PyQt:是什么鬼? 首先&a…

C# 中的接口

简介 官方说明:接口定义协定。 实现该协定的任何 class 或 struct 必须提供接口中定义的成员的实现。 接口可为成员定义默认实现。 它还可以定义 static 成员,以便提供常见功能的单个实现。 从 C# 11 开始,接口可以定义 static abstract 或 …

【AI-Pos系列】DeepLabCut 学习

level: nature neuroscience author:Alexander Mathis  1,2, Pranav Mamidanna1 , Kevin M. Cury3 , Taiga Abe3 , Venkatesh N. Murthy  2 , Mackenzie Weygandt Mathis  1,4,8* and Matthias Bethge1,5,6,7,8 date: 2018 keyword: quantifying behavior; pose estimation;…

超级胶水(第十一届蓝桥杯)

题目 小明有 n n n颗石子,按顺序摆成一排。他准备用胶水将这些石子粘在一起。 每颗石子有自己的重量,如果将两颗石子粘在一起,将合并成一颗新的石子,重量是这两颗石子的重量之和。 为了保证石子粘贴牢固,粘贴两颗石…

架构师之路(十六)计算机网络(传输层)

前置知识(了解):计算机基础。 作为架构师,我们所设计的系统很少为单机系统,因此有必要了解计算机和计算机之间是怎么联系的。局域网的集群和混合云的网络有啥区别。系统交互的时候网络会存在什么瓶颈。 既然网络层已经…

.net访问oracle数据库性能问题

问题: 生产环境相同的inser语句在别的非.NET程序相应明显快于.NET程序,执行时间相差比较大,影响正常业务运行,测试环境反而正常。 问题详细诊断过程 问题初步判断诊断过程: 查询插入慢的sql_id 检查对应的执行计划…

直播间流程解析基础

通过用户心理需求引导用户行为 贯穿内容和产品牵引想要和需要 直播间内流程解析 分为播前准备、开播暖场、产品介绍、穿插活动、结尾预告 (1)直播间内流程解析----播前准备 (2)直播间内流程解析----开播暖场 (3&…

互联网加竞赛 基于机器视觉的银行卡识别系统 - opencv python

1 前言 🔥 优质竞赛项目系列,今天要分享的是 基于深度学习的银行卡识别算法设计 该项目较为新颖,适合作为竞赛课题方向,学长非常推荐! 🧿 更多资料, 项目分享: https://gitee.com/dancheng…

电涌保护器(SPD)、后备保护器(SCB)、断路器(CB)的区别与应用

随着现代电力系统的不断发展,电力设备的保护显得愈发重要。其中,电涌保护器(SPD)、后备保护器(SCB)和断路器(CB)是三种常见的保护设备,但它们各自具有不同的功能和特点。…

【渗透测试】借助PDF进行XSS漏洞攻击

简介 在平时工作渗透测试一个系统时,常常会遇到文件上传功能点,其中大部分会有白名单或者黑名单机制,很难一句话木马上传成功,而PDF则是被忽略的一个点,可以让测试报告更丰富一些。 含有XSS的PDF制作步骤 1. 编辑器…

论文阅读《thanking frequency fordeepfake detection》

项目链接:https://github.com/yyk-wew/F3Net 这篇论文从频域的角度出发,提出了频域感知模型用于deepfake检测的模型 整体架构图: 1.FAD: 频域感知分解,其实就是利用DCT变换,将空间域转换为频域&#xff…

element+vue 之 v-limit 按钮操作权限

1.新建一个permission.js文件 import store from /storeexport default {inserted: function (el, binding) {const { perms: limits } store.state.userconst { value: params } bindingif (!limits.length) returnif (params && Array.isArray(params)) {if (!limi…

如何利用PyTorch?

上一篇文章介绍了“What is PyTorch?”,本篇文章探讨一下“How to use PyTorch?” 1、PyTorch PyTorch 是一个开源机器学习库,基于 Torch 库开发,主要由 Facebook 的人工智能研究实验室(FAIR)研发。它是一个强大且灵…