【图像分类】【深度学习】【Pytorch版本】ResNet模型算法详解

【图像分类】【深度学习】【Pytorch版本】 ResNet模型算法详解

文章目录

  • 【图像分类】【深度学习】【Pytorch版本】 ResNet模型算法详解
  • 前言
  • ResNet讲解
    • Deep residual learning framework(深度残差学习框架)
    • 残差结构(Residuals)
    • ResNet模型结构
  • ResNet Pytorch代码
  • 完整代码
  • 总结


前言

ResNet是微软研究院的He, Kaiming等人在《Deep Residual Learning for Image Recognition【CVPR-2016】》【论文地址】一文中提出的模型,通过残差模块发现并解决深层网络的退化问题,使更深的卷积神经网络的训练成为可能,大大提升神经网络深度。

网络的退化问题:随着网络层数的增加,模型性能反而变得更差的现象。


ResNet讲解

神经网络将一个端到端的多层模型中的低/中/高级特征以及分类器整合起来,特征的等级可以通过所堆叠层的数量(深度)来丰富,模型的深度发挥着至关重要的作用。
训练一个更好的网络是否和堆叠更多的层一样简单呢?解决这一问题的障碍是梯度消失/梯度爆炸 (退化问题),这阻碍了模型的收敛。传统解决办法通过归一初始化(normalized initialization) 和**中间归一化(intermediate normalization)**在很大程度上解决了这一问题,它使得数十层的网络在反向传播的随机梯度下降上能够收敛。
但是当更深的网络能够开始收敛时又暴露了新的退化问题:随着网络深度的增加,准确率达到饱和后迅速下降。这种下降并不是由过拟合引起的,并且在深度适当的模型上添加更多的网络层会导致更高的训练误差。

论文中提到了一种构建神经网络的理论(思路):假设浅层网络已经取得不错的效果,新增加的网络层啥也不干,只是拟合一个恒等映射(identity mapping),就是新增网络层的输出就拟合新增网络层的输入,这样构建的深层网络至少不应该比没有新增网络层的浅层网络的训练误差要高。该理论在当时没能用实验进行证明。

Deep residual learning framework(深度残差学习框架)

依据:一个更深的模型不应当产生比它的浅层版本更高的训练错误率。
为了解决退化问题,作者在该论文中提出了“深度残差学习框架”的网络。在该框架中,每个堆叠网络块(building block) 拟合残差映射(Residual mapping)

相较于残差结构(Plaint net),将以前的结构称之为非残差结构(Residual net)。

网络块由若干卷积层组成,非残差结构直接拟合网络块期望的基础映射(Underlying mapping) H ( X ) H(X) H(X),基础映射即将当前网络块的输入与输出之间的映射。残差结构同样是拟合网络块的基础映射 H ( X ) H(X) H(X),但网络块新增了一条shortcut分支表示恒等映射/投影(identity) X X X,网络块的主分支(堆叠的若干卷积层)拟合的就是残差映射 F ( X ) = H ( X ) − X F(X)=H(X)-X F(X)=H(X)X,没有像非残差结构一样直接对 H ( X ) H(X) H(X)进行拟合。

恒等快捷连接既不增加额外的参数也不增加计算复杂度

在极端情况下,假设 X X X是最优的,不需要后续网络块处理,残差结构将主分支 F ( X ) F(X) F(X)置为零比非残差结构将 H ( X ) H(X) H(X)直接拟合成原始输入的 X X X更容易,即使 F ( X ) F(X) F(X)不可能得到理想的置零,但是接近于零,也足以缓解了退化问题。

论文实验证明,优化残差映射比优化原始的基础映射容易很多,残差结构确实比其他结构更易学习!

残差结构(Residuals)

正是因为提出了残差模块,可以搭建更深的网络,论文中提到了俩种残差结构,左边的图主要是针对层数较少的网络使用的残差结构,右边的图是针对层数较多的网络使用的残差结构。

残差结构在主路上经过一系列的卷积层之后输出的特征图再与输入特征图进行一个相加的操作,主分支与shortcut分支的特征图在相同的维度上做相加的操作,之后通过激活函数输出,这就要求主分支与shortcut的输出特征图形状必须完全一致。

1×1卷积是为了升维和降维

随着网络层深度增加时,输出的特征图的尺寸会减小而通道数会增加,部分残差结构的输入特征图和输出特征图的形状不再保持一致,为了应对这种情况,会在shortcut分支上新增一层1×1卷积层,目的是为了残差结构的主分支输出特征图和shortcut分支输出特征图的形状再次保持一致。

ResNet模型结构

下图是原论文给出的关于ResNet模型结构的详细示意图:
Resnet在图像分类中分为两部分:backbone部分: 主要由残差结构、卷积层和池化层(汇聚层)组成,分类器部分:由全连接层组成 。
论文通过前文提到的两种残差结构的堆叠搭建了5种网络,分别为Resnet-18、Resnet-34、Resnet-50、Resnet-101和Resnet-152。
下图以Resnet-34为例,根据Vgg的启发构建卷积神经网络:

对于输出特征图尺寸相同的层,具有相同数量的卷积核;当特征图尺寸减半则卷积核数量加倍以保持每层网络计算的时间复杂度;主要通过步长为2的卷积层直接执行下采样。

当输入和输出具有相同形状(实线shortcut),可以直接使用恒等快捷连接;当输入和输出具有不同形状时(虚线shortcut),新增一层1×1卷积层用于配准形状。


ResNet Pytorch代码

残差结构BasicBlock(小网络): 卷积层+BN层+激活函数

class BasicBlock(nn.Module):# 小网络的输入输出的channel是保存一致的expansion = 1def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):super(BasicBlock, self).__init__()# 第一层卷积层self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channel)self.relu = nn.ReLU()# 第二层卷积层self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channel)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)# 第一层卷积层out = self.conv1(x)out = self.bn1(out)out = self.relu(out)# 第二层卷积层out = self.conv2(out)out = self.bn2(out)out += identityout = self.relu(out)return out

残差结构Bottleneck(大网络): 卷积层+BN层+激活函数

class Bottleneck(nn.Module):# 大网络的输出channel是输入channel的4倍expansion = 4def __init__(self, in_channel, out_channel, stride=1, downsample=None):super(Bottleneck, self).__init__()# 第一层卷积层self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,kernel_size=1, stride=1, bias=False)  # 压缩channelself.bn1 = nn.BatchNorm2d(out_channel)# 第二层卷积层self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,kernel_size=3, stride=stride, bias=False, padding=1)self.bn2 = nn.BatchNorm2d(out_channel)# 第三层卷积层self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel*self.expansion,kernel_size=1, stride=1, bias=False)  # 扩展channelself.bn3 = nn.BatchNorm2d(out_channel*self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)# 第一层卷积层out = self.conv1(x)out = self.bn1(out)out = self.relu(out)# 第二层卷积层out = self.conv2(out)out = self.bn2(out)out = self.relu(out)# 第三层卷积层out = self.conv3(out)out = self.bn3(out)out += identityout = self.relu(out)return out

downsample(输入和输出形状不同): 卷积层+BN层

downsample = nn.Sequential(nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(channel * block.expansion))

完整代码

import torch.nn as nn
import torch
from torchsummary import summaryclass BasicBlock(nn.Module):# 小网络的输入输出的channel是保存一致的expansion = 1def __init__(self, in_channel, out_channel, stride=1, downsample=None):super(BasicBlock, self).__init__()# 第一层卷积层self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channel)self.relu = nn.ReLU()# 第二层卷积层self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channel)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)# 第一层卷积层out = self.conv1(x)out = self.bn1(out)out = self.relu(out)# 第二层卷积层out = self.conv2(out)out = self.bn2(out)out += identityout = self.relu(out)return outclass Bottleneck(nn.Module):# 大网络的输出channel是输入channel的4倍expansion = 4def __init__(self, in_channel, out_channel, stride=1, downsample=None):super(Bottleneck, self).__init__()# 第一层卷积层self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,kernel_size=1, stride=1, bias=False)  # 压缩channelself.bn1 = nn.BatchNorm2d(out_channel)# 第二层卷积层self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,kernel_size=3, stride=stride, bias=False, padding=1)self.bn2 = nn.BatchNorm2d(out_channel)# 第三层卷积层self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel*self.expansion,kernel_size=1, stride=1, bias=False)  # 扩展channelself.bn3 = nn.BatchNorm2d(out_channel*self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)# 第一层卷积层out = self.conv1(x)out = self.bn1(out)out = self.relu(out)# 第二层卷积层out = self.conv2(out)out = self.bn2(out)out = self.relu(out)# 第三层卷积层out = self.conv3(out)out = self.bn3(out)out += identityout = self.relu(out)return outclass ResNet(nn.Module):def __init__(self,block,blocks_num,num_classes=1000):super(ResNet, self).__init__()self.in_channel = 64self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,padding=3, bias=False)self.bn1 = nn.BatchNorm2d(self.in_channel)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# 第一组残差块组self.layer1 = self._make_layer(block, 64, blocks_num[0])# 第二组残差块组self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)# 第三组残差块组self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)# 第四组残差块组self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')def _make_layer(self, block, channel, block_num, stride=1):downsample = None# 残差块组中的第一个残差块'''第一个残差块组是不执行下采样的,但是小网络的输入输出是channel一致,大网络输出是输入channel的四倍因此shortcut是否需要卷积层的判断条件不再是stride,而是channelchannel不一致则是大网络的Bottleneck,第一个残差块需要卷积层调整维度'''if stride != 1 or self.in_channel != channel * block.expansion:downsample = nn.Sequential(nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(channel * block.expansion))layers = []layers.append(block(self.in_channel,channel,downsample=downsample,stride=stride))self.in_channel = channel * block.expansion# 残差块组中的剩余残差块for _ in range(1, block_num):layers.append(block(self.in_channel,channel))return nn.Sequential(*layers)def forward(self, x):# backbone主干网络部分# resnet34为例# N x 3 x 224 x 224x = self.conv1(x)# N x 64 x 112 x 112x = self.bn1(x)# N x 64 x 112 x 112x = self.relu(x)# N x 64 x 112 x 112x = self.maxpool(x)# N x 64 x 56 x 56# 第一组残差块组x = self.layer1(x)# N x 64 x 56 x 56# 第二组残差块组x = self.layer2(x)# N x 128 x 28 x 28# 第三组残差块组x = self.layer3(x)# N x 256 x 14 x 14# 第四组残差块组x = self.layer4(x)# N x 512 x 7 x 7# 分类器部分x = self.avgpool(x)# N x 512 x 1 x 1x = torch.flatten(x, 1)# N x 512x = self.fc(x)# N x 1000return xdef resnet18(num_classes=1000):# https://download.pytorch.org/models/resnet34-333f7ec4.pthreturn ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)def resnet34(num_classes=1000):# https://download.pytorch.org/models/resnet34-333f7ec4.pthreturn ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes)def resnet50(num_classes=1000):# https://download.pytorch.org/models/resnet50-19c8e357.pthreturn ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes)def resnet101(num_classes=1000):# https://download.pytorch.org/models/resnet101-5d3b4d8f.pthreturn ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes)def resnet152(num_classes=1000):# https://download.pytorch.org/models/resnet101-5d3b4d8f.pthreturn ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes)if __name__ == '__main__':device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = resnet34().to(device)summary(model, input_size=(3, 224, 224))

summary可以打印网络结构和参数,方便查看搭建好的网络结构。


总结

尽可能简单、详细的介绍了残差结构的原理和在卷积神经网络中的作用,讲解了ResNet模型的结构和pytorch代码。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/153082.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【练习】检测U盘并自动复制内容到电脑的软件

软件作用: 有U盘插在电脑上后,程序会检测到U盘的路径。 自己可以提前设置一个保存复制文件的路径或者使用为默认保存的复制路径(默认为桌面,可自行修改)。 检测到U盘后程序就会把U盘的文件复制到电脑对应的…

PyTorch微调终极指南1:预训练模型调整

如今,在训练深度学习模型时,通过根据自己的数据微调预训练模型来进行迁移学习(transfer learning)已成为首选方法。 通过微调这些模型,我们可以利用他们的专业知识并使它们适应我们的特定任务,从而节省宝贵…

【miniQMT实盘量化4】获取实时行情数据

前言 上篇,我们介绍了如何获取历史数据,有了历史数据,我们可以进行分析和回测。但,下一步,我们更需要的是实时数据,只有能有效的监控实时行情数据,才能让我们变成市场上的“千里眼,…

从0开始学习JavaScript--深入探究JavaScript类型化数组

JavaScript类型化数组是一种特殊的数组类型,引入了对二进制数据的更底层的操作。这种数组提供了对内存中的二进制数据直接进行读写的能力,为处理图形、音频、视频等大规模数据提供了高效的手段。本文将深入探讨JavaScript类型化数组的基本概念、常见类型…

场景交互与场景漫游-对象选取(8-2)

对象选取示例的代码如程序清单8-11所示: /******************************************* 对象选取示例 *************************************/ // 对象选取事件处理器 class PickHandler :public osgGA::GUIEventHandler { public:PickHandler() :_mx(0.0f), _my…

48. 旋转图像

给定一个 n n 的二维矩阵 matrix 表示一个图像。请你将图像顺时针旋转 90 度。 你必须在 原地 旋转图像,这意味着你需要直接修改输入的二维矩阵。请不要 使用另一个矩阵来旋转图像。 示例 1: 输入:matrix [[1,2,3],[4,5,6],[7,8,9]] 输出&…

用平板当电脑副屏(spacedesk)双端分享

文章目录 1.准备工作2.操作流程1. 打开spacedesk点击2. 勾选USB Cable Android3. 用数据线连接移动端和pc端,选择仅充电4. 打开安装好的spacedesk 记得在win系统中设置扩展显示器: 1.准备工作 下载软件spacedesk Driver Console pc端: 移动…

macos苹果电脑清理软件有哪些?cleanmymac和腾讯柠檬哪个好

MacOS是一款优秀的操作系统,但是随着使用时间的增加,它也会产生一些不必要的垃圾文件,占用磁盘空间和内存资源,影响系统的性能和稳定性。为了保持MacOS的清洁和高效,我们需要使用一些专业的清理软件来定期扫描和清除这…

Pandas数据集的合并与连接merge()方法_Python数据分析与可视化

数据集的合并与连接 merge()解析merge()的主要参数 merge()解析 merge()可根据一个或者多个键将不同的DataFrame连接在一起,类似于SQL数据库中的合并操作。 数据连接的类型 一对一的连接: df1 pd.DataFrame({employee: [Bob, Jake, Lisa, Sue], grou…

【Linux】:体系结构与进程概念

朋友们、伙计们,我们又见面了,本期来给大家解读一下有关Linux体系结构和进程的知识点,如果看完之后对你有一定的启发,那么请留下你的三连,祝大家心想事成! C 语 言 专 栏:C语言:从入…

腾讯云轻量数据库1核1G性能测评、租用费用和详细介绍

腾讯云轻量数据库服务采用腾讯云自研的新一代云原生数据库 TDSQL-C,融合了传统数据库、云计算与新硬件技术的优势,100%兼容 MySQL,实现超百万级 QPS 的高吞吐,128TB 海量分布式智能存储,保障数据安全可靠。腾讯云百科t…

机器人制作开源方案 | 智能照科植物花架

作者:付菲菲、于海鑫、王子敏单位:黑河学院指导老师:索向峰、李岩 1. 概述 1.1设计背景​ 随着时代的发展,城市化脚步加快、城市人口密度越来越大、城市生活节奏快压力大作息难成规律。城市建筑建筑面积迅速增加、而绿…

Leetcode—5.最长回文子串【中等】

2023每日刷题(三十五) Leetcode—5.最长回文子串 中心扩展法算法思想 可以使用一种叫作“中心扩展法”的算法。由回文的性质可以知道,回文一定有一个中心点,从中心点向左和向右所形成的字符序列是一样的,并且如果字符…

Vue移动 HTML 元素到指定位置 teleport 标签

teleport 标签&#xff1a;用于将组件中的 HTML 元素移动到任意的位置。 使用 teleport 标签移动 HTML 元素&#xff1a; <!-- 将 teleport 中的内容移动到 body 标签中 --> <teleport to"body"><div><h3>我是第三层组件的标题</h3>…

如何使用http来获取thingsbord中的设备数据

背景 有个读者问我,他想做tb的二次开发,想要通过一个接口来查询设备的遥测数据。 于是我给他写了这篇文章。 具体实现 由于他使用的是cloud版本,于是我使用cloud来做演示 文档的接口 https://thingsboard.cloud/swagger-ui/#/telemetry-controller/getTimeseriesUsing…

V100 GPU服务器安装CUDA教程

大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。喜欢通过博客创作的方式对所学的…

Linux C 线程间同步机制

线程间同步机制 概述保护机制互斥锁创建互斥锁  pthread_mutex_init加锁  pthread_mutex_lock解锁  pthread_mutex_unlock删除锁  pthread_mutex_destroy 条件变量创建条件变量  pthread_cond_init激活条件变量  pthread_cond_signal等待条件变量  pthread_cond_…

python实现炫酷的屏幕保护程序

shigen日更文章的博客写手&#xff0c;擅长Java、python、vue、shell等编程语言和各种应用程序、脚本的开发。记录成长&#xff0c;分享认知&#xff0c;留住感动。 上次的文章如何实现一个下班倒计时程序的阅读量很高&#xff0c;觉得也很实用酷炫&#xff0c;下边是昨天的体验…

QT5 MSVC2017 64bit配置OpenCV4.5无需编译与示范程序

环境&#xff1a;Windows 10 64位 Opencv版本&#xff1a;4.5 QT&#xff1a;5.14 QT5 MSVC2017配置OpenCV 版本参考&#xff1a; opencv msvc c对应版本 1.安装MSVC2017&#xff08;vs2017&#xff09; 打开Visual Studio Installer&#xff0c;点击修改 选择vs2017生成工…

java使用 TCP 的 Socket API 实现客户端服务器通信

一&#xff1a;什么是 Socket(套接字) Socket 套接字是由系统提供于网络通信的技术, 是基于 TCP/IP 协议的网络通信的基本操作&#xff0c;要进行网络通信, 需要有一个 socket 对象, 一个 socket 对象对应着一个 socket 文件, 这个文件在 网卡上而不是硬盘上, 所以有了 sokcet…