【深度学习笔记】5_11 残差网络ResNet

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图

5.11 残差网络(ResNet)

让我们先思考一个问题:对神经网络模型添加新的层,充分训练后的模型是否只可能更有效地降低训练误差?理论上,原模型解的空间只是新模型解的空间的子空间。也就是说,如果我们能将新添加的层训练成恒等映射 f ( x ) = x f(x) = x f(x)=x,新模型和原模型将同样有效。由于新模型可能得出更优的解来拟合训练数据集,因此添加层似乎更容易降低训练误差。然而在实践中,添加过多的层后训练误差往往不降反升。即使利用批量归一化带来的数值稳定性使训练深层模型更加容易,该问题仍然存在。针对这一问题,何恺明等人提出了残差网络(ResNet) [1]。它在2015年的ImageNet图像识别挑战赛夺魁,并深刻影响了后来的深度神经网络的设计。

5.11.2 残差块

让我们聚焦于神经网络局部。如图5.9所示,设输入为 x \boldsymbol{x} x。假设我们希望学出的理想映射为 f ( x ) f(\boldsymbol{x}) f(x),从而作为图5.9上方激活函数的输入。左图虚线框中的部分需要直接拟合出该映射 f ( x ) f(\boldsymbol{x}) f(x),而右图虚线框中的部分则需要拟合出有关恒等映射的残差映射 f ( x ) − x f(\boldsymbol{x})-\boldsymbol{x} f(x)x。残差映射在实际中往往更容易优化。以本节开头提到的恒等映射作为我们希望学出的理想映射 f ( x ) f(\boldsymbol{x}) f(x)。我们只需将图5.9中右图虚线框内上方的加权运算(如仿射)的权重和偏差参数学成0,那么 f ( x ) f(\boldsymbol{x}) f(x)即为恒等映射。实际中,当理想映射 f ( x ) f(\boldsymbol{x}) f(x)极接近于恒等映射时,残差映射也易于捕捉恒等映射的细微波动。图5.9右图也是ResNet的基础块,即残差块(residual block)。在残差块中,输入可通过跨层的数据线路更快地向前传播。

在这里插入图片描述

图5.9 普通的网络结构(左)与加入残差连接的网络结构(右)

ResNet沿用了VGG全 3 × 3 3\times 3 3×3卷积层的设计。残差块里首先有2个有相同输出通道数的 3 × 3 3\times 3 3×3卷积层。每个卷积层后接一个批量归一化层和ReLU激活函数。然后我们将输入跳过这两个卷积运算后直接加在最后的ReLU激活函数前。这样的设计要求两个卷积层的输出与输入形状一样,从而可以相加。如果想改变通道数,就需要引入一个额外的 1 × 1 1\times 1 1×1卷积层来将输入变换成需要的形状后再做相加运算。

残差块的实现如下。它可以设定输出通道数、是否使用额外的 1 × 1 1\times 1 1×1卷积层来修改通道数以及卷积层的步幅。

import time
import torch
from torch import nn, optim
import torch.nn.functional as Fimport sys
sys.path.append("..") 
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')class Residual(nn.Module):  # 本类已保存在d2lzh_pytorch包中方便以后使用def __init__(self, in_channels, out_channels, use_1x1conv=False, stride=1):super(Residual, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)if use_1x1conv:self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)else:self.conv3 = Noneself.bn1 = nn.BatchNorm2d(out_channels)self.bn2 = nn.BatchNorm2d(out_channels)def forward(self, X):Y = F.relu(self.bn1(self.conv1(X)))Y = self.bn2(self.conv2(Y))if self.conv3:X = self.conv3(X)return F.relu(Y + X)

下面我们来查看输入和输出形状一致的情况。

blk = Residual(3, 3)
X = torch.rand((4, 3, 6, 6))
blk(X).shape # torch.Size([4, 3, 6, 6])

我们也可以在增加输出通道数的同时减半输出的高和宽。

blk = Residual(3, 6, use_1x1conv=True, stride=2)
blk(X).shape # torch.Size([4, 6, 3, 3])

5.11.2 ResNet模型

ResNet的前两层跟之前介绍的GoogLeNet中的一样:在输出通道数为64、步幅为2的 7 × 7 7\times 7 7×7卷积层后接步幅为2的 3 × 3 3\times 3 3×3的最大池化层。不同之处在于ResNet每个卷积层后增加的批量归一化层。

net = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

GoogLeNet在后面接了4个由Inception块组成的模块。ResNet则使用4个由残差块组成的模块,每个模块使用若干个同样输出通道数的残差块。第一个模块的通道数同输入通道数一致。由于之前已经使用了步幅为2的最大池化层,所以无须减小高和宽。之后的每个模块在第一个残差块里将上一个模块的通道数翻倍,并将高和宽减半。

下面我们来实现这个模块。注意,这里对第一个模块做了特别处理。

def resnet_block(in_channels, out_channels, num_residuals, first_block=False):if first_block:assert in_channels == out_channels # 第一个模块的通道数同输入通道数一致blk = []for i in range(num_residuals):if i == 0 and not first_block:blk.append(Residual(in_channels, out_channels, use_1x1conv=True, stride=2))else:blk.append(Residual(out_channels, out_channels))return nn.Sequential(*blk)

接着我们为ResNet加入所有残差块。这里每个模块使用两个残差块。

net.add_module("resnet_block1", resnet_block(64, 64, 2, first_block=True))
net.add_module("resnet_block2", resnet_block(64, 128, 2))
net.add_module("resnet_block3", resnet_block(128, 256, 2))
net.add_module("resnet_block4", resnet_block(256, 512, 2))

最后,与GoogLeNet一样,加入全局平均池化层后接上全连接层输出。

net.add_module("global_avg_pool", d2l.GlobalAvgPool2d()) # GlobalAvgPool2d的输出: (Batch, 512, 1, 1)
net.add_module("fc", nn.Sequential(d2l.FlattenLayer(), nn.Linear(512, 10))) 

这里每个模块里有4个卷积层(不计算 1 × 1 1\times 1 1×1卷积层),加上最开始的卷积层和最后的全连接层,共计18层。这个模型通常也被称为ResNet-18。通过配置不同的通道数和模块里的残差块数可以得到不同的ResNet模型,例如更深的含152层的ResNet-152。虽然ResNet的主体架构跟GoogLeNet的类似,但ResNet结构更简单,修改也更方便。这些因素都导致了ResNet迅速被广泛使用。

在训练ResNet之前,我们来观察一下输入形状在ResNet不同模块之间的变化。

X = torch.rand((1, 1, 224, 224))
for name, layer in net.named_children():X = layer(X)print(name, ' output shape:\t', X.shape)

输出:

0  output shape:	 torch.Size([1, 64, 112, 112])
1  output shape:	 torch.Size([1, 64, 112, 112])
2  output shape:	 torch.Size([1, 64, 112, 112])
3  output shape:	 torch.Size([1, 64, 56, 56])
resnet_block1  output shape:	 torch.Size([1, 64, 56, 56])
resnet_block2  output shape:	 torch.Size([1, 128, 28, 28])
resnet_block3  output shape:	 torch.Size([1, 256, 14, 14])
resnet_block4  output shape:	 torch.Size([1, 512, 7, 7])
global_avg_pool  output shape:	 torch.Size([1, 512, 1, 1])
fc  output shape:	 torch.Size([1, 10])

5.11.3 获取数据和训练模型

下面我们在Fashion-MNIST数据集上训练ResNet。

batch_size = 256
# 如出现“out of memory”的报错信息,可减小batch_size或resize
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

输出:

training on  cuda
epoch 1, loss 0.0015, train acc 0.853, test acc 0.885, time 31.0 sec
epoch 2, loss 0.0010, train acc 0.910, test acc 0.899, time 31.8 sec
epoch 3, loss 0.0008, train acc 0.926, test acc 0.911, time 31.6 sec
epoch 4, loss 0.0007, train acc 0.936, test acc 0.916, time 31.8 sec
epoch 5, loss 0.0006, train acc 0.944, test acc 0.926, time 31.5 sec

小结

  • 残差块通过跨层的数据通道从而能够训练出有效的深度神经网络。
  • ResNet深刻影响了后来的深度神经网络的设计。

参考文献

[1] He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 770-778).

[2] He, K., Zhang, X., Ren, S., & Sun, J. (2016, October). Identity mappings in deep residual networks. In European Conference on Computer Vision (pp. 630-645). Springer, Cham.


注:除代码外本节与原书此节基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/729710.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Express框架的产生

Express框架的产生,解决的痛点是什么? 1.优化Node.js在Web的开发 Express框架是一个基于Node.js的Web应用程序开发框架,它的产生主要是为了解决Node.js在Web开发中的一些痛点。 在Node.js出现之前,Web开发主要是基于传统的后端…

springboot项目集成,项目流程概述

一、项目介绍 二、项目设计原则 2.1整体原则 2.2持久层 2.3业务逻辑层 具体分析 三、实战 3.1项目搭建 <dependency><groupId>org.springframework.security</groupId><artifactId>spring-security-crypto</artifactId></dependency>&l…

双链表的实现(数据结构)

链表总体可以分为三大类 一、无头和有头 二、单向和双向 三、循环和不循环 从上面分类得知可以组合成8种不同类型链表&#xff0c;其中单链表最为简单&#xff0c;双联表最为复杂&#xff0c;两种链表都实现后其余链表都不成问题。 我们前期博客已将完成了单向无头不循环链表…

基于PHP的景点数据分析系统设计与实现

目 录 摘 要 I Abstract II 引 言 1 1 关键理论与技术 3 1.1 框架技术 3 1.1.1 QueryList 3 1.1.2 ThinkPHP 3 1.1.3 Amaze UI 3 1.2 数据可视化技术 4 1.3 数据库技术 4 1.4 本章小结 4 2 需求分析 5 2.1 业务流程分析 5 2.2 功能需求分析 5 2.3 用例分析 7 2.4 非功能性需求…

it-tools工具箱

it-tools 是一个在线工具集合&#xff0c;包含各种实用的开发工具、网络工具、图片视频工具、数学工具等 github地址&#xff1a;https://github.com/CorentinTh/it-tools 部署 docker run -d --name it-tools --restart unless-stopped -p 8080:80 corentinth/it-tools:lat…

gradle 相关

aar 不加 aar 以及 transitive true library可以通过多种格式上传到远程仓库&#xff0c;比如大部分情况下用到的.jar或.aar。当没有指定后缀的话&#xff0c;依赖的时候将会下载它的默认格式&#xff08;由上传方定义&#xff0c;如果没有定义则默认为.jar&#xff09;的Lib…

yolov8多batch推理,nms后处理

0. 背景 在高速公路监控视频场景下&#xff0c;图像分辨率大都是1920 * 1080或者2560 * 1440&#xff0c;远处的物体&#xff08;车辆和行人等&#xff09;都比较小。考虑需要对图像进行拆分&#xff0c;然后把拆分后的数据统一送入模型中&#xff0c;推理的结果然后再做nms&am…

redis centos7 单点搭建

redis centos 安装步骤 下载源文件编译Redis拷贝编译后文件修改配置文件启动redis 下载源文件 wget https://download.redis.io/redis-stable.tar.gz编译Redis tar -xzvf redis-stable.tar.gz cd redis-stable make如果编译成功&#xff0c;你会在src目录中找到几个 Redis 二…

Android Selinux详解[一]---整体介绍

Android 使用安全增强型 Linux (SELinux) 对所有进程强制执行强制访问控制 (MAC)&#xff0c;甚至包括以 Root/超级用户权限运行的进程&#xff08;Linux 功能&#xff09;。 借助 SELinux&#xff0c;Android 可以更好地保护和限制系统服务、控制对应用数据和系统日志的访问、…

【微前端乾坤】 vue2主应用、vue2+webpack子应用,vue3+webpack子应用、vue3+vite子应用的配置

因公司需求 需要将原本vue2iframe 形式的项目改成微前端乾坤的方式。 之前iframe都是直接嵌套到vue2项目的二级目录或者三级目录下的(反正就是要随处可嵌) 用乾坤的原因&#xff1a; 1、iframe嵌套的方式存在安全隐患&#xff1b; 2、项目是联合开发的&#xff0c; 乾坤的方便…

Doris画像存储实践系列二

上一篇: Doris画像存储系列一(https://editor.csdn.net/md/?articleId120416295) 六、画像宽表bitmap倒排表 重复一下bitmap倒排表的优点和缺点 标签类型标签值user_ids性别男1,2性别发3 优点: doris bitmap聚合表在对做用户画像群体计算时很友好,交集/并集/差集因为数据…

SQL23 统计每个学校各难度的用户平均刷题数

题解 | #统计每个学校各难度的用户平均刷题数# 题意明确&#xff1a; 计算每个学校用户不同难度下的用户平均答题题目数 问题分解&#xff1a; 限定条件&#xff1a;无&#xff1b;每个学校&#xff1a;按学校分组group by university不同难度&#xff1a;按难度分组group b…

Hack The Box-Crafty

目录 信息收集 rustscan whatweb WEB 漏洞利用 漏洞说明 漏洞验证 提权 get user.txt get Administrator 总结 信息收集 rustscan ┌──(root㉿ru)-[~/kali/hackthebox] └─# rustscan -a 10.10.11.249 --range0-65535 --ulimit5000 -- -A -sC [~] Automatically…

NLP:自定义模型训练

书接上文&#xff0c;为了完成指定的任务&#xff0c;我们需要额外训练一个特定场景的模型 这里主要参考了这篇博客&#xff1a;大佬的博客 我这里就主要讲一下我根据这位大佬的博客一步一步写下时&#xff0c;遇到的问题&#xff1a; 文中的cfg在哪里下载&#xff1f; 要不…

Fastjson漏洞利用合集

0x01 Fastjson 概述 1.应用场景 接口返回数据 Ajax异步访问数据RPC远程调用前后端分离后端返回的数据开放API(一些公司开放接口的时候&#xff0c;我们点击请求&#xff0c;返回的数据是JSON格式的)企业间合作接口(数据对接的时候定义的一种规范&#xff0c;确定入参&#x…

BUUCTF-MISC-[HDCTF2019]信号分析1

题目链接&#xff1a;BUUCTF在线评测 (buuoj.cn) 下载附件是一个WAV的文件&#xff0c;题目又叫做信号分析&#xff0c;用Adobe Audition 打开分析了 发现有很多长短不一样的信号&#xff0c;只需要分析一段 猜测长的是一短的为0 最后得到0101010101010101000000110 百度得知…

vscode如何远程到linux python venv虚拟环境开发?(python虚拟环境、vscode远程开发、vscode远程连接)

文章目录 1. 安装VSCode2. 安装扩展插件3. 配置SSH连接4. 输入用户名和密码5. 打开远程文件夹6. 创建/选择Python虚拟环境7. 安装Python插件 Visual Studio Code (VSCode) 提供了一种称为 Remote Development 的功能&#xff0c;允许用户在远程系统、容器或甚至 Windows 子系统…

【致逝去的青春】《龙珠》作者鸟山明逝世,享年68岁

鸟山明工作室&#xff08;BIRD STUDIO&#xff09;于3月8日发布讣告&#xff1a;鸟山明已于2024年3月1日因急性硬膜下血肿逝世&#xff0c;享年68岁。 《龙珠》从 1984 年开始于《周刊少年Jump》连载&#xff0c;过后曾改编曾多部动画、剧场版、游戏&#xff0c;相关周边商品也…

opengl 学习(二)-----你好,三角形

你好&#xff0c;三角形 分类demo效果解析 分类 opengl c demo #include "glad/glad.h" #include "glfw3.h" #include <iostream> #include <cmath> #include <vector>using namespace std;/** * 在学习此节之前&#xff0c;建议将这…

Alveo 概念拓扑结构

在 Alveo 加速卡中,涉及到的概念拓扑结构主要包括 Alveo 卡上的各个关键组件以及与主机系统之间的通信结构。以下是对这些概念拓扑结构的简要介绍: 1.DDR 即双数据率内存(Double Data Rate memory),是一种常见的计算机内存类型,用于存储和提供处理器所需的数据和指令。…