【机器学习】卷积神经网络(CNN)的特征数计算

文章目录

  • 基本步骤
  • 示例
  • 图解过程

基本步骤

在卷积神经网络(CNN)中,计算最后的特征数通常涉及到以下步骤:

  1. 确定输入尺寸

    首先,你需要知道输入数据的尺寸。对于图像数据,这通常是 (batch_size, channels, height, width)

  2. 应用卷积层

    在卷积操作过程中,图像与卷积核进行滑动窗口式的乘加运算,这会导致图像尺寸的变化。特征数会根据卷积核的数量和大小以及步长等因素发生变化。

    • in_channels:输入数据的通道数。
    • out_channels:卷积层产生的输出特征图的数量,即卷积核的数量。
    • kernel_size:卷积核(filter)的大小(FxF)(kernel_size的选择对模型的性能有很大影响,因为它决定了模型能够捕捉到的特征的尺度和复杂性。增大kernel_size可以捕获更大范围的特征,但可能会增加计算复杂性和过拟合的风险;减小kernel_size则可以关注更细节、局部的特征,但可能忽略掉一些重要的全局信息。因此,选择合适的kernel_sizeCNN设计中的一个重要环节)。
    • stride:卷积核在输入数据上滑动的步长。
    • padding:在输入数据边缘添加的零填充的数量。


    卷积层的输出尺寸可以通过以下公式计算(floor()是向下取整函数):

    output_height = floor((input_height - kernel_size + 2 * padding) / stride) + 1
    output_width = floor((input_width - kernel_size + 2 * padding) / stride) + 1
    

    特征数(或通道数)在卷积层后变为 out_channels

  3. 应用池化层

    池化层通常不会改变特征数,但会改变特征图的高度和宽度。

    池化层的输出尺寸可以通过以下公式计算:

    output_height = floor((input_height - kernel_size) / stride) + 1
    output_width = floor((input_width - kernel_size) / stride) + 1
    
  4. 重复以上步骤

    继续应用卷积层和池化层,每次更新特征图的尺寸和特征数。

  5. 全局平均池化或全连接层

    在某些情况下,网络可能包含全局平均池化层或全连接层,这些层可以进一步改变特征数。为了将这些特征图转换为一维向量以输入到全连接层,你需要将特征图的元素"展平"(flatten)。展平的过程是将所有元素按顺序排列成一个单一的向量。

    计算展平后的输入维度(in_features)的公式为:

    in_features = channels * height * width
    
  6. 最终特征数

    网络的最后一层之前的特征图的通道数就是最后的特征数。

示例

以下是一个简单的例子来说明如何计算最后特征图的尺寸:给定 RGB 图像 (batch_size=32,channels=3,height=60,width=90)

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv_block1 = nn.Sequential(nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(32),nn.ReLU(inplace=True),nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(32),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),)self.conv_block2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),)self.fc2 = nn.Sequential(nn.Linear(18816, 9408),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(9408, 4704),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(4704, 5))def forward(self, x):x = self.conv_block1(x)x = self.conv_block2(x)x = x.reshape(x.shape[0], -1)x = self.fc2(x)return x

在上述代码中,给定一个 RGB 图像 (batch_size=32,channels=3,height=60,width=90),我们将图像输入到 self.conv_block1self.conv_block2 进行处理。

首先,我们计算经过 self.conv_block1 后的特征数:

  • 输入数据有 3 个通道(RGB 图像)。
  • 第一个卷积层将输出通道数增加到 32

由于 kernel_size=3, stride=1, padding=1,即卷积核的大小为 3×3,步长为 1,填充为 1,我们可以计算新的特征图尺寸:

output_height = (60 - 3 + 2 * 1) / 1 + 1 = 60
output_width = (90 - 3 + 2 * 1) / 1 + 1 = 90
  • 经过 ReLU 激活函数后,特征数保持为 32
  • 第二个卷积层仍然保持 32 个输出通道,同上特征图的高度和宽度不变。
  • 再经过 ReLU 激活函数后,特征数仍为 32
  • 最后,最大池化层不会改变通道数,但会减小特征图的高度和宽度。

由于 nn.MaxPool2d(kernel_size=3, stride=2),即最大池化层的池化窗口的大小为 3×3 步长为 2,我们可以计算新的特征图尺寸:

output_height = (60 - 3) / 2 + 1 = 29
output_width = (90 - 3) / 2 + 1 = 44

所以,经过self.conv_block1后,特征图的尺寸为(1, 32, 29, 44),特征数为 32

接下来,我们将这个 32 通道的特征图输入到self.conv_block2

  • 第一个卷积层将输出通道数从 32 增加到 64,同上特征图的高度和宽度不变。
  • 经过 ReLU 激活函数后,特征数保持为 64
  • 第二个卷积层仍然保持 64 个输出通道,同上特征图的高度和宽度不变。
  • 再经过 ReLU 激活函数后,特征数仍为 64
  • 最后,最大池化层不会改变通道数,但会进一步减小特征图的高度和宽度。

同样地,最大池化层的池化窗口的大小为 3×3 步长为 2,我们可以计算新的特征图尺寸:

output_height = (29 - 3) / 2 + 1 = 14
output_width = (29 - 3) / 2 + 1 = 21

因此,经过 self.conv_block1self.conv_block2 后,最终的特征图的尺寸为 (32, 64, 14, 21)

nn.LinearPyTorch 中的一个全连接层(Fully Connected Layer),它用于执行线性变换。全连接层的输入和输出维度通常是由网络架构和数据的特性决定的。

nn.Linear 的第一个参数,即输入维度(input_featuresin_features

为了将这些特征图转换为一维向量以输入到全连接层,你需要将特征图的元素“展平”(flatten)。展平的过程是将所有元素按顺序排列成一个单一的向量。我们可以计算展平后新的特征数,即输入维度 (in_features)

in_features = 64 * 14 * 21 = 18816

第一个全连接层输出维度为 9408,再经过 ReLU 激活函数。

nn.DropoutPyTorch 库中的一种正则化技术的实现,常用于防止过拟合。在深度学习模型训练过程中,dropout 通过随机忽略(“丢弃”)一部分神经元的输出来降低模型的复杂性。这里 dropout 比例为 0.5,那么在训练过程中,每一步有 50% 的神经元输出会被随机设置为0。

同上过程,再来一次最后输出维度为 5,显然这是个 5-分类问题

图解过程

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/231735.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Webpack安装及使用

win系统 全局安装Webpack及使用 前提:使用Webpack必须安装node环境,建议使用nvm管理node版本。 1:查看自己电脑是否安装了node 2:npm install webpack版本号 -g 3:npm install webpack-cli -g -g:表示全局安装 4&…

Wireshark在物联网中的应用

第一章:Wireshark基础及捕获技巧 1.1 Wireshark基础知识回顾 1.2 高级捕获技巧:过滤器和捕获选项 1.3 Wireshark与其他抓包工具的比较 第二章:网络协议分析 2.1 网络协议分析:TCP、UDP、ICMP等 2.2 高级协议分析:HTTP…

龙芯loongarch64服务器编译安装gcc-8.3.0

前言 当前电脑的gcc版本为8.3.0,但是在编译其他依赖包的时候,出现各种奇怪的问题,会莫名其妙的中断编译。本地文章讲解如何自编译安装gcc,替换系统自带的gcc。 环境准备 下载页面:龙芯开源社区网站 - LoongArch GCC 8.3 交叉工具链 - 源码下载源码包名称如:loongson-gnu…

2023楚慧杯 WEB方向 部分:(

1、eaaeval 查看源码能看见账号&#xff1a;username169&#xff0c;密码&#xff1a;password196提交这个用户密码可以跳转到页面/dhwiaoubfeuobgeobg.php 通过dirsearch目录爆破可以得到www.zip <?php class Flag{public $a;public $b;public function __construct(){…

2023-12-18 最大二叉树、合并二叉树、二叉搜索树中的搜索、验证二叉搜索树

654. 最大二叉树 核心&#xff1a;记住递归三部曲&#xff0c;一般传入的参数的都是题目给好的了&#xff01;把构造树类似于前序遍历一样就可&#xff01;就是注意单层递归的逻辑&#xff01; # Definition for a binary tree node. # class TreeNode: # def __init__(se…

【踩坑记录】pytorch 自定义嵌套网络时部分网络有梯度但参数不更新

问题描述 使用如下的自定义的多层嵌套网络进行训练&#xff1a; class FC1_bot(nn.Module):def __init__(self):super(FC1_bot, self).__init__()self.embeddings nn.Sequential(nn.Linear(10, 10))def forward(self, x):emb self.embeddings(x)return embclass FC1_top(nn…

强化产品联动:网关V7独家解决方案的三重优势

客户背景 某央企单位汇聚了众多业内优秀的工程师和科研人员&#xff0c;拥有先进的研发设施和丰富的研发经验&#xff0c;专注于为全球汽车行业提供创新和实用的解决方案。其研发成果不仅在国内市场上得到了广泛应用&#xff0c;也在国际市场上赢得了广泛的认可和赞誉。 客户需…

jconsole与jvisualvm

jconsole 环境变量配置好后 直接输入在cmd 输入jconsole 即可 jvisualvm cmd 输入jvisualvm jvisualvm 能干什么 监控内存泄露&#xff0c;跟踪垃圾回收&#xff0c;执行时内存、cpu 分析&#xff0c;线程分析… 运行&#xff1a;正在运行的 休眠&#xff1a;sleep 等待…

接口测试的工具(3)----postman+node.js+newman

1.安装newman&#xff1a;输入命令之后 一定注意 什么都不要操作 静静的等待结束就行了。 2.安装失败的对此尝试不行 在用下面的方法 解压一下就行了 3.验证是否成功 多次尝试是可以在线安装成功的

测试进程监控:确保产品质量的关键

引言&#xff1a; 在软件开发过程中&#xff0c;测试是确保产品质量的重要环节。为了提高测试效率和准确性&#xff0c;测试进程监控成为了不可或缺的工具。本文将介绍测试进程监控的各个方面&#xff0c;包括产品风险度量、缺陷度量源、测试用例&#xff08;或规程&#xff09…

Unity中Shader URP最简Shader框架(ShaderGraph 转 URP Shader)

文章目录 前言一、 我们先了解一下 Shader Graph 怎么操作1、了解一下 Shader Graph 的面板信息2、修改Shader路径3、鼠标中键 或 Alt 鼠标左键 移动画布4、鼠标右键 打开创建节点菜单5、把ShaderGraph节点转化为 Shader 代码6、可以看出 URP 和 BuildIn RP 大体框架一致 二、…

【Docker-2】在 Debian 上安装 Docker 引擎

在 Debian 上安装 Docker 引擎 要开始在 Debian 上使用 Docker 引擎&#xff0c;请确保满足先决条件&#xff0c;然后按照安装步骤操作。 先决条件 操作系统要求 要安装 Docker Engine&#xff0c;您需要以下 Debian 之一的 64 位版本 版本&#xff1a; Debian Bookworm 12…

隐私计算介绍

这里只对隐私计算做一些概念性的浅显介绍&#xff0c;作为入门了解即可 目录 隐私计算概述隐私计算概念隐私计算背景国外各个国家和地区纷纷出台了围绕数据使用和保护的公共政策国内近年来也出台了数据安全、隐私和使用相关的政策法规 隐私计算技术发展 隐私计算技术安全多方计…

C# WPF上位机开发(usb设备访问)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 目前很多嵌入式设备都支持usb访问&#xff0c;特别是很多mcu都支持高速usb访问。和232、485下个比较&#xff0c;usb的访问速度和它们基本不在一个…

C语言求n的阶乘(n!)

从键盘输入一个数&#xff0c;求出这个数的阶乘&#xff0c;即 n!。 1、算法思想 首先要清楚阶乘定义&#xff0c;所谓 n 的阶乘&#xff0c;就是从 1 开始乘以比前一个数大 1 的数&#xff0c;一直乘到 n&#xff0c;用公式表示就是&#xff1a;1234…(n-2)(n-1)nn! 具体的操…

unittest自动化测试框架讲解以及实战

为什么要学习unittest 按照测试阶段来划分&#xff0c;可以将测试分为单元测试、集成测试、系统测试和验收测试。单元测试是指对软件中的最小可测试单元在与程序其他部分相隔离的情况下进行检查和验证的工作&#xff0c;通常指函数或者类&#xff0c;一般是开发完成的。 单元…

进程间通讯-消息队列

介绍 消息队列是一种存放在内核中的数据结构&#xff0c;用于在不同进程之间传递消息。它基于先进先出&#xff08;FIFO&#xff09;的原则&#xff0c;进程可以将消息发送到队列中&#xff0c;在需要的时候从队列中接收消息。消息队列提供了一种异步通信的方式&#xff0c;使…

❤Mac上后端环境工具安装使用

❤Mac上后端环境工具安装使用 Cornerstone 使用 &#xff08;最好的SVN Mac软键&#xff09; 使用教程 安装 由于Cornerstone是收费的&#xff0c;因此你可以去网上下载破解版&#xff0c;直接安装即可。 配置远程仓库 首先&#xff0c;打开CornerStone&#xff0c;在界面…

工业数据的特殊性和安全防护体系探索思考

随着工业互联网的发展&#xff0c;工业企业在生产运营管理过程中会产生各式各样数据&#xff0c;主要有研发设计数据、用户数据、生产运营数据、物流供应链数据等等&#xff0c;这样就形成了工业大数据&#xff0c;这些数据需要依赖企业的网络环境和应用系统进行内外部流通才能…

19、商城系统(一):项目架构图,配置前端后台开发环境,构建git项目,导入 人人开源框架并前端后台启动

目录​​​​​​​ 一、项目架构图 二、配置环境 1.配置linux (1)复制linux环境