分类神经网络2:ResNet模型复现

目录

ResNet网络架构

ResNet部分实现代码


ResNet网络架构

论文原址:https://arxiv.org/pdf/1512.03385.pdf

残差神经网络(ResNet)是由微软研究院的何恺明、张祥雨、任少卿、孙剑等人提出的,通过引入残差学习解决了深度网络训练中的退化问题,同时该网络结构特点主要表现为3点,超深的网络结构(超过1000层)、提出residual(残差结构)模块和使用Batch Normalization 加速训练(丢弃dropout)。ResNet网络的模型结构如下:

ResNet网络通过“跳跃连接”,将靠前若干层的某一层数据输出直接跳过多层引入到后面数据层的输入部分。即神经网络学习到该层的参数是冗余的时候,它可以选择直接走这条“跳接”曲线(shortcut connection),跳过这个冗余层,而不需要再去拟合参数使得H(x)=F(x)=x。同时通过这种连接方式不仅保护了信息的完整性(避免卷积层堆叠存在的信息丢失),整个网络也只需要学习输入、输出差别的部分,这克服了由于网络深度加深而产生的学习效率变低与准确率无法有效提升的问题(梯度消失或梯度爆炸)。

残差模块如下图示:

残差结构有两种,常规残差和瓶颈残差常规残差:由2个3x3卷积层堆叠而成,当输入和输出维度一致时,可以直接将输入加到输出上,这相当于简单执行了同等映射,不会产生额外的参数,也不会增加计算复杂度(随着网络深度的加深,这种残差模块在实践中并不十分有效);瓶颈残差:依次由1x1 、3x3 、1x1个卷积层构成,这里1x1卷积,能够对通道数channel起到升维或降维的作用,从而令3x3 的卷积,以相对较低维度的输入进行卷积运算,提高计算效率。

ResNet网络的具体配置信息如下:

在构建神经网络时,首先采用了步长为2的卷积层进行图像尺寸缩减,即下采样操作,紧接着是多个残差结构,在网络架构的末端,引入了一个全局平均池化层,用于整合特征信息,最后是一个包含1000个类别的全连接层,并在该层后应用了softmax激活函数以进行多分类任务。值得注意的是,通过引入残差连接模块,其最深的网络结构达到了152层,同时在50层后均使用的是瓶颈残差结构。

ResNet部分实现代码

老样子,直接上代码,建议大家阅读代码时结合网络结构理解

import torch
import torch.nn as nn__all__ = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']class ConvBNReLU(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):super(ConvBNReLU, self).__init__()self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)self.bn = nn.BatchNorm2d(num_features=out_channels)self.relu = nn.ReLU(inplace=True)def forward(self, x):x = self.conv(x)x = self.bn(x)x = self.relu(x)return xclass ConvBN(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):super(ConvBN, self).__init__()self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)self.bn = nn.BatchNorm2d(num_features=out_channels)def forward(self, x):x = self.conv(x)x = self.bn(x)return xdef conv3x3(in_channels, out_channels, stride=1, groups=1, dilation=1):"""3x3 convolution with padding:捕捉局部特征和空间相关性,学习更复杂的特征和抽象表示"""return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,padding=dilation, groups=groups, bias=False, dilation=dilation)def conv1x1(in_channels, out_channels, stride=1):"""1x1 convolution:实现降维或升维,调整通道数和执行通道间的线性变换"""return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)class BasicBlock(nn.Module):expansion = 1def __init__(self, in_channels, out_channels, stride=1, downsample=None):super(BasicBlock, self).__init__()self.convbnrelu1 = ConvBNReLU(in_channels, out_channels, kernel_size=3, stride=stride)self.convbn1 = ConvBN(out_channels, out_channels, kernel_size=3)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = strideself.conv_down = nn.Sequential(conv1x1(in_channels, out_channels * self.expansion, self.stride),nn.BatchNorm2d(out_channels * self.expansion),)def forward(self, x):residual = xout = self.convbnrelu1(x)out = self.convbn1(out)if self.downsample:residual = self.conv_down(x)out += residualout = self.relu(out)return outclass Bottleneck(nn.Module):expansion = 4def __init__(self, in_channels, out_channels, stride=1, downsample=None):super(Bottleneck, self).__init__()groups = 1base_width = 64dilation = 1width = int(out_channels * (base_width / 64.)) * groups   # wide = out_channelsself.conv1 = conv1x1(in_channels, width)       # 降维通道数self.bn1 = nn.BatchNorm2d(width)self.conv2 = conv3x3(width, width, stride, groups, dilation)self.bn2 = nn.BatchNorm2d(width)self.conv3 = conv1x1(width, out_channels * self.expansion)   # 升维通道数self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = strideself.conv_down = nn.Sequential(conv1x1(in_channels, out_channels * self.expansion, self.stride),nn.BatchNorm2d(out_channels * self.expansion),)def forward(self, x):residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)if self.downsample:residual = self.conv_down(x)out += residualout = self.relu(out)return outclass ResNet(nn.Module):def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,groups=1, width_per_group=64):super(ResNet, self).__init__()self.inplanes = 64self.dilation = 1replace_stride_with_dilation = [False, False, False]self.groups = groupsself.base_width = width_per_groupself.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(self.inplanes)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, layers[0])self.layer2 = self._make_layer(block, 128, layers[1], stride=2,dilate=replace_stride_with_dilation[0])self.layer3 = self._make_layer(block, 256, layers[2], stride=2,dilate=replace_stride_with_dilation[1])self.layer4 = self._make_layer(block, 512, layers[3], stride=2,dilate=replace_stride_with_dilation[2])self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)# Zero-initialize the last BN in each residual branch,# so that the residual branch starts with zeros, and each residual block behaves like an identity.# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677if zero_init_residual:for m in self.modules():if isinstance(m, Bottleneck):nn.init.constant_(m.bn3.weight, 0)elif isinstance(m, BasicBlock):nn.init.constant_(m.bn2.weight, 0)def _make_layer(self, block, planes, blocks, stride=1, dilate=False):downsample = Falseif dilate:self.dilation *= stridestride = 1if stride != 1 or self.inplanes != planes * block.expansion:downsample = Truelayers = nn.ModuleList()layers.append(block(self.inplanes, planes, stride, downsample))self.inplanes = planes * block.expansionfor _ in range(1, blocks):layers.append(block(self.inplanes, planes))return layersdef forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)for layer in self.layer1:x = layer(x)for layer in self.layer2:x = layer(x)for layer in self.layer3:x = layer(x)for layer in self.layer4:x = layer(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return xdef resnet18(num_classes, **kwargs):return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, **kwargs)def resnet34(num_classes, **kwargs):return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, **kwargs)def resnet50(num_classes, **kwargs):return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, **kwargs)def resnet101(num_classes, **kwargs):return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, **kwargs)def resnet152(num_classes, **kwargs):return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, **kwargs)if __name__=="__main__":import torchsummarydevice = 'cuda' if torch.cuda.is_available() else 'cpu'input = torch.ones(2, 3, 224, 224).to(device)net = resnet50(num_classes=4)net = net.to(device)out = net(input)print(out)print(out.shape)torchsummary.summary(net, input_size=(3, 224, 224))# Total params: 134,285,380

希望对大家能够有所帮助呀!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/1510.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

IO实现方式(同步阻塞、同步非阻塞、IO多路复用)

1. 同步阻塞IO 同步阻塞io在数据在数据拷贝到两个阶段都是阻塞的,即把socket的数据拷贝到内核缓冲区和把内核缓冲区的数据拷贝到用户态到应用程序缓冲区都是阻塞的。用户线程在这个期间不能处理其他任务。 优点:简单易用 缺点:为每一次io请…

VScode远程连接虚拟机提示: 无法建立连接:XHR failed.问题解决方案

一问题描述 在vscode下载插件Remote-SSH远程连接虚拟机时提示无法建立连接 二.最大嫌疑原因: 我也是在网上找了许久,发现就是网络原因,具体不知,明明访问别的网页没问题,就是连不上,然后发现下载vscode的…

【电赛】自制模块2——偏置变幅器

一、理论基础 模电学习笔记——集成运算放大器https://mp.csdn.net/mp_blog/creation/editor/134449862 运放单双电源转换/运放单双电源供电详解https://mp.csdn.net/mp_blog/creation/editor/135884117 通过改变R4或R5的阻值改变正弦波的振幅。 根据公式 A表示放大倍数。 …

MAC上如何将某个目录制作成iso格式磁盘文件,iso文件本质是什么?以及挂载到ParallelDesktop中?(hdiutil makehybrid )

背景 ParallelsDesktop没有安装ParallelsTools的无法共享目录,可以通过ParallelsDesktop提供CD磁盘的方式共享进去 命令 # 准备文档 mkdir mytestdir cp xxx mytestdir# 生成iso hdiutil makehybrid -o output.iso mytestdir -iso -joliethdiutil是MAC提供的磁盘…

Oracle中的视图

1- 什么是视图 视图是一个虚拟表 视图是由sql查询语句产生的 视图真实存在 但是不存储数据 视图中的数据 只是对 基表(源数据表) 中的数据的引用 总的来说 视图可以简化数据 用户,订单,物流 三个表进行关联 吧很复杂的sql查询语句存储成一个视图 …

Jmeter-非GUI模式下运行jmeter脚本-适用于服务器上持续集成测试

背景 大部分Jmeter脚本都是部署在Linux上运行,利用Jenkins做接口自动化,定时巡检任务。 执行命令 1.进入jmeter的目录,bin文件夹 cd C:\path\to\jmeter\bin2.运行脚本文件 jmeter -n -t D:\{脚本文件目录}\xxx.jmx -l D:\{脚本文件目录}…

信息系统项目管理师0061:架构设计(5信息系统工程—5.1软件工程—5.1.1架构设计)

第五章 信息系统工程 信息系统工程是用系统工程的原理、方法来指导信息系统建设与管理的一门工程技术学科,它是信息科学、管理科学、系统科学、计算机科学与通信技术相结合的综合性、交叉性、具有独特风格的应用学科。当前信息系统工程的主要任务是研究信息处理过程内在的规律…

Java中的BIO、NIO与AIO

1.概述 I/O 模型简单的理解:就是用什么样的通道进行数据的发送和接收,很大程度上决定了程序通信的性能。Java 共支持 3 种网络编程模型 I/O 模式:BIO、NIO、AIO。 2.Java BIO Java BIO(Blocking I/O):是传统的java io 编程&#…

密钥密码学(二)

原文:annas-archive.org/md5/b5abcf9a07e32fc6f42b907f001224a1 译者:飞龙 协议:CC BY-NC-SA 4.0 第十章:可变长度分数化 本章涵盖 基于摩尔斯电码的密码 混合字母和双字母 可变长度二进制码字 基于文本压缩的密码 本章涵盖…

【嵌入式】keil5安装(同时兼容C51和STM32)

最近在开发STM32的时候,安装Keil5,遇到STM32和C51的共存的问题,在网上找了很多方法,又遇到一些bug,最终还是弄好了。因此将处理的过程记录下来,希望对遇到相同问题的朋友一些启发。 1、下载安装包 Keil P…

新牛市新方向:探索加密货币生态的未来

序章:牛市来袭,新的探索 新的牛市来临,带来了加密货币世界的一次次惊喜。比特币、以太坊、Solana等生态系统在这场盛宴中展现出各自的独特魅力,带来了一场场引人入胜的探索之旅。让我们跟随着这些生态系统的脚步,一起…

基础算法前缀和与差分

前言 本次博客会介绍一维和二维的前缀和,以及一维二维差分的基本使用,尽量画图,多使用配合文字 使大家理解,希望有所帮助吧 一维前缀和 问题描述 这里有一个长度为n的数组,我们要算出【2,5】区间的元素和 暴力思…

Mogdb 5.0新特性:SQL PATCH绑定执行计划

前言 熟悉Oracle的dba都知道,生产系统出现性能问题时,往往是SQL走错了执行计划,紧急情况下,无法及时修改应用代码,dba可以采用多种方式针对于某类SQL进行执行计划绑定,比如SQL Profile、SPM、SQL Plan Base…

Linux——网络管理nmcli

nmcli 不能独立使用,需要对应的服务启动 1. NetworkManager.service 2. 网络配置和服务不相关 3. 通过 nmcl i 建立网络配置和网卡之前的映射关系 网卡 简称:nmcli d DEVICE :物理设备 TYPE: 物理设备类型 ethernet 以太网…

C++设计模式:适配器模式(十四)

1、定义与动机 定义:将一个类的接口转换成客户希望的另外一个接口。Adapter模式使得原本由于接口不兼容而不能一起工作的哪些类可以一起工作。 动机: 在软件系统中,由于应用环境的变化,常常需要将“一些现存的对象”放在新的环境…

强固型工业电脑在码头智能闸口、OCR(箱号识别)、集装箱卡车车载电脑行业应用

集装箱卡车车载电脑应用 背景介绍 针对码头集装箱卡车的调度运用, 结合码头TOS系统设计出了各种平台的车载电脑(VT系列)和车载LED显示屏(VLD系列),同时提供各种安装支架,把车载电脑固定到狭小的驾驶室中;同时提供了各种天线选择(…

【JVM常见问题总结】

文章目录 jvm介绍jvm内存模型jvm内存分配参数jvm堆中存储对象:对象在堆中创建分配内存过程 jvm 堆垃圾收集器垃圾回收算法标记阶段引用计数算法可达性分析算法 清除阶段标记清除算法复制算法标记压缩算法 实际jvm参数实战jvm调优jvm常用命令常用工具 jvm介绍 Java虚…

高速公路交通运输大数据平台解决方案

前言 交通运输行业面临着多重挑战。其管控困难,涉及广泛地理范围,导致监控成本高且难以及时响应;同时,行业内数据量大,地理信息数据繁多,缺乏高效的可视化工具来揭示数据规律并优化业务;货运和…

回溯算法-组合问题

回溯算法-组合问题 77. 组合 问题描述 给定两个整数 n 和 k,返回范围 [1, n] 中所有可能的 k 个数的组合。 你可以按 任何顺序 返回答案。 示例 1: 输入:n 4, k 2 输出: [[2,4],[3,4],[2,3],[1,2],[1,3],[1,4], ]示例 2&a…

05集合-CollectionListSet

Collection体系的特点、使用场景总结 如果希望元素可以重复,又有索引,索引查询要快? 用ArrayList集合, 基于数组的。(用的最多) 如果希望元素可以重复,又有索引,增删首尾操作快? 用LinkedList集合, 基于链表的。 如果希望增…