Resnet模型详解

1、Resnet是什么?

Resnet是一种深度神经网络架构,被广泛用于计算机视觉任务,特别是图像分类。它是由微软研究院的研究员于2015年提出的,是深度学习领域的重要里程碑之一。

2、网络退化问题

理论上来讲,随着网络的层数的增加,网络能够进行更加复杂的特征提取,可以取得更好的结果。但是实验发现深度网络出现了退化问题,如下图所示。网络深度增加时,网络准确度出现饱和,之后甚至还快速下降。而且这种下降不是因为过拟合引起的,而是因为在适当的深度模型上添加更多的层会导致了更高的训练误差,从而使其下降。

图1 网络深度对比(来源:Resnet的论文)

当你使用深度神经网络进行训练时,网络层可以被看作是一系列的函数堆叠,每个函数代表一个网络层的操作,这里我们就记作f(x)。在反向传播过程中,梯度是通过链式法则逐层计算得出的。假设每个操作的梯度都小于1,因为多个小于1的数相乘可能会导致结果变得更小。在神经网络中,随着反向传播的逐层传递,梯度可能会逐渐变得非常小,甚至接近于零,这就是梯度消失问题。

而如果经过网络层操作后的输出值大于1,那么反向传播时梯度可能会相应地增大。这种情况下,梯度爆炸问题可能会出现。梯度爆炸问题指的是在深度神经网络中,梯度逐渐放大,导致底层网络的参数更新过大,甚至可能导致数值溢出。

3、残差结构

在ResNet提出之前,所有的神经网络都是通过卷积层和池化层的叠加组成的。所以,Resnet对后面计算机视觉的发展影响是巨大的。

图2 残差结构(来源:Resnet的论文)

它这里完成的一个很简单的过程,我先举一个例子:

想象一张经过神经网络处理后的低分辨率图像。为了提高图像的质量,我们引入了一个创新的思想:将原始高分辨率图像与低分辨率图像之间的差异提取出来,形成了一个残差图像。这个残差图像代表了低分辨率图像与目标高分辨率图像之间的差异或缺失的细节。

​图3 残差图像

 然后,我们将这个残差图像与低分辨率图像相加,得到一个结合了低分辨率信息和残差细节的新图像。这个新图像作为下一个神经网络层的输入,使网络能够同时利用原始低分辨率信息和残差细节信息进行更精确的学习。

图4 残差+低分辨率图像

通过这种方式,我们的神经网络能够逐步地从低分辨率图像中提取信息,并通过残差图像的相加操作将遗漏的细节加回来。这使得网络能够更有效地进行图像恢复或其他任务,提高了模型的性能和准确性。

我相信我已经成功表达了残差结构的思想和操作过程。其实这个思想也并非是resnet创新的,在我们过去的其他领域中早已有这种思想,ResNet将这一思想引入了计算机视觉领域,并在深度神经网络中的训练中取得了重要突破。这种创新在一定程度上解决了深层神经网络训练中的梯度消失和梯度爆炸问题,使得网络能够更深更准确地学习特征和表示。

4、Resnet网络结构

(1)对于相同的输出特征图尺寸,层具有相同数量的滤波器

(2)当feature map大小降低一半时,feature map的数量增加一倍【过滤器(可以看作是卷积核的集合)的数量增加一倍】,这保持了网络层的复杂度。然后通过步长为2的卷积层直接执行下采样。

网络结构具体如下图所示:

8b612e130e2ad5ac04f7bc4f8e1ed7dd.png

图5 左为VGG-19,中为34个参数层的简单网络,右为34个参数层的残差网络

ResNet相比普通网络每两层间增加了短路机制,这就形成了残差学习,其中实线表示快捷连接,虚线表示feature map数量发生了改变。

有两种情况需要考虑:

(1)输入和输出具有相同的维度时(对应实线部分):

直接使用恒等快捷连接

12534cd86e604f1782361149f07c5eda.png

(2)维度增加(当快捷连接跨越两种尺寸的特征图时,它们执行时步长为2):

①快捷连接仍然执行恒等映射,额外填充零输入以增加维度。这样就不会引入额外的参数。

②用下面公式的投影快捷连接用于匹配维度(由1×1卷积完成)

c44e781d717d4b2b8686c2bf4129d23e.png

论文中也提供了更详细的结构,如下图所示:

5、使用Pytorch实现Resnet

本来是按照论文手写的代码,但用的时候发现维度不匹配,用不了预训练权重,所以这里就照着pytorch源码进行了修改。

import torch
import torchvision
import torch.nn as nn
import torchsummary
from torch.hub import load_state_dict_from_urlmodel_urls = {"resnet18" : "https://download.pytorch.org/models/resnet18-f37072fd.pth","resnet34" : "https://download.pytorch.org/models/resnet34-b627a593.pth","resnet50" : "https://download.pytorch.org/models/resnet50-0676ba61.pth","resnet101" : "https://download.pytorch.org/models/resnet101-63fe2227.pth","resnet152" : "https://download.pytorch.org/models/resnet152-394f9c45.pth",
}cfgs = {"resnet18": [2, 2, 2, 2],"resnet34": [3, 4, 6, 3],"resnet50": [3, 4, 6, 3],"resnet101": [3, 4, 23, 3],"resnet152": [3, 8, 36, 3],
}def conv1x1(in_planes, out_planes, stride = 1):"""1x1 convolution"""return nn.Conv2d(in_planes, out_planes, kernel_size=(1,1), stride=(stride,stride), bias=False)
def conv3x3(in_planes, out_planes, stride= 1, groups=1, dilation=1):"""3x3 convolution with padding"""return nn.Conv2d(in_planes,out_planes,kernel_size=(3,3),stride=(stride,stride),padding=dilation,groups=groups,bias=False,dilation=(dilation,dilation))
class BasicBlock(nn.Module):expansion = 1def __init__(self,inplanes: int,planes,stride = 1,downsample = None,groups =1,base_width = 64,dilation= 1,norm_layer= None,):super(BasicBlock,self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dif groups != 1 or base_width != 64:raise ValueError("BasicBlock only supports groups=1 and base_width=64")if dilation > 1:raise NotImplementedError("Dilation > 1 not supported in BasicBlock")self.conv1 = conv3x3(inplanes, planes, stride)self.bn1 = norm_layer(planes)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(planes, planes)self.bn2 = norm_layer(planes)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return outclass Bottleneck(nn.Module):expansion = 4def __init__(self,inplanes,planes,stride = 1,downsample = None,groups = 1,base_width = 64,dilation = 1,norm_layer = None,):super(Bottleneck,self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dwidth = int(planes * (base_width / 64.0)) * groupsself.conv1 = conv1x1(inplanes, width)self.bn1 = norm_layer(width)self.conv2 = conv3x3(width, width, stride, groups, dilation)self.bn2 = norm_layer(width)self.conv3 = conv1x1(width, planes * self.expansion)self.bn3 = norm_layer(planes * self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return outclass ResNet(nn.Module):def __init__(self,block,layers,in_channels=3,num_classes = 1000,zero_init_residual = False,groups = 1,width_per_group = 64,replace_stride_with_dilation = None,):super(ResNet,self).__init__()norm_layer = nn.BatchNorm2dself._norm_layer = norm_layerself.num=num_classesself.inplanes = 64self.dilation = 1self.groups = groupsself.base_width = width_per_groupself.conv1 = nn.Conv2d(in_channels, self.inplanes, kernel_size=(7,7), stride=(2,2), padding=3, bias=False)self.bn1 = norm_layer(self.inplanes)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, layers[0])self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=False)self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=False)self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=False)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, self.num)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)if zero_init_residual:for m in self.modules():if isinstance(m, Bottleneck) and m.bn3.weight is not None:nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]elif isinstance(m, BasicBlock) and m.bn2.weight is not None:nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]def _make_layer(self,block, planes, blocks, stride = 1,dilate = False,):norm_layer = self._norm_layerdownsample = Noneprevious_dilation = self.dilationif dilate:self.dilation *= stridestride = 1if stride != 1 or self.inplanes != planes * block.expansion:downsample = nn.Sequential(conv1x1(self.inplanes, planes * block.expansion, stride),norm_layer(planes * block.expansion),)layers = []layers.append(block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer))self.inplanes = planes * block.expansionfor _ in range(1, blocks):layers.append(block(self.inplanes,planes,groups=self.groups,base_width=self.base_width,dilation=self.dilation,norm_layer=norm_layer,))return nn.Sequential(*layers)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return xdef resnet(in_channels, num_classes, mode='resnet50', pretrained=False):if mode == "resnet18" or mode == "resnet34":block = BasicBlockelse:block = Bottleneckmodel = ResNet(block, cfgs[mode], in_channels=in_channels, num_classes=num_classes)if pretrained:state_dict = load_state_dict_from_url(model_urls[mode], model_dir='./model', progress=True)  # 预训练模型地址if num_classes != 1000:num_new_classes = num_classesfc_weight = state_dict['fc.weight']fc_bias = state_dict['fc.bias']fc_weight_new = fc_weight[:num_new_classes, :]fc_bias_new = fc_bias[:num_new_classes]state_dict['fc.weight'] = fc_weight_newstate_dict['fc.bias'] = fc_bias_newmodel.load_state_dict(state_dict)return model

这种写法是按照先前VGG那样写的,这样有助于使用同一个model_urls和cfgs。

BasicBlock类中的init()函数是先定义网络架构,forward()的函数是前向传播,实现的功能就是残差块:

Bottleneck类是另一种blcok类型,同上,init()函数是预定义网络架构,forward函数是进行前向传播。该block中有三个卷积,分别是1x1,3x3,1x1,分别完成的功能就是维度压缩,卷积,恢复维度,所以bottleneck实现的功能就是对通道数进行压缩,再放大。

注意:这里的plane不再是输出的通道数,输出通道数应该就是plane*expansion,即4*plane。

在ss

  • resnet18: BasicBlock, [2, 2, 2, 2]
  • resnet34: BasicBlock, [3, 4, 6, 3]
  • resnet50: Bottleneck, [3, 4, 6, 3]
  • resnet101:  Bottleneck, [3, 4, 23, 3]
  • resnet152:  Bottleneck, [3, 8, 36, 3]

这个后面的结构是作者自己挑的一个参数,所以不用管它为什么。BasicBlock主要用于resnet18和34,Bottleneck用于resnet50,101和152。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/50885.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于springboot的社区生活缴费系统/基于javaweb的水电缴费系统

摘 要 网络的广泛应用给生活带来了十分的便利。所以把社区生活缴费管理与现在网络相结合,利用java语言建设社区生活缴费系统,实现社区生活缴费管理的信息化。则对于进一步提高社区生活缴费管理发展,丰富社区生活缴费管理经验能起到不少的促进…

使用Python搭建服务器公网展示本地电脑文件

文章目录 1.前言2.本地http服务器搭建2.1.Python的安装和设置2.2.Python服务器设置和测试 3.cpolar的安装和注册3.1 Cpolar云端设置3.2 Cpolar本地设置 4.公网访问测试5.结语 1.前言 Python作为热度比较高的编程语言,其语法简单且语句清晰,而且python有…

UE4/5Niagara粒子特效之Niagara_Particles官方案例:1.5->2.3

目录 之前的文章: 1.5 Blend Attributes by Value 发射器更新 粒子生成 粒子更新 2.1 Static Beams ​编辑 发射器更新: 粒子生成 粒子更新 2.2 Dynamic Beams 没有开始模拟前的效果是: 开始模拟后的效果是: 发射器更新 …

python模拟登入某平台+破解验证码

概述 python模拟登录平台,遇见验证码识别!用最简单的方法seleniumda破解验证码,来自动登录平台 详细 python用seleniumxpath模拟登录破解验证码 先随便找个小说平台用户登陆 - 书海小说网用户登陆 - 书海小说网用户登陆 - 书海小说网 准…

【网络安全】跨站脚本(xss)攻击

跨站点脚本(也称为 XSS)是一种 Web 安全漏洞,允许攻击者破坏用户与易受攻击的应用程序的交互。它允许攻击者绕过同源策略,该策略旨在将不同的网站彼此隔离。跨站点脚本漏洞通常允许攻击者伪装成受害者用户,执行用户能够…

七夕学算法

目录 P1031 [NOIP2002 提高组] 均分纸牌 原题链接 : 题面 : 思路 : 代码 : P1036 [NOIP2002 普及组] 选数 原题链接 : 题面 : 思路 : 代码 : P1060 [NOIP2006 普及组] 开心的金明 原题链接 : 题面 : 思路 : 01背包例题 : 代码 : P1100 高低位交换 原题…

Istio入门体验系列——基于Istio的灰度发布实践

导言:灰度发布是指在项目迭代的过程中用平滑过渡的方式进行发布。灰度发布可以保证整体系统的稳定性,在初始发布的时候就可以发现、调整问题,以保证其影响度。作为Istio体验系列的第一站,本文基于Istio的流量治理机制,…

如何批量加密PDF文件并设置不同密码 - 批量PDF加密工具使用教程

如果你正在寻找一种方法来批量加密和保护你的PDF文件,批量PDF加密工具是一个不错的选择。 它是一个体积小巧但功能强大的Windows工具软件,能够批量给多个PDF文件加密和限制,包括设置打印限制、禁止文字复制,并增加独立的打开密码。…

从零开始的Hadoop学习(一) | 大数据概念、特点、应用场景、发展前景

1. 大数据概念 大数据(Big Data):指 无法在一定时间范围 内用常规软件工具进行捕捉、管理和处理的数据集合,是需要新处理模式才能具有更强的决策力、洞察发现力和流程优化能力的 海量、高增长率和多样化 的 信息资产。 大数据主要解决,海量…

【业务功能篇73】分布式ID解决方案

业界实现方案 1. 基于UUID2. 基于DB数据库多种模式(自增主键、segment)3. 基于Redis4. 基于ZK、ETCD5. 基于SnowFlake6. 美团Leaf(DB-Segment、zkSnowFlake)7. 百度uid-generator() 1.基于UUID生成唯一ID UUID:UUID长度128bit,32个16进制字符,占用存储空…

保护函数返回的利器——Linux Shadow Call Stack

写在前面 提到内核栈溢出的漏洞缓解,许多朋友首先想到的是栈内金丝雀(Stack Canary)。今天向大家介绍一项在近年,于Android设备中新增,且默默生效的安全机制——影子调用栈:SCS(Shadow Call St…

elementUI moment 年月日转时间戳 时间限制

changeStartTime(val){debuggerthis.startT val// this.startTime parseInt(val.split(-).join())this.startTime moment(val).unix() * 1000 //开始时间毫秒if(this.endTime){this.endTime moment(this.endT).unix() * 1000 //结束时间毫秒if(this.startTime - this.endTi…

​山东省图书馆典藏《乡村振兴战略下传统村落文化旅游设计》鲁图中大许少辉博士八一新书

​山东省图书馆《乡村振兴战略下传统村落文化旅游设计》鲁图中大许少辉博士八一新书

Python豆瓣爬虫(最简洁的豆瓣250爬虫,随机选择电影)

案例背景 电影才是世界艺术,所以我一直想看完豆瓣250,那么就重新拾起我的爬虫知识。 以前刚学爬虫那啥也不会,python语法都没弄清楚,现在不一样了,能用最为简洁的代码写出爬虫250的代码。 代码实现 导入包&#xff…

多模态(文本、图片)数据融合模型(含公开数据集、文献及开源代码汇总)

多模态&#xff08;文本、图片&#xff09;数据融合模型&#xff08;含公开数据集、文献及开源代码汇总&#xff09; <center>多模态模型的应用跑代码普遍存在的问题 <center>多模态公开数据集<center>文献及开源代码 多模态模型的应用 多模态模型的应用按照…

单片机 (一) 让LED灯 亮

一&#xff1a;硬件电路图 二&#xff1a;软件代码 #include "reg52.h"#define LED_PORT P2void main() {LED_PORT 0x01; // 0000 0001 D1 是灭的 } #include "reg52.h" 这个头文件的作用&#xff1a;包含52 系列单片机内部所有的功能寄存器 三&#…

使用Termux在安卓手机上搭建Hexo博客网站,并发布到公网访问

文章目录 1. 安装 Hexo2. 安装cpolar内网穿透3. 公网远程访问4. 固定公网地址 Hexo 是一个用 Nodejs 编写的快速、简洁且高效的博客框架。Hexo 使用 Markdown 解析文章&#xff0c;在几秒内&#xff0c;即可利用靓丽的主题生成静态网页。 下面介绍在Termux中安装个人hexo博客并…

缓存穿透、缓存击穿和缓存雪崩

&#x1f44f;作者简介&#xff1a;大家好&#xff0c;我是爱发博客的嗯哼&#xff0c;爱好Java的小菜鸟 &#x1f525;如果感觉博主的文章还不错的话&#xff0c;请&#x1f44d;三连支持&#x1f44d;一下博主哦 &#x1f4dd;社区论坛&#xff1a;希望大家能加入社区共同进步…

【JVM 内存结构 | 程序计数器】

内存结构 前言简介程序计数器定义作用特点示例应用场景 主页传送门&#xff1a;&#x1f4c0; 传送 前言 Java 虚拟机的内存空间由 堆、栈、方法区、程序计数器和本地方法栈五部分组成。 简介 JVM&#xff08;Java Virtual Machine&#xff09;内存结构包括以下几个部分&#…

和鲸 ModelWhale 与中科可控多款服务器完成适配认证,赋能中国云生态

当前世界正处于新一轮技术革命及传统产业数字化转型的关键期&#xff0c;云计算作为重要的技术底座&#xff0c;其产业发展与产业规模对我国数字经济的高质量运行有着不可取代的推动作用。而随着我国数字上云、企业上云加快进入常规化阶段&#xff0c;云计算承载的业务应用越来…