【李沐论文精读】Resnet精读

论文地址:Deep Residual Learning for Image Recognition

参考:撑起计算机视觉半边天的ResNet【论文精读】、ResNet论文逐段精读【论文精读】、【李沐论文精读系列】

一、导论

深度神经网络的优点:可以加很多层把网络变得特别深,然后不同程度的层会得到不同等级的feature,比如低级的视觉特征或者是高级的语义特征。但是学一个好的网络,就是简简单单的把所有网络堆在一起就行了吗?如果这样,网络做深就行了。

提出问题:随着网络越来越深,梯度就会出现爆炸或者消失

  • 解决办法就是:1、在初始化的时候要做好一点,就是权重在随机初始化的时候,权重不要特别大也不要特别小。2、在中间加入一些normalization,包括BN(batch normalization)可以使得校验每个层之间的那些输出和他的梯度的均值和方差相对来说比较深的网络是可以训练的,避免有一些层特别大,有一些层特别小。使用了这些技术之后是能够训练(能够收敛),虽然现在能够收敛了,但是当网络变深的时候,性能其实是变差的(精度会变差)
  • 文章提出出现精度变差的问题不是因为层数变多了,模型变复杂了导致的过拟合,而是因为训练误差也变高了(overfitting是说训练误差变得很低,但是测试误差变得很高),训练误差和测试误差都变高了,所以他不是overfitting。虽然网络是收敛的,但是好像没有训练出一个好的结果

深入讲述了深度增加了之后精度也会变差

  • 考虑一个比较浅一点的网络和他对应的比较深的版本(在浅的网络中再多加一些层进去),如果钱的网络效果还不错的话,神的网络是不应该变差的:深的网络新加的那些层,总是可以把这些层学习的变成一个identity mapping(输入是x,输出也是x,等价于可以把一些权重学成比如说简单的n分之一,是的输入和输出是一一对应的),但是实际情况是,虽然理论上权重是可以学习成这样,但是实际上做不到:假设让SGD去优化,深层学到一个跟那些浅层网络精度比较好的一样的结果,上面的层变成identity(相对于浅层神经网络,深层神经网络中多加的那些层全部变成identity),这样的话精度不应该会变差,应该是跟浅层神经网络是一样的,但是实际上SGD找不到这种最优解。
  • 这篇文章提出显式地构造出一个identity mapping,使得深层的神经网络不会变的比相对较浅的神经网络更差,它将其称为deep residual learning framework。
  • 要学的东西叫做H(x),假设现在已经有了一个浅的神经网络,他的输出是x,然后要在这个浅的神经网络上面再新加一些层,让它变得更深。新加的那些层不要直接去学H(x),而是应该去学H(x)-xx是原始的浅层神经网络已经学到的一些东西,新加的层不要重新去学习,而是去学习学到的东西和真实的东西之间的残差,最后整个神经网络的输出等价于浅层神经网络的输出x和新加的神经网络学习残差的输出之和,将优化目标从H(x)转变成为了H(x)-x

  • 上图中最下面的红色方框表示所要学习的H(x)
  • 蓝色方框表示原始的浅层神经网络
  • 红色阴影方框表示新加的层
  • o表示最终整个神经网络的输出
  • 这样的好处是:只是加了一个东西进来,没有任何可以学的参数,不会增加任何的模型复杂度,也不会使计算变得更加复杂,而且这个网络跟之前一样,也是可以训练的,没有任何改变。

下面这张图就对应了上一张图的简笔画。

二、related work

残差连接如何处理输入和输出的形状是不同的情况

  • 第一个方案是在输入和输出上分别添加一些额外的0,使得这两个形状能够对应起来然后可以相加
  • 第二个方案是之前提到过的全连接怎么做投影,做到卷积上,是通过一个叫做1*1的卷积层,这个卷积层的特点是在空间维度上不做任何东西,主要是在通道维度上做改变。所以只要选取一个1*1的卷积使得输出通道是输入通道的两倍,这样就能将残差连接的输入和输出进行对比了。在ResNet中,如果把输出通道数翻了两倍,那么输入的高和宽通常都会被减半,所以在做1*1的卷积的时候,同样也会使步幅为2,这样的话使得高宽和通道上都能够匹配上。

implementation中讲了实验的一些细节

  • 把短边随机的采样到256和480(AlexNet是直接将短边变成256,而这里是随机的)。随机放的比较大的好处是做随机切割,切割成224*224的时候,随机性会更多一点
  • 将每一个pixel的均值都减掉了
  • 使用了颜色的增强(AlexNet上用的是PCA,现在我们所使用的是比较简单的RGB上面的,调节各个地方的亮度、饱和度等)
  • 使用了BN(batch normalization)
  • 所有的权重全部是跟另外一个paper中的一样(作者自己的另外一篇文章)。注意写论文的时候,尽量能够让别人不要去查找别的文献就能够知道你所做的事情
  • 批量大小是56,学习率是0.1,然后每一次当错误率比较平的时候除以10
  • 模型训练了60*10^4个批量。建议最好不要写这种iteration,因为他跟批量大小是相关的,如果变了一个批量大小,他就会发生改变,所以现在一般会说迭代了多少遍数据,相对来说稳定一点
  • 这里没有使用dropout,因为没有全连接层,所以dropout没有太大作用
  • 在测试的时候使用了标准的10个crop testing(给定一张测试图片,会在里面随机的或者是按照一定规则的去采样10个图片出来,然后再每个子图上面做预测,最后将结果做平均)。这样的好处是因为训练的时候每次是随机把图片拿出来,测试的时候也大概进行模拟这个过程,另外做10次预测能够降低方差。
  • 采样的时候是在不同的分辨率上去做采样,这样在测试的时候做的工作量比较多,但是在实际过程中使用比较少。

三、实验

 3.1 不同配置的ResNet结构

  • 上表是整个ResNet不同架构之间的构成信息(5个版本)
  • 第一个7*7的卷积是一样的
  • 接下来的pooling层也是一样的
  • 最后的全连接层也是一样的(最后是一个全局的pooling然后再加一个1000的全连接层做输出)
  • 不同的架构之间,主要是中间部分不一样,也就是那些复制的卷积层是不同的
  • conv2.x:x表示里面有很多不同的层(块)
  • 【3*3,64】:46是通道数
  • 模型的结构为什么取成表中的结构,论文中并没有细讲,这些超参数是作者自己调出来的,实际上这些参数可以通过一些网络架构的自动选取
  • flops:整个网络要计算多少个浮点数运算。卷积层的浮点运算等价于输入的高乘以宽乘以通道数乘以输出通道数再乘以核的窗口的高和宽
3.2 残差结构效果对比

  • 上图中比较了18层和34层在有残差连接和没有残差连接的结果
  • 左图中,红色曲线表示34的验证精度(或者说是测试精度)
  • 左图中,粉色曲线表示的是34的训练精度
  • 一开始训练精度是要比测试精度高的,因为在一开始的时候使用了大量的数据增强,使得寻来你误差相对来说是比较大的,而在测试的时候没有做数据增强,噪音比较低,所以一开始的测试误差是比较低的
  • 图中曲线的数值部分是由于学习率的下降,每一次乘以0.1,对整个曲线来说下降就比较明显。为什么现在不使用乘0.1这种方法:在什么时候乘时机不好掌控,如果乘的太早,会后期收敛无力,晚一点乘的话,一开始找的方向更准一点,对后期来说是比较好的
  • 上图主要是想说明在有残差连接的时候,34比28要好;另外对于34来说,有残差连接会好很多;其次,有了残差连接以后,收敛速度会快很多,核心思想是说,在所有的超参数都一定的情况下,有残差的连接收敛会快,而且后期会好。
3.3 残差结构中,输入输出维度不一致如何处理

A. pad补0,使维度一致;
B. 维度不一致的时候,使其映射到统一维度,比如使用全连接或者是CNN中的1×1卷积(输出通道是输入的两倍)。
C. 不管输入输出维度是否一致,都进行投影映射。(就算输入输出的形状是一样的,一样可以在连接的时候做个1*1的卷积,但是输入和输出通道数是一样的,做一次投影)

从上述结果可以看到,B和C效果差不多,都比A好。但是做映射会增加很多复杂度,考虑到ResNet中大部分情况输入输出维度是一样的(也就是4个模块衔接时通道数会变),作者最后采用了方案B

3.4 深层ResNet引入瓶颈结构Bottleneck

在ResNet-50及以上的结构中,模型更深了,可以学习更多的参数,所以通道数也要变大。比如前面模型配置表中,ResNet-50/101/152的第一个残差模块输出都是256维,增加了4倍。

如果残差结构还是和之前一样,计算量就增加的太多了(增加16倍),划不来。所以重新设计了Bottleneck结构,将输入从256维降为64维,然后经过一个3×3卷积,再升维回256维。这样操作之后,复杂度和左侧图是差不多的。这也是为啥ResNet-50对比ResNet-34理论计算量变化不大的原因。(实际上1×1卷积计算效率不高,所以ResNet-50计算还是要贵一些)

3.5 代码实现

resnet中残差块有两种:(use_1x1conv=True/False)

  1. 步幅为2 ,高宽减半,通道数增加。所以shortcut连接部分会加一个1×1卷积层改变通道数
  2. 步幅为1,高宽不变

残差块代码实现

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2lclass Residual(nn.Module):  #@savedef __init__(self, input_channels, num_channels,use_1x1conv=False, strides=1):super().__init__()self.conv1 = nn.Conv2d(input_channels, num_channels,kernel_size=3, padding=1, stride=strides)self.conv2 = nn.Conv2d(num_channels, num_channels,kernel_size=3, padding=1)if use_1x1conv:self.conv3 = nn.Conv2d(input_channels, num_channels,kernel_size=1, stride=strides)else:self.conv3 = Noneself.bn1 = nn.BatchNorm2d(num_channels)self.bn2 = nn.BatchNorm2d(num_channels)#每个bn都有自己的参数要学习,所以需要定义两个def forward(self, X):Y = F.relu(self.bn1(self.conv1(X)))Y = self.bn2(self.conv2(Y))if self.conv3:X = self.conv3(X)Y += Xreturn F.relu(Y)

四、结论

本身的论文并没有给结论。但是在这里讨论一下为什么ResNet训练起来比较快?

  • 一方面是因为梯度上保持的比较好,新加一些层的话,加的层越多,梯度的乘法就越多,因为梯度比较小,一般是在0附近的高斯分布,所以就会导致在很深的时候就会比较小(梯度消失)。虽然batch normalization或者其他东西能够对这种状况进行改善,但是实际上相对来说还是比较小,但是如果加了一个ResNet的话,它的好处就是在原有的基础上加上了浅层网络的梯度,深层的网络梯度很小没有关系,浅层网络可以进行训练,变成了加法,一个小的数加上一个大的数,相对来说梯度还是会比较大的。也就是说,不管后面新加的层数有多少,前面浅层网络的梯度始终是有用的,这就是从误差反向传播的角度来解释为什么训练的比较快。

  • 在CIFAR上面加到了1000层以上,没有做任何特别的regularization,然后效果很好,overfitting有一点点但是不大。SGD收敛是没有意义的,SGD的收敛就是训练不动了,收敛是最好收敛在比较好的地方。做深的时候,用简单的机器训练根本就跑不动,根本就不会得到比较好的结果,所以只看收敛的话意义不大,但是在加了残差连接的情况下,因为梯度比较大,所以就没那么容易收敛,所以导致一直能够往前(SGD的精髓就是能够一直能跑的动,如果哪一天跑不动了,梯度没了就完了,就会卡在一个地方出不去了,所以它的精髓就在于需要梯度够大,要一直能够跑,因为有噪音的存在,所以慢慢的他总是会收敛的,所以只要保证梯度一直够大,其实到最后的结果就会比较好)

为什么ResNet在CIFAR-10那么小的数据集上他的过拟合不那么明显?

虽然模型很深,参数很多,但是因为模型是这么构造的,所以使得他内在的模型复杂度其实不是很高,也就是说,很有可能加了残差链接之后,使得模型的复杂度降低了,一旦模型的复杂度降低了,其实过拟合就没那么严重了

  • 所谓的模型复杂度降低了不是说不能够表示别的东西了,而是能够找到一个不那么复杂的模型去拟合数据,就如作者所说,不加残差连接的时候,理论上也能够学出一个有一个identity的东西(不要后面的东西),但是实际上做不到,因为没有引导整个网络这么走的话,其实理论上的结果它根本过不去,所以一定是得手动的把这个结果加进去,使得它更容易训练出一个简单的模型来拟合数据的情况下,等价于把模型的复杂度降低了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/718743.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

力扣周赛387

第一题 代码 package Competition.The387Competitioin;public class Demo1 {public static void main(String[] args) {}public int[] resultArray(int[] nums) {int ans[]new int[nums.length];int arr1[]new int[nums.length];int arr2[]new int[nums.length];if(nums.leng…

Linux系统Docker部署RStudio Server

文章目录 前言1. 安装RStudio Server2. 本地访问3. Linux 安装cpolar4. 配置RStudio server公网访问地址5. 公网远程访问RStudio6. 固定RStudio公网地址 前言 RStudio Server 使你能够在 Linux 服务器上运行你所熟悉和喜爱的 RStudio IDE,并通过 Web 浏览器进行访问…

4. 编写app组件

1. 代码 main.ts // 引入createApp用于创建应用 import {createApp} from "vue"// 引入App根组件 import App from ./App.vue createApp(App).mount(#app) App.vue <!-- vue文件可以写三种标签1. template标签&#xff0c;写html结构2. script 脚本标签&…

风险评估是什么意思?与等保测评有什么区别?

最近看到不少小伙伴在问&#xff0c;风险评估是什么意思&#xff1f;与等保测评有什么区别&#xff1f;这里我们就来简单聊聊。 风险评估是什么意思&#xff1f; 风险评估是指对某个特定领域或项目进行全面分析和评估&#xff0c;以确定可能存在的潜在风险和危害&#xff0c;并…

2023全球软件开发大会-上海站:探索技术前沿,共筑未来软件生态(附大会核心PPT下载)

随着信息技术的迅猛发展&#xff0c;全球软件开发大会&#xff08;QCon&#xff09;已成为软件行业最具影响力的年度盛会之一。2023年&#xff0c;QCon再次来到上海&#xff0c;汇聚了众多业界精英、技术领袖和开发者&#xff0c;共同探讨软件开发的最新趋势和实践。 一、大会…

服务器感染了.ma1x0勒索病毒,如何确保数据文件完整恢复?

引言&#xff1a; 网络安全成为至关重要的议题。.ma1x0勒索病毒是当前网络威胁中的一种恶意软件&#xff0c;它的出现给用户带来了极大的困扰。然而&#xff0c;正如任何挑战一样&#xff0c;我们也有方法来面对并克服.ma1x0勒索病毒。本文将全面介绍这种病毒的特点&#xff0…

MB85RC铁电 FRAM驱动(全志平台linux)

测试几天发现一个bug&#xff0c;就是无法一次读取32个字节的数据&#xff0c;1-31,33,128,512都试过了&#xff0c;唯独无法读取32个字节&#xff0c;驱动未报错&#xff0c;但是读取的都是0&#xff0c;找不到原因&#xff0c;估计应该是全志iic驱动的问题&#xff0c;暂时没…

leetcode - 2095. Delete the Middle Node of a Linked List

Description You are given the head of a linked list. Delete the middle node, and return the head of the modified linked list. The middle node of a linked list of size n is the ⌊n / 2⌋th node from the start using 0-based indexing, where ⌊x⌋ denotes th…

python中的类与对象(3)

目录 一. 类的多继承 二. 类的封装 三. 类的多态 四. 类与对象综合练习&#xff1a;校园管理系统 一. 类的多继承 在&#xff08;2&#xff09;第四节中我们介绍了什么是类的继承&#xff0c;在子类的括号里面写入要继承的父类名。上一节我们只在括号内写了一个父类名&…

【机器人最短路径规划问题(栅格地图)】基于模拟退火算法求解

代码获取方式&#xff1a;QQ&#xff1a;491052175 或者 私聊博主获取 基于模拟退火算法求解机器人最短路径规划问题&#xff08;栅格地图&#xff09;的仿真结果 仿真结果&#xff1a; 初始解的路径规划图 收敛曲线&#xff1a; 模拟退火算法求解的路径规划图 结论&#xff…

Ubuntu20安装zabbix-agent2,对接zabbix 6.4

在Ubuntu 20.04 LTS上安装Zabbix Agent 2并与Zabbix Server 6.4对接&#xff0c;请按照以下步骤操作&#xff1a; 更新系统&#xff1a; sudo apt update sudo apt upgrade 添加Zabbix官方仓库&#xff1a; 首先&#xff0c;需要将Zabbix的官方存储库添加到你的系统中以获取Za…

【了解SpringCloud Gateway微服务网关】

曾梦想执剑走天涯&#xff0c;我是程序猿【AK】 目录 简述概要知识图谱什么是SpringCloudGateway功能特征应用场景核心概念配置文件工作原理路由谓词工厂&#xff08;内置的&#xff09;[After 路由谓词工厂](https://docs.spring.io/spring-cloud-gateway/docs/current/refere…

Mysql运维篇(七) 部署MHA--完结

一路走来&#xff0c;所有遇到的人&#xff0c;帮助过我的、伤害过我的都是朋友&#xff0c;没有一个是敌人。如有侵权&#xff0c;请留言&#xff0c;我及时删除&#xff01; 一、MHA软件构成 Manager工具包主要包括以下几个工具&#xff1a; masterha_manger 启…

【C++】多态深入分析

目录 一&#xff0c;多态的原理 1&#xff0c;虚函数表与虚函数表指针 2&#xff0c;原理调用 3&#xff0c;动态绑定与静态绑定 二&#xff0c;抽象类 三&#xff0c;单继承和多继承关系的虚函数表 1&#xff0c;单继承中的虚函数表 2&#xff0c;多继承中的虚函数表 …

内网搭建mysql8.0并搭建主从复制详细教程!!!

一、安装mysql 1.1 mysql下载链接&#xff1a; https://downloads.mysql.com/archives/community/ 1.2 解压包并创建相应的数据目录 tar -xvf mysql-8.2.0-linux-glibc2.28-x86_64.tar.xz -C /usr/local cd /usr/local/ mv mysql-8.2.0-linux-glibc2.28-x86_64/ mysql mkdir…

Python绘图-9饼图(上)

饼图&#xff08;Pie Chart&#xff09;是一种用于表示数据分类和相对大小的可视化图形。在饼图中&#xff0c;整个圆形代表数据的总和&#xff0c;而圆形内的各个扇形则代表不同的分类或类别&#xff0c;扇形的面积大小表示该类别在整体中所占的比例。饼图通常用于展示数据的分…

Window部署Jaeger

参考&#xff1a;windows安装使用jaeger链路追踪_windows安装jaeger-CSDN博客 下载&#xff1a;Releases jaegertracing/jaeger GitHub Jaeger – Download Jaeger 目录 1、安装nssm 2、安装运行 elasticsearch 3、安装运行 3.1部署JaegerAgent 3.2部署JaegerCollec…

【全志D1-H 哪吒开发板】Debian系统安装调教和点灯指南

全志D1-H开发板【哪吒】使用Deabian系统入门 特别说明&#xff1a; 因为涉及到操作较多&#xff0c;博文可能会导致格式丢失 其中内容&#xff0c;会根据后续使用做优化调整 目录&#xff1a; 参考资料固件烧录启动调教点灯问题 〇、参考资料 官方资料 开发板-D1开发板【…

C++:函数模板整理

函数模板: 找到函数相同的实现思路&#xff0c;区别于函数的参数类型。 使用函数模板使得函数可容纳不同类型的参数实现函数功能&#xff0c;而不是当类型不同时便编译大量类型不同的函数&#xff0c;产生大量重复代码和内存占用 函数模板格式&#xff1a; template<typ…

[Vulnhub]靶场 Red

kali:192.168.56.104 主机发现 arp-scan -l # arp-scan -l Interface: eth0, type: EN10MB, MAC: 00:0c:29:d2:e0:49, IPv4: 192.168.56.104 Starting arp-scan 1.10.0 with 256 hosts (https://github.com/royhills/arp-scan) 192.168.56.1 …