【机器学习】卷积神经网络(CNN)的特征数计算

文章目录

  • 基本步骤
  • 示例
  • 图解过程

基本步骤

在卷积神经网络(CNN)中,计算最后的特征数通常涉及到以下步骤:

  1. 确定输入尺寸

    首先,你需要知道输入数据的尺寸。对于图像数据,这通常是 (batch_size, channels, height, width)

  2. 应用卷积层

    在卷积操作过程中,图像与卷积核进行滑动窗口式的乘加运算,这会导致图像尺寸的变化。特征数会根据卷积核的数量和大小以及步长等因素发生变化。

    • in_channels:输入数据的通道数。
    • out_channels:卷积层产生的输出特征图的数量,即卷积核的数量。
    • kernel_size:卷积核(filter)的大小(FxF)(kernel_size的选择对模型的性能有很大影响,因为它决定了模型能够捕捉到的特征的尺度和复杂性。增大kernel_size可以捕获更大范围的特征,但可能会增加计算复杂性和过拟合的风险;减小kernel_size则可以关注更细节、局部的特征,但可能忽略掉一些重要的全局信息。因此,选择合适的kernel_sizeCNN设计中的一个重要环节)。
    • stride:卷积核在输入数据上滑动的步长。
    • padding:在输入数据边缘添加的零填充的数量。


    卷积层的输出尺寸可以通过以下公式计算(floor()是向下取整函数):

    output_height = floor((input_height - kernel_size + 2 * padding) / stride) + 1
    output_width = floor((input_width - kernel_size + 2 * padding) / stride) + 1
    

    特征数(或通道数)在卷积层后变为 out_channels

  3. 应用池化层

    池化层通常不会改变特征数,但会改变特征图的高度和宽度。

    池化层的输出尺寸可以通过以下公式计算:

    output_height = floor((input_height - kernel_size) / stride) + 1
    output_width = floor((input_width - kernel_size) / stride) + 1
    
  4. 重复以上步骤

    继续应用卷积层和池化层,每次更新特征图的尺寸和特征数。

  5. 全局平均池化或全连接层

    在某些情况下,网络可能包含全局平均池化层或全连接层,这些层可以进一步改变特征数。为了将这些特征图转换为一维向量以输入到全连接层,你需要将特征图的元素"展平"(flatten)。展平的过程是将所有元素按顺序排列成一个单一的向量。

    计算展平后的输入维度(in_features)的公式为:

    in_features = channels * height * width
    
  6. 最终特征数

    网络的最后一层之前的特征图的通道数就是最后的特征数。

示例

以下是一个简单的例子来说明如何计算最后特征图的尺寸:给定 RGB 图像 (batch_size=32,channels=3,height=60,width=90)

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv_block1 = nn.Sequential(nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(32),nn.ReLU(inplace=True),nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(32),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),)self.conv_block2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),)self.fc2 = nn.Sequential(nn.Linear(18816, 9408),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(9408, 4704),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(4704, 5))def forward(self, x):x = self.conv_block1(x)x = self.conv_block2(x)x = x.reshape(x.shape[0], -1)x = self.fc2(x)return x

在上述代码中,给定一个 RGB 图像 (batch_size=32,channels=3,height=60,width=90),我们将图像输入到 self.conv_block1self.conv_block2 进行处理。

首先,我们计算经过 self.conv_block1 后的特征数:

  • 输入数据有 3 个通道(RGB 图像)。
  • 第一个卷积层将输出通道数增加到 32

由于 kernel_size=3, stride=1, padding=1,即卷积核的大小为 3×3,步长为 1,填充为 1,我们可以计算新的特征图尺寸:

output_height = (60 - 3 + 2 * 1) / 1 + 1 = 60
output_width = (90 - 3 + 2 * 1) / 1 + 1 = 90
  • 经过 ReLU 激活函数后,特征数保持为 32
  • 第二个卷积层仍然保持 32 个输出通道,同上特征图的高度和宽度不变。
  • 再经过 ReLU 激活函数后,特征数仍为 32
  • 最后,最大池化层不会改变通道数,但会减小特征图的高度和宽度。

由于 nn.MaxPool2d(kernel_size=3, stride=2),即最大池化层的池化窗口的大小为 3×3 步长为 2,我们可以计算新的特征图尺寸:

output_height = (60 - 3) / 2 + 1 = 29
output_width = (90 - 3) / 2 + 1 = 44

所以,经过self.conv_block1后,特征图的尺寸为(1, 32, 29, 44),特征数为 32

接下来,我们将这个 32 通道的特征图输入到self.conv_block2

  • 第一个卷积层将输出通道数从 32 增加到 64,同上特征图的高度和宽度不变。
  • 经过 ReLU 激活函数后,特征数保持为 64
  • 第二个卷积层仍然保持 64 个输出通道,同上特征图的高度和宽度不变。
  • 再经过 ReLU 激活函数后,特征数仍为 64
  • 最后,最大池化层不会改变通道数,但会进一步减小特征图的高度和宽度。

同样地,最大池化层的池化窗口的大小为 3×3 步长为 2,我们可以计算新的特征图尺寸:

output_height = (29 - 3) / 2 + 1 = 14
output_width = (29 - 3) / 2 + 1 = 21

因此,经过 self.conv_block1self.conv_block2 后,最终的特征图的尺寸为 (32, 64, 14, 21)

nn.LinearPyTorch 中的一个全连接层(Fully Connected Layer),它用于执行线性变换。全连接层的输入和输出维度通常是由网络架构和数据的特性决定的。

nn.Linear 的第一个参数,即输入维度(input_featuresin_features

为了将这些特征图转换为一维向量以输入到全连接层,你需要将特征图的元素“展平”(flatten)。展平的过程是将所有元素按顺序排列成一个单一的向量。我们可以计算展平后新的特征数,即输入维度 (in_features)

in_features = 64 * 14 * 21 = 18816

第一个全连接层输出维度为 9408,再经过 ReLU 激活函数。

nn.DropoutPyTorch 库中的一种正则化技术的实现,常用于防止过拟合。在深度学习模型训练过程中,dropout 通过随机忽略(“丢弃”)一部分神经元的输出来降低模型的复杂性。这里 dropout 比例为 0.5,那么在训练过程中,每一步有 50% 的神经元输出会被随机设置为0。

同上过程,再来一次最后输出维度为 5,显然这是个 5-分类问题

图解过程

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/231735.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Webpack安装及使用

win系统 全局安装Webpack及使用 前提:使用Webpack必须安装node环境,建议使用nvm管理node版本。 1:查看自己电脑是否安装了node 2:npm install webpack版本号 -g 3:npm install webpack-cli -g -g:表示全局安装 4&…

龙芯loongarch64服务器编译安装gcc-8.3.0

前言 当前电脑的gcc版本为8.3.0,但是在编译其他依赖包的时候,出现各种奇怪的问题,会莫名其妙的中断编译。本地文章讲解如何自编译安装gcc,替换系统自带的gcc。 环境准备 下载页面:龙芯开源社区网站 - LoongArch GCC 8.3 交叉工具链 - 源码下载源码包名称如:loongson-gnu…

2023-12-18 最大二叉树、合并二叉树、二叉搜索树中的搜索、验证二叉搜索树

654. 最大二叉树 核心:记住递归三部曲,一般传入的参数的都是题目给好的了!把构造树类似于前序遍历一样就可!就是注意单层递归的逻辑! # Definition for a binary tree node. # class TreeNode: # def __init__(se…

强化产品联动:网关V7独家解决方案的三重优势

客户背景 某央企单位汇聚了众多业内优秀的工程师和科研人员,拥有先进的研发设施和丰富的研发经验,专注于为全球汽车行业提供创新和实用的解决方案。其研发成果不仅在国内市场上得到了广泛应用,也在国际市场上赢得了广泛的认可和赞誉。 客户需…

jconsole与jvisualvm

jconsole 环境变量配置好后 直接输入在cmd 输入jconsole 即可 jvisualvm cmd 输入jvisualvm jvisualvm 能干什么 监控内存泄露,跟踪垃圾回收,执行时内存、cpu 分析,线程分析… 运行:正在运行的 休眠:sleep 等待…

接口测试的工具(3)----postman+node.js+newman

1.安装newman:输入命令之后 一定注意 什么都不要操作 静静的等待结束就行了。 2.安装失败的对此尝试不行 在用下面的方法 解压一下就行了 3.验证是否成功 多次尝试是可以在线安装成功的

Unity中Shader URP最简Shader框架(ShaderGraph 转 URP Shader)

文章目录 前言一、 我们先了解一下 Shader Graph 怎么操作1、了解一下 Shader Graph 的面板信息2、修改Shader路径3、鼠标中键 或 Alt 鼠标左键 移动画布4、鼠标右键 打开创建节点菜单5、把ShaderGraph节点转化为 Shader 代码6、可以看出 URP 和 BuildIn RP 大体框架一致 二、…

隐私计算介绍

这里只对隐私计算做一些概念性的浅显介绍,作为入门了解即可 目录 隐私计算概述隐私计算概念隐私计算背景国外各个国家和地区纷纷出台了围绕数据使用和保护的公共政策国内近年来也出台了数据安全、隐私和使用相关的政策法规 隐私计算技术发展 隐私计算技术安全多方计…

C# WPF上位机开发(usb设备访问)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 目前很多嵌入式设备都支持usb访问,特别是很多mcu都支持高速usb访问。和232、485下个比较,usb的访问速度和它们基本不在一个…

unittest自动化测试框架讲解以及实战

为什么要学习unittest 按照测试阶段来划分,可以将测试分为单元测试、集成测试、系统测试和验收测试。单元测试是指对软件中的最小可测试单元在与程序其他部分相隔离的情况下进行检查和验证的工作,通常指函数或者类,一般是开发完成的。 单元…

工业数据的特殊性和安全防护体系探索思考

随着工业互联网的发展,工业企业在生产运营管理过程中会产生各式各样数据,主要有研发设计数据、用户数据、生产运营数据、物流供应链数据等等,这样就形成了工业大数据,这些数据需要依赖企业的网络环境和应用系统进行内外部流通才能…

【Python】—— NumPy基础及取值操作

NumPy基础及取值操作 第1关:ndarray对象第2关:形状操作第3关:基础操作第4关:随机数生成第5关:索引与切片 第1关:ndarray对象 任务描述 本关任务:根据本关所学知识,补全代码编辑器中…

react基于antd二次封装spin组件

目录 react基于antd二次封装spin组件组件使用组件效果 react基于antd二次封装spin组件 组件 import { Spin } from antd; import propTypes from "prop-types"; import React from react; import styleId from "styled-components"; // 使用 父div必须加…

【爬虫课堂】如何高效使用短效代理IP进行网络爬虫

目录 一、前言 二、代理IP的基本知识 三、短效代理IP的优势 四、高效使用短效代理IP的技巧 1. 多源获取代理IP 2. 质量筛选代理IP 3. 使用代理池 4. 定时更换代理IP 5. 失败重试机制 6. 监控和自动化 五、示例代码 六、结语 一、前言 网络爬虫是一种自动化程序&am…

MongoDB中的关系

本文主要介绍MongoDB中的关系。 目录 MongoDB的关系嵌入关系引用关系 MongoDB的关系 MongoDB是一个非关系型数据库,它使用了键值对的方式来存储数据。因此,MongoDB没有像传统关系型数据库中那样的表、行和列的概念。相反,MongoDB中的关系是通…

LLM之RAG实战(五)| 高级RAG 01:使用小块检索,小块所属的大块喂给LLM,可以提高RAG性能

RAG(Retrieval Augmented Generation,检索增强生成)系统从给定的知识库中检索相关信息,从而使其能够生成事实信息、上下文相关信息和特定领域的信息。然而,在有效检索相关信息和生成高质量响应方面,RAG面临…

【网络安全】-Linux操作系统—CentOS安装、配置

文章目录 准备工作下载CentOS创建启动盘确保硬件兼容 安装CentOS启动安装程序分区硬盘网络和主机名设置开始安装完成安装 初次登录和配置更新系统安装额外的软件仓库安装网络工具配置防火墙设置SELinux安装文本编辑器配置SSH服务 总结 CentOS是一个基于Red Hat Enterprise Linu…

美颜SDK是什么?视频美颜SDK在直播平台中的集成与接入教程详解

当下,主播们追求更加自然、精致的外观,而观众也期待在屏幕前欣赏到更为清晰、美丽的画面。为了满足这一需求,美颜SDK应运而生,成为直播平台的重要利器之一。 一、什么是美颜SDK? 通过美颜SDK,开发者可以…

Kotlin Multiplatform的现状—2023年网络研讨会

Kotlin Multiplatform的现状—2023年网络研讨会 在2023年,Kotlin Multiplatform因其开发、当前状态和未来潜力而受到了相当大的关注。随着越来越多的开发者对采用KMP进行跨平台解决方案表示兴趣,JetBrains在11月下旬推出了一系列网络研讨会作为回应。首…

“去 Android化”为何蔚然成风?

早在2008年时,国内市场诞生了第一批自研手机OS,由于种种缘由铩羽而归,“优化Android ”貌似成为了本土特色。而从2023年下半年开始掀起了一股"去安卓化"的热潮,像华为、小米、vivo等都不约而同的站在了同一战线。 “去…