【深度学习入门篇 ⑦】PyTorch池化层

【🍊易编橙:一个帮助编程小伙伴少走弯路的终身成长社群🍊】

大家好,我是小森( ﹡ˆoˆ﹡ ) ! 易编橙·终身成长社群创始团队嘉宾,橙似锦计划领衔成员、阿里云专家博主、腾讯云内容共创官、CSDN人工智能领域优质创作者 。


池化层 (Pooling) 降低维度,缩减模型大小,提高计算速度. 即: 主要对卷积层学习到的特征图进行下采样(SubSampling)处理 。

  • 通过下采样,我们可以提取出特征图中最重要的特征,同时忽略掉一些不重要的细节。
  • 上采样是指增加数据(图像)的尺寸;通常用于图像的分割、超分辨率重建或生成模型中,以便将特征图恢复到原始图像的尺寸或更大的尺寸。 

池化层

池化包含最大池化和平均池化,有一维池化,二维池化,三维池化,在这里以二维池化为例

最大池化

最大池化就是求一个区域中的最大值,来代替该区域。

torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)

假设输入的尺寸是(𝑁,𝐶,𝐻,𝑊),输出尺寸是(𝑁,𝐶,𝐻𝑜𝑢𝑡,𝑊𝑜𝑢𝑡),kernel_size是(𝑘𝐻,𝑘𝑊),可以写成下面形式 :

其中,输入参数 kernel_sizestridepaddingdilation可以是

  • 一个 int :代表长宽使用同样的参数
  • 两个int组成的元组:第一个int用在H维度,第二个int用在W维度
import torch
import torch.nn as nn
#长宽一致的池化,核尺寸为3x3,池化步长为2
ml = nnMaxPool2d(3, stride=2)
#长宽不一致的池化
m2 = nn.MaxPool2d((3,2), stride=(2,1))
input = torch.randn(4,3,24,24)
output1 = m1( input)
output2 = m2( input)
print( "input.shape = " ,input.shape)
print( "output1.shape = " , output1.shape)
print( "output2.shape = " , output2.shape)

 输出:

input.shape = torch.size([4,3,24,24])
output1.shape = torch. size([4,3,11,11])
output2.shape = torch.size([4,3,11,23])
平均池化

平均池化就是用一个区域中的平均数来代替本区域

torch.nn.AvgPool2d(kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None)

import torch
import torch.nn as nn
#长宽一致的池化,核尺寸为3x3,池化步长为2
ml = nn. AvgPool2d( 3, stride=2)
#长宽不一致的池化
m2 = nn. AvgPool2d(( 3,2), stride=(2,1) )
input = torch.randn(4,3,24,24)
output1 = m1( input)
output2 = m2( input)
print("input.shape = ",input. shape)
print("output1.shape = " , output1.shape)
print( "output2.shape = ", output2.shape)
  • randn是生成形状为[batch_size, channels, height, width] 

输出:

input.shape = torch.size([4,3,24,24])
output1.shape = torch.size([4,3,11,11])
output2.shape = torch.size([4,3,11,23])

BN层

BN,即Batch Normalization,是对每一个batch的数据进行归一化操作,可以使得网络训练更稳定,加速网络的收敛。

import torch
import torch.nn as nn
#批量归一化层(具有可学习参数)
m_learnable = nn. BatchNorm2d(100)
#批量归一化层(不具有可学习参数)
m_non_learnable = nn.BatchNorm2d(100,affine=False)
#随机生成输入数据
input = torch.randn(20,100,35,45)
#应用具有可学习参数的批量归一化层
output_learnable = m_learnable(input)
#应用不具有可学习参数的批量归一化层
output_non_learnable = m_non_learnable(input)
print( "input.shape = ", input.shape)
print( "output_learnable.shape = ", output_learnable.shape)
print( "output_non_learnable.shape = ", output_non_learnable.shape)

 输出:

input.shape = torch.size([20,100,35,45])
output_learnable.shape = torch.size( [20,100,35,45])
output_non_learnable.shape = torch.size([20,100,35,45])

常见的层就是上面提到的这些,如果这些层结构被反复调用,我们可以将其封装成一个个不同的模块。

案例:复现LeNet

LeNet结构,使用PyTorch进行复现,卷积核大小5x5,最大池化层,核大小2x2

import torch
import torch.nn as nn
from torchsummary import summary
class LeNet( nn . Module):def _init_( self,num_classes=10):super(Leet, self)._init__()self.conv1 = nn.conv2d( in_channels=3,out_channels=6,kernel_size=5)self.pool1 = nn. MaxPool2d(kernel_size=2)self.conv2 = nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5)self.pool2 = nn. MaxPool2d(kernel_size=2)self.conv3 = nn.conv2d(in_channels=16,out_channels=120, kernel_size=5)self.fc1 = nn.Linear(in_features=120,out_features=84)self.fc2 = nn.Linear(in_features=84,out_features=10)def forward(self, x):#通过卷积层、ReLU和池化层x = self.conv1(x)x = self.pool1(x)x = self.conv2(x)x = self.pool2(x)x = self.conv3(x)x = x.view( -1,120)x = self.fc1(x)x = self.fc2(x)return x
#创建网络实例
num_classes = 10
net = LeNet( num_classes)#创建一个输入
batch_size = 4
input_tensor = torch.randn(batch_size,3,32,32)
# 假设输入是32x32的RGB图像
#将输入Tensor传递给网络
output = net(input_tensor)
# #显示输出Tensor的形状
print(output.shape)
summary(net,(3,32,32))

Sequential: 顺序容器

Sequential属于顺序容器。模块将按照在构造函数中传递的顺序从上到下进行运算。

使用OrderedDict,可以进一步对传进来的层进行重命名。

#使用sequential来创建小模块,当有输入进来,会从上到下依次经过所有模块
model = nn. Sequential(
nn.conv2d(1,20,5),nn.ReLu() ,
nn.conv2d(20,64,5),nn.ReLU()
)
#使用orderedDict,可以对传进来的模块进行命名,实现效果同上
from collections import orderedDict
model = nn. sequential ( orderedDict([( 'conv1 ', nn.Conv2d( 1,20,5)),( 'relu1 ', nn.ReLU( ) ),( 'conv2 ', nn.conv2d(20,64,5)),( 'relu2 ', nn.ReLU())
]))

除此之外,还可以用 ModuleList和 ModuleDict 来存放子模块,但是用的不多,掌握了上面的内容就足够了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/46955.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Home Assistant在windows环境安装

Home Assistant是什么? Home Assistant 是一个开源的智能家居平台,旨在通过集成各种智能设备和服务,提供一个统一的、可自定义的家庭自动化解决方案。它可以允许用户监控、控制和自动化家中的各种设备,包括灯光、温度、安全系统、…

02-Redis未授权访问漏洞

免责声明 本文仅限于学习讨论与技术知识的分享,不得违反当地国家的法律法规。对于传播、利用文章中提供的信息而造成的任何直接或者间接的后果及损失,均由使用者本人负责,本文作者不为此承担任何责任,一旦造成后果请自行承担&…

IDEA快速生成项目树形结构图

下图用的IDEA工具,但我觉得WebStorm 应该也可以 文章目录 进入项目根目录下,进入cmd输入如下指令: 只有文件夹 tree . > list.txt 包括文件夹和文件 tree /f . > list.txt 还可以为相关包路径加上注释

ROS-机械臂——从零构建机器人模型

URDF建模 URDF URDF,全称为 Unified Robot Description Format(统一机器人描述格式),是一种用于描述机器人几何结构和运动学属性的标准文件格式。URDF 文件通常用于机器人模拟、路径规划、控制算法开发和可视化等领域&#xff0c…

React学习笔记03-----手动创建和运行

一、项目创建与运行【手动】 react-scripts集成了webpack、bable、提供测试服务器 1.目录结构 public是静态目录,提供可以供外部直接访问的文件,存放不需要webpack打包的文件,比如静态图片、CSS、JS src存放源码 (1&#xff09…

十大经典排序算法(1)——冒泡排序

一、算法简述 冒泡排序(Bubble Sort)是一种简单直观的暴力枚举式排序算法。它重复地遍历要排序数组,每次比较两个相邻元素,如果顺序错误就把他们交换过来。直到数组已经按照顺序排列,冒泡算法之所以叫做“冒泡”&…

公司想无偿裁员,同事赖着不走

关注卢松松,会经常给你分享一些我的经验和观点。 这招好像也不错! 事情是这样的:某公司准备把成本高的员工都裁掉,主要包含研发部和程序员,总共18个人,准备裁掉10人,因为他们工资开的太高了,…

HTML+CSS+JS井字棋(来自动下棋)

井字棋 自动下棋 玩家先下&#xff0c;计算机后下 源码在图片后面 点赞❤️收藏⭐️关注&#x1f60d; 效果图 源代码 <!DOCTYPE html> <html lang"en"> <head> <meta charset"UTF-8"> <title>Tic Tac Toe Game</tit…

释放DOE的能量,快速确定最佳工艺设置,节省时间、成本和资源

您是否希望降低成本、提高生产效率&#xff0c;并最大限度地减少行业对环境的影响&#xff1f; 所有行业&#xff0c;尤其是钢铁、铝、水泥和石化等能源密集型行业&#xff0c;都面临着应对这些挑战的持续压力。供应链压力、可持续发展、严格的监管环境、日益增长的消费者预期…

【Linux】权限的管理和Linux上的一些工具

文章目录 权限管理chgrpchownumaskfile指令sudo指令 目录权限粘滞位Linux中的工具1.软件包管理器yum2.rzsz Linux开发工具vim 总结 权限管理 chgrp 功能&#xff1a;修改文件或目录的所属组 格式&#xff1a;chgrp [参数] 用户组名 文件名 常用选项&#xff1a;-R 递归修改文…

股指期货与股票抛空机制的区别是什么?

在投资的世界里&#xff0c;有两种看似相似&#xff0c;实则大有不同的玩法&#xff1a;股指期货和股票抛空。让我们用通俗易懂的话来搞清楚这两者的区别。 股票抛空&#xff1a;借来卖出&#xff0c;期待低价买回 想象一下&#xff0c;你看到市场上有只股票&#xff0c;觉得…

基于STM32设计的超声波测距仪(微信小程序)(186)

基于STM32设计的超声波测距仪(微信小程序)(186) 文章目录 一、前言1.1 项目介绍【1】项目功能介绍【2】项目硬件模块组成1.2 设计思路【1】整体设计思路【2】ESP8266工作模式配置1.3 项目开发背景【1】选题的意义【2】可行性分析【3】参考文献1.4 开发工具的选择1.5 系统框架图…

Latte: Latent Diffusion Transformer for Video Generation

文章目录 AbstractIntroductionMethodology潜在扩散模型的初步研究Latte的模型变体Latte的实验验证潜在视频片段的patch embeddingTimestep-class information injectionTemporal positional embedding通过学习策略增强视频生成 Experiments Abstract Latte首先从输入的视频提…

成像光谱遥感技术中的AI革命:ChatGPT

遥感技术主要通过卫星和飞机从远处观察和测量我们的环境&#xff0c;是理解和监测地球物理、化学和生物系统的基石。ChatGPT是由OpenAI开发的最先进的语言模型&#xff0c;在理解和生成人类语言方面表现出了非凡的能力&#xff0c;ChatGPT在遥感中的应用&#xff0c;人工智能在…

太速科技-FMC207-基于FMC 两路QSFP+光纤收发子卡

FMC207-基于FMC 两路QSFP光纤收发子卡 一、板卡概述 本卡是一个FPGA夹层卡&#xff08;FMC&#xff09;模块&#xff0c;可提供高达2个QSFP / QSFP 模块接口&#xff0c;直接插入千兆位级收发器&#xff08;MGT&#xff09;的赛灵思FPGA。支持利用Spartan-6、Virtex-6、Kin…

PTA - 接收n个关键字参数

接收n个以关键字形式传入的参数&#xff0c;按格式输出。 函数接口定义&#xff1a; def print_info (**keyargs) 提示&#xff1a;keyargs为可变参数&#xff0c;其可接受若干个关键字形式的实参值&#xff0c;并将接收到的值组装为一个字典。 裁判测试程序样例&#xff1…

数据库:redis练习题

1、安装redis&#xff0c;启动客户端、验证。 redis-server redis-cli 2、string类型数据的命令操作&#xff1a; &#xff08;1&#xff09; 设置键值&#xff1a; set mykey "haha" &#xff08;2&#xff09; 读取键值&#xff1a; get mykey &#xff08;3&…

MSSQL Server运维常用SQL命令

1、数据库连接数 select name, state, state_desc from sys.databases; 查询结果&#xff1a; 2、数据库状态 select name, state, state_desc from sys.databases; 查询结果&#xff1a; 3、数据文件状态 select a.name, b.physical_name, b.state, b.state_desc from sy…

03MFC画笔/画刷/画椭圆/圆/(延时)文字

文章目录 画实心矩形自定义画布设计及使用连续画线及自定义定义变量扇形画椭圆/圆输出颜色文本定时器与定时事件画实心矩形 自定义画布设计及使用 连续画线及自定义定义变量 扇形 画椭圆/圆 输出颜色文本

尚品汇-(十六)

目录 商品详情功能开发 &#xff08;1&#xff09;搭建service-item &#xff08;2&#xff09;获取sku基本信息与图片信息 &#xff08;3&#xff09;获取分类信息&#xff08;查看三级分类&#xff09; 商品详情功能开发 &#xff08;1&#xff09;搭建service-item 点…