SqueezeNet 一维,二维网络复现 pytorch 小白易懂版

SqueezeNet

时隔一年我又开始复现神经网络的经典模型,这次主要复的是轻量级网络全家桶,轻量级神经网络旨在使用更小的参数量,无限的接近大模型的准确率,降低处理时间和运算量,这次要复现的是轻量级网络的非常经典的一个模型SqueezeNet,它由美国加州大学伯克利分校的研究团队开发,并于2016年发布。


文章链接: https://arxiv.org/pdf/1602.07360.pdf?source=post_page---------------------------

看懂这篇文章需要的基础知识

  1. 了解python语法基础
  2. 了解深度学习基本原理
  3. 知道什么是卷积层池化层激活函数层softmanx层
  4. 熟悉卷积层池化层需要的参数
  5. 需要了解pytorch模型的基本构成

我记得去年的这个时候,好像GPT还没被特别广泛的使用,还没到一键就能直接输出写好的模型的这一个步骤,那为什么还要看博客这类的文章呢,应该是因为毕竟GPT他还是靠着已有的资料进行读取,他不能图文并茂的给你写一个一定好用的大型模型,不然直接把论文甩给他让他复现就好了,所以还是打算写一下,然后简单画点图然后给之后的学弟学妹们留一点遗产。

SqueezeNet 的模型结构

下面是原论文给出的模型结构
在这里插入图片描述
原文中给出了三种模型,分别是第一个基础模型,以及第二个和第三个带有残差分支的模型,其中卷积池化分支我们都有了解,这里新的东西就是这个Fire层,那就先从这个Fire层开始介绍

Fire层

作者说他的SqueezeNet网络为什么可以有更小的参数量,主要由于用了下面这个叫Fire层的东西,Fire层分两部分

  • 一部分是Squeeze层其实就是卷积核大小为1×1的一个卷积层
  • 另一部分呢是expend层他实际上是卷积核大小为1×1和卷积核大小为卷积层和3×3输出的一个拼接

下面是原论文中对Fire模型的详细描述
在这里插入图片描述
在这里插入图片描述
那如果要实现一维的那就把3×3的卷积核改成1×3的
加上激活函数,其实现代码应该是这样的,接下来详细介绍里面的参数。

  • in_channels 指Fire模块的输入通道数,也是就每个Fire模块的squeeze卷积层的输入通道数
  • squeeze_channels 指的是squeeze层的输出通道数
  • expand1x1_channels 指的是expand层中卷积核大小为1×1的卷积层的输出通道数
  • expand1x3_channels 指的是expand层中卷积核大小为1×2的卷积层的输出通道数
class FireModule(torch.nn.Module):def __init__(self, in_channels, squeeze_channels, expand1x1_channels, expand1x3_channels):super(FireModule, self).__init__()self.squeeze = torch.nn.Conv1d(in_channels, squeeze_channels, kernel_size=1)self.relu = torch.nn.ReLU(inplace=True)self.expand1x1 = torch.nn.Conv1d(squeeze_channels, expand1x1_channels, kernel_size=1)self.expand1x3 = torch.nn.Conv1d(squeeze_channels, expand1x3_channels, kernel_size=3, padding=1)def forward(self, x):x = self.squeeze(x)x = self.relu(x)out1x1 = self.expand1x1(x)out1x3 = self.expand1x3(x)out = torch.cat([out1x1, out1x3], dim=1)return self.relu(out)

基础知识补充: torch.cat 将向量在某一个维度上拼接

import torch
# Create two tensors
out1x1 = torch.tensor([[1, 2, 3], [1, 2, 3]])
out1x3 = torch.tensor([[4, 5, 6], [7, 8, 9]])# Concatenate the tensors along the second dimension (dim=1)
out = torch.cat([out1x1, out1x3], dim=1)
print(out)
# tensor([[1, 2, 3, 4, 5, 6],
#         [1, 2, 3, 7, 8, 9]])
out = torch.cat([out1x1, out1x3], dim=0)
print(out)
# tensor([[1, 2, 3],
#         [1, 2, 3],
#         [4, 5, 6],
#         [7, 8, 9]])

那有了Fire层模块之后就可以开始搭建我们的模型,那在搭建的过程中,各个层的参数如何设置呢,原文中给了如下表
在这里插入图片描述

  • 第一列Layer name/type 指的是层的名称和类型
  • 第二列Output size 指的是输出尺寸
  • 第三列是filter size/stride (if not a fire layer)滤波器(卷积核/池化核)的大小(不包含Fire层)
  • 第四列depth 卷积层的深度,可以无视掉,没什么用
  • 第五-第七 给的就是Fire 层的参数了

再后面的是稀疏性字节大小还有修剪前后的参数大小,这部分不用过于关注,可能要多提一下的就是这个稀疏性sparsity,他指的是卷积层里选择多少参数一直为0,但是并没有详细说具体是怎么实现的,然后我也去搜了一下,需要用一些正则化的东西才可以,这个问题我打算再详细理解一下,暂时我们都默认稀疏性是100,不再为了稀疏性降低参数量实现额外复杂的工作.

根据参数和结构实现代码

一维

import torch
from torchsummary import summary
class FireModule(torch.nn.Module):def __init__(self, in_channels, squeeze_channels, expand1x1_channels, expand1x3_channels):super(FireModule, self).__init__()self.squeeze = torch.nn.Conv1d(in_channels, squeeze_channels, kernel_size=1)self.relu = torch.nn.ReLU(inplace=True)self.expand1x1 = torch.nn.Conv1d(squeeze_channels, expand1x1_channels, kernel_size=1)self.expand1x3 = torch.nn.Conv1d(squeeze_channels, expand1x3_channels, kernel_size=3, padding=1)def forward(self, x):x = self.squeeze(x)x = self.relu(x)out1x1 = self.expand1x1(x)out1x3 = self.expand1x3(x)out = torch.cat([out1x1, out1x3], dim=1)return self.relu(out)class SqueezeNet(torch.nn.Module):def __init__(self,in_channels,classes):super(SqueezeNet, self).__init__()self.features = torch.nn.Sequential(# conv1torch.nn.Conv1d(in_channels, 96, kernel_size=7, stride=2),torch.nn.ReLU(inplace=True),# maxpool1torch.nn.MaxPool1d(kernel_size=3, stride=2),# Fire2FireModule(96, 16, 64, 64),# Fire3FireModule(128, 16, 64, 64),# Fire4FireModule(128, 32, 128, 128),# maxpool4torch.nn.MaxPool1d(kernel_size=3, stride=2),# Fire5FireModule(256, 32, 128, 128),# Fire6FireModule(256, 48, 192, 192),# Fire7FireModule(384, 48, 192, 192),# Fire8FireModule(384, 64, 256, 256),# maxpool8torch.nn.MaxPool1d(kernel_size=3, stride=2),# Fire9FireModule(512, 64, 256, 256))self.classifier = torch.nn.Sequential(# conv10torch.nn.Conv1d(512, classes, kernel_size=1),torch.nn.ReLU(inplace=True),# avgpool10torch.nn.AdaptiveAvgPool1d((1)))def forward(self, x):x = self.features(x)x = self.classifier(x)x = torch.flatten(x, 1)return xif __name__ == "__main__":# 创建一个SqueezeNet实例model = SqueezeNet(in_channels=3,classes=10)# model = FireModule(96,16,64,64)# 打印模型结构summary(model=model, input_size=(3, 224), device='cpu')

二维

import torch
from torchsummary import summary
class FireModule(torch.nn.Module):def __init__(self, in_channels, squeeze_channels, expand1x1_channels, expand3x3_channels):super(FireModule, self).__init__()self.squeeze = torch.nn.Conv2d(in_channels, squeeze_channels, kernel_size=1)self.relu = torch.nn.ReLU(inplace=True)self.expand1x1 = torch.nn.Conv2d(squeeze_channels, expand1x1_channels, kernel_size=1)self.expand3x3 = torch.nn.Conv2d(squeeze_channels, expand3x3_channels, kernel_size=3, padding=1)def forward(self, x):x = self.squeeze(x)x = self.relu(x)out1x1 = self.expand1x1(x)out3x3 = self.expand3x3(x)out = torch.cat([out1x1, out3x3], dim=1)return self.relu(out)class SqueezeNet(torch.nn.Module):def __init__(self,in_channels,classes):super(SqueezeNet, self).__init__()self.features = torch.nn.Sequential(# conv1torch.nn.Conv2d(in_channels, 96, kernel_size=7, stride=2),torch.nn.ReLU(inplace=True),# maxpool1torch.nn.MaxPool2d(kernel_size=3, stride=2),# Fire2FireModule(96, 16, 64, 64),# Fire3FireModule(128, 16, 64, 64),# Fire4FireModule(128, 32, 128, 128),# maxpool4torch.nn.MaxPool2d(kernel_size=3, stride=2),# Fire5FireModule(256, 32, 128, 128),# Fire6FireModule(256, 48, 192, 192),# Fire7FireModule(384, 48, 192, 192),# Fire8FireModule(384, 64, 256, 256),# maxpool8torch.nn.MaxPool2d(kernel_size=3, stride=2),# Fire9FireModule(512, 64, 256, 256))self.classifier = torch.nn.Sequential(# conv10torch.nn.Conv2d(512, classes, kernel_size=1),torch.nn.ReLU(inplace=True),# avgpool10torch.nn.AdaptiveAvgPool2d((1,1)))def forward(self, x):x = self.features(x)x = self.classifier(x)x = torch.flatten(x, 1)return xif __name__ == "__main__":# 创建一个SqueezeNet实例model = SqueezeNet(in_channels=3,classes=10)# model = FireModule(96,16,64,64)# 打印模型结构summary(model=model, input_size=(3, 224, 224), device='cpu')

结束

对于SqueezeNet的第二个和第三个模型,我先把其他的轻量级网络都复现完之后我再回来写一下,对于入门来说先实现个基础版本就够用了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/116569.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Node.js】暴露自定义响应头和预检请求的时机

1. 暴露自定义响应头 // server.js app.post(/api/user/hello, (req, res) > {res.setHeader(Access-Control-Allow-Origin, *)// 权限设置(如果有个多,用 ,隔开),暴露给前端res.setHeader(Access-Control-expose-…

muduo源码学习base——Atomic(原子操作与原子整数)

Atomic(原子操作与原子整数) 前置知识AtomicIntegerTget()getAndAdd()getAndSet() 关于原子操作实现无锁队列(lock-free-queue) 前置知识 happens-before: 用来描述两个操作的内存可见性 如果操作 X happens-before 操作 Y,那么 X 的结果对于…

React hooks介绍及使用

介绍: React hooks 是 React 16.8 版本引入的新特性,它允许你在无需编写类组件的情况下,能够使用状态和其他 React 特性。它是基于函数组件的,使得函数组件也能够拥有类组件的状态和生命周期等特性,同时减少了处理一些…

有奖招募——2023年度清华社“荐书官”活动今日开始了!

又到“1024程序员节”了,维护网络世界稳定和平的程序员大大们,辛苦了!生活难免有bug,来给彼此个hug~ 过完1024,这一年也快要结束了,岁末回顾又要提上日程。很多人都有整理年度书单的习惯,那么这…

架构风格区别-架构案例(五十九)

管道-过滤器和仓库的区别? 独立的数据仓库,处理流独立,处理数据用连接仓库工具数据与处理在一起,改动的话需要重启系统需要仓库工具与仓库连接,数据与处理分离,性能差可以支持并发连接访问仓库&#xff0c…

隐藏微信网页右上角的按钮、在微信网页中获取用户的网络状态,支付等

1.隐藏微信网页右上角按钮 <script type"text/javascript">document.addEventListener(WeixinJSBridgeReady,function onBridgeReady() {// 通过下面这个API隐藏右上角按钮WeixinJSBridge.call(hideOptionMenu); });document.addEventListener(WeixinJSBridge…

【经验分享】如何构建openGauss开发编译提交一体化环境

前文 本文适合对openGauss源代码有好奇心的爱好者&#xff0c;那么一个友好的openGauss源代码环境应该是怎么样的。openGauss的开发环境是如何设置的&#xff1f;openGauss的编译环境是如何构建的&#xff1f;如何向openGauss提交代码&#xff0c;笔者集合官方和几位博主实践提…

为什么要学习python?

Python是一种广泛使用的编程语言&#xff0c;它的简洁易读以及强大的功能使得它成为了许多人喜爱的编程语言之一。无论是初学者还是有经验的开发者&#xff0c;学习Python都是非常有价值的。在本篇博文中&#xff0c;我们将探讨学习Python的一些重要原因&#xff0c;并提供一些…

2023年【危险化学品生产单位主要负责人】考试报名及危险化学品生产单位主要负责人模拟考试题

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 危险化学品生产单位主要负责人考试报名考前必练&#xff01;安全生产模拟考试一点通每个月更新危险化学品生产单位主要负责人模拟考试题题目及答案&#xff01;多做几遍&#xff0c;其实通过危险化学品生产单位主要负…

cpp中this和*this区别

大家好&#xff0c;我叫徐锦桐&#xff0c;个人博客地址为www.xujintong.com。平时记录一下学习计算机过程中获取的知识&#xff0c;还有日常折腾的经验&#xff0c;欢迎大家访问。 this&#xff1a;是返回当前对象的地址&#xff08;指向当前对象的指针&#xff09;。 *this&a…

用nodejs爬虫台湾痞客邦相册

情景:是这样的,我想保存一些喜欢的小伙伴的照片,一张张保存太慢了, 所以我写了个js,放在国外服务器爬,国内的自己解决~ 使用方法 1.点相册随便一张, 复制url, 这张开始接下来的图片都会保存 /*** 2023年10月23日 22:58:44* 支持解析痞客邦相册* 只需要复制相册第一张图片的ur…

新款模块上线实现SIP模块与扩拨电话之间打点与喊话功能 IP矿用电话模块SV-2800VP

新款模块上线实现SIP模块与扩拨电话之间打点与喊话功能 IP矿用电话模块SV-2800VP 一、简介 SV-2800VP系列模块是我司设计研发的一款用于井下的矿用IP音频传输模块&#xff0c;可用此模块打造一套低延迟、高效率、高灵活和多扩展的IP矿用广播对讲系统&#xff0c;亦可对传统煤…

嵌入式实时操作系统的设计与开发 (启动内核学习)

RTOS的引导模式 RTOS的引导是指将操作系统装入内存并开始执行的过程。 时间限制主要包括&#xff1a;系统要求快速启动和系统启动后要求程序能实时运行。 空间限制主要包括&#xff1a;Flash等非易失性存储空间限制和RAM等易失性存储空间限制。 通常不可能同时满足两种要求&a…

Linux 下安装配置部署MySql8.0

一 . 准备工作 MySQL安装包&#xff1a;在官网下载需要的版本&#xff0c;这里我用的版本是 MySQL 8.0.34 https://dev.mysql.com/downloads/mysql/ 本次linux机器使用的是阿里云ECS实例 二 . 开始部署 1. 将安装包上传至服务器 解压到当前文件夹 tar -zxvf mysql-8.0.34…

SAP HANA Time Zone设置

通常对于MINICHECK中检查出来的Timezone时区设置问题&#xff0c;可以通过以下方式进行修改 对于ABAP系统 修改HANA 参数即可 • indexserver.ini -> [global] -> timezone_default_data_client_name 000 • indexserver.ini -> [global] -> timezone_default_da…

蓝桥杯每日一题2023.10.21

后缀表达式 - 蓝桥云课 (lanqiao.cn) 题目描述 题目分析 30分解法&#xff1a;要求出最大的结果就需要加的数越大&#xff0c;减的数越小&#xff0c;以此为思路简单列举即可 #include<bits/stdc.h> using namespace std; typedef long long ll; const int N 2e5 10…

微信小程序设计之主体文件app-json-pages

一、新建一个项目 首先&#xff0c;下载微信小程序开发工具&#xff0c;具体下载方式可以参考文章《微信小程序开发者工具下载》。 然后&#xff0c;注册小程序账号&#xff0c;具体注册方法&#xff0c;可以参考文章《微信小程序个人账号申请和配置详细教程》。 在得到了测…

【2021集创赛】Digilent杯二等奖:基于FPGA的动态视觉感知融合的运动目标检测系统

杯赛题目&#xff1a;Diligent杯&#xff1a;基于FPGA开源软核的硬件加速智能平台 参赛组别&#xff1a;A组 设计任务&#xff1a; 利用业界主流软核处理器(仅限于Cortex-M系列及 RISC-V系列)在限定的DIGILENT官方FPGA平台上构建SoC片上系统&#xff0c;在 SoC中添加面向智能应…

Python数据挖掘 | 升级版自动查核酸

&#x1f4d5;作者简介&#xff1a;热爱跑步的恒川&#xff0c;致力于C/C、Java、Python等多编程语言&#xff0c;热爱跑步&#xff0c;喜爱音乐的一位博主。 &#x1f4d7;本文收录于恒川的日常汇报系列&#xff0c;大家有兴趣的可以看一看 &#x1f4d8;相关专栏C语言初阶、C…

GoLong的学习之路(五)语法之数组

书接上回&#xff0c;上回书说到&#xff0c;循环语句&#xff0c;在go中循环语句的少了whlie这个关键词&#xff0c;但是与之for可以改这个改这个特点。并且在终止关键词中&#xff0c;又有标签可以方便&#xff0c;停止。这次说数组 文章目录 Array(数组)数组的初始化方法一方…