《动手学深度学习(PyTorch版)》笔记7.6

注:书中对代码的讲解并不详细,本文对很多细节做了详细注释。另外,书上的源代码是在Jupyter Notebook上运行的,较为分散,本文将代码集中起来,并加以完善,全部用vscode在python 3.9.18下测试通过,同时对于书上部分章节也做了整合。

Chapter7 Modern Convolutional Neural Networks

7.6 Residual Networks(ResNet)

随着我们设计越来越深的网络,深刻理解“新添加的层如何提升神经网络的性能”变得至关重要。

7.6.1 Function Class

首先,假设有一类特定的神经网络架构 F \mathcal{F} F,它包括学习速率和其他超参数设置。对于所有 f ∈ F f \in \mathcal{F} fF,存在一些参数集(例如权重和偏置),这些参数可以通过在合适的数据集上进行训练而获得。现在假设 f ∗ f^* f是我们真正想要找到的函数,如果是 f ∗ ∈ F f^* \in \mathcal{F} fF,那我们可以轻而易举的训练得到它,但通常我们不会那么幸运。我们将尝试找到一个函数 f F ∗ f^*_\mathcal{F} fF,这是我们在 F \mathcal{F} F中的最佳选择。例如,给定一个具有 X \mathbf{X} X特性和 y \mathbf{y} y标签的数据集,我们可以尝试通过解决以下优化问题来找到它:

f F ∗ : = a r g m i n f L ( X , y , f ) ,  f ∈ F . f^*_\mathcal{F} := \mathop{\mathrm{argmin}}_f L(\mathbf{X}, \mathbf{y}, f) \text{ , } f \in \mathcal{F}. fF:=argminfL(X,y,f) , fF.

为了得到更近似真正 f ∗ f^* f的函数,唯一合理的可能性是设计一个更强大的架构 F ′ \mathcal{F}' F。换句话说,我们预计 f F ′ ∗ f^*_{\mathcal{F}'} fF f F ∗ f^*_{\mathcal{F}} fF“更近似”。然而,如果 F ⊈ F ′ \mathcal{F} \not\subseteq \mathcal{F}' FF,则无法保证新的体系“更近似”。事实上, f F ′ ∗ f^*_{\mathcal{F}'} fF可能更糟:如下图所示,对于非嵌套函数(non-nested function)类,较复杂的函数类并不总是向“真”函数 f ∗ f^* f靠拢(复杂度由 F 1 \mathcal{F}_1 F1 F 6 \mathcal{F}_6 F6递增)。在下图的左边,虽然 F 3 \mathcal{F}_3 F3 F 1 \mathcal{F}_1 F1更接近 f ∗ f^* f,但 F 6 \mathcal{F}_6 F6却离的更远了。相反,对于下图右边的嵌套函数(nested function)类 F 1 ⊆ … ⊆ F 6 \mathcal{F}_1 \subseteq \ldots \subseteq \mathcal{F}_6 F1F6,我们可以避免上述问题。
在这里插入图片描述

因此,只有当较复杂的函数类包含较小的函数类时,我们才能确保提高它们的性能。对于深度神经网络,如果我们能将新添加的层训练成恒等映射(identity function) f ( x ) = x f(\mathbf{x}) = \mathbf{x} f(x)=x,新模型和原模型将同样有效。同时,由于新模型可能得出更优的解来拟合训练数据集,因此添加层似乎更容易降低训练误差。针对这一问题,何恺明等人提出了残差网络(ResNet)。其核心思想是:每个附加层都应该更容易地包含原始函数作为其元素之一。于是,残差块(residual blocks)便诞生了,这个设计对如何建立深层神经网络产生了深远的影响。

7.6.2 Residual Blocks

在这里插入图片描述

如上图所示,假设我们的原始输入为 x x x,而希望学出的理想映射为 f ( x ) f(\mathbf{x}) f(x)。上图左边是一个正常块,虚线框中的部分需要直接拟合出该映射 f ( x ) f(\mathbf{x}) f(x),而右边是ResNet的基础架构–残差块(residual block),虚线框中的部分则需要拟合出残差映射 f ( x ) − x f(\mathbf{x}) - \mathbf{x} f(x)x。残差映射在现实中往往更容易优化。以恒等映射作为理想映射 f ( x ) f(\mathbf{x}) f(x),只需将上图右边虚线框内上方的加权运算(如仿射)的权重和偏置参数设成0,那么 f ( x ) f(\mathbf{x}) f(x)即为恒等映射。实际上,当理想映射 f ( x ) f(\mathbf{x}) f(x)极接近于恒等映射时,残差映射也易于捕捉恒等映射的细微波动。在残差块中,输入可通过跨层数据线路更快地向前传播,且可以避免某些梯度消失或梯度爆炸的问题。

在这里插入图片描述

ResNet沿用了VGG完整的 3 × 3 3\times 3 3×3卷积层设计。残差块里首先有2个有相同输出通道数的 3 × 3 3\times 3 3×3卷积层,每个卷积层后接一个批量规范化层和ReLU激活函数,然后我们通过跨层数据通路,跳过这2个卷积运算,将输入直接加在最后的ReLU激活函数前。这样的设计要求2个卷积层的输出与输入形状一样,从而使它们可以相加。如果想改变通道数,就需要引入一个额外的 1 × 1 1\times 1 1×1卷积层来将输入变换成需要的形状后再做相加运算。

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
import matplotlib.pyplot as pltclass Residual(nn.Module):  #@savedef __init__(self, input_channels,num_channels,use_1x1conv=False, strides=1):super().__init__()self.conv1 = nn.Conv2d(input_channels, num_channels,kernel_size=3, padding=1, stride=strides)self.conv2 = nn.Conv2d(num_channels, num_channels,kernel_size=3, padding=1)if use_1x1conv:self.conv3 = nn.Conv2d(input_channels, num_channels,kernel_size=1, stride=strides)else:self.conv3 = Noneself.bn1 = nn.BatchNorm2d(num_channels)self.bn2 = nn.BatchNorm2d(num_channels)def forward(self, X):Y = F.relu(self.bn1(self.conv1(X)))Y = self.bn2(self.conv2(Y))if self.conv3:X = self.conv3(X)Y += Xreturn F.relu(Y)

如下图所示,此代码生成两种类型的网络:当use_1x1conv=False时,应用ReLU非线性函数之前,将输入添加到输出;当use_1x1conv=True时,使用 1 × 1 1 \times 1 1×1卷积调整通道和分辨率。

在这里插入图片描述

blk = Residual(3,3)#输入和输出形状一致
X = torch.rand(4, 3, 6, 6)
Y = blk(X)
print(Y.shape)blk = Residual(3,6, use_1x1conv=True, strides=2)#增加输出通道数的同时,减半输出的高和宽
print(blk(X).shape)#定义ResNet的模块
#b2-b5各有4个卷积层(不包括恒等映射的1x1卷积层),加上第一个7x7卷积层和最后一个全连接层,共有18层,因此这种模型通常被称为ResNet-18
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))def resnet_block(input_channels, num_channels, num_residuals,first_block=False):blk = []for i in range(num_residuals):if i == 0 and not first_block:blk.append(Residual(input_channels, num_channels,use_1x1conv=True, strides=2))else:blk.append(Residual(num_channels, num_channels))return blkb2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))net = nn.Sequential(b1, b2, b3, b4, b5,nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(), nn.Linear(512, 10))X = torch.rand(size=(1, 1, 224, 224))
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape:\t', X.shape)lr, num_epochs, batch_size = 0.05, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
plt.show()

训练结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/680531.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何在CentOS7上使用防火墙保护Docker容器的端口

防火墙设置对Docker容器内开放的端口无效? 在主机或虚机上运行Docker容器时,即便主机启用了firewalld服务,仍然存在一些安全隐患,尤其是当Docker容器内打开端口并监听0.0.0.0时,存在即使通过firewall-cmd配置了阻止某…

前端主流框架:项目运行命令 npm 详解

作为一位资深前端开发,我对npm(Node Package Manager)的使用有着深入的了解。npm是Node.js的包管理器,用于安装、管理和删除各种前端库和工具。现在,让我们深入了解npm在Vue、React、Angular和Vue 3项目中的一些基本使…

【软件工程导论】实验二——编制数据字典(数字化校园系统案例分析)

数字化校园系统案例分析 问题定义实验内容编制内容1数据项数据流处理逻辑数据存储 2外部实体 问题定义 数字化校园系统期望以数字化信息和网络为基础,在计算机和网络技术上建立起对教学、科研、管理、技术服务、生活服务等校园信息的收集、处理、整合、存储、传输和…

嵌入式Qt 第一个Qt项目

一.创建Qt项目 打开Qt Creator 界面选择 New Project或者选择菜单栏 【文件】-【新建文件或项目】菜单项 弹出New Project对话框,选择Qt Widgets Application 选择【Choose】按钮,弹出如下对话框 设置项目名称和路径,按照向导进行下一步 选…

EMC学习笔记(二十六)降低EMI的PCB设计指南(六)

降低EMI的PCB设计指南(六) 1.PCB布局1.1 带键盘和显示器的前置面板PCB在汽车和消费类应用中的应用1.2 敏感元器件的布局1.3 自动布线器 2.屏蔽2.1 工作原理2.2 屏蔽接地2.3 电缆屏蔽至旁路2.4 缝隙天线:冷却槽和缝隙 tips:资料主要…

MySQL篇----第二十一篇

系列文章目录 文章目录 系列文章目录前言一、什么是乐观锁二、什么是悲观锁三、什么是时间戳四、什么是行级锁前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站,这篇文章男女通用,看懂了就去分享给你的码吧。 一、…

无人机概述及系统组成,无人机系统的构成

无人机的定义 无人驾驶航空器,是一架由遥控站管理(包括远程操纵或自主飞行)的航空器,也称遥控驾驶航空器,以下简称无人机。 无人机系统的定义 无人机系统,也称无人驾驶航空器系统,是指一架无人…

屏幕字体种类介绍

[ Script and font support in Windows ] [Windows 中的脚本和字体支持] 在Windows 2000 以前,Windows 的每个主要版本都会添加对新脚本的文本显示支持。本文介绍了每个主要版本中的更改。 Since before Windows 2000, text-display support for new scr…

树莓派4B(Raspberry Pi 4B)使用docker搭建springBoot/springCloud服务

树莓派4B(Raspberry Pi 4B)使用docker搭建springBoot/springCloud服务 前提:本文基于Ubuntu,Java8,SpringBoot 2.6.13讲解 准备工作 准备SpringBoot/SpringCloud项目jar包 用 maven 打包springBoot/springCloud项目&…

idea中vue文件如何快捷打出html标签结构,不写<

例如写一个<button></button>标签&#xff1a;快捷键如下 先写一个button&#xff0c;然后再按tab键即可自动生成一对标签。 演示&#xff1a; 步骤一&#xff1a; 步骤二&#xff1a;

PHP中读取(截取substr)字符串前N个字符或者从第几个字符开始取几个字符

html <?php $str "123456789";echo substr($str , 0 , 3);//从左边第一位字符起截取3位字符&#xff1a;结果&#xff1a;123echo substr($str , 3 , 3);//从左边第3位字符起截取3位字符&#xff1a;结果&#xff1a;456?> html <?php$rest substr(&…

【华为OD机试】 最长子字符串的长度(一)【2024 C卷|100分】

【华为OD机试】-真题 !!点这里!! 【华为OD机试】真题考点分类 !!点这里 !! 题目描述 给你一个字符串 s,首尾相连成一个环形,请你在环中找出 o 字符出现了偶数次最长子字符串的长度。 输入描述 输入是一个小写字母组成的字符串 输出描述 输出是一个整数 备注 1 ≤ s.length…

STM32 STD/HAL库驱动W25Q64模块读写字库数据+OLED0.96显示例程

STM32 STD/HAL库驱动W25Q64 模块读写字库数据OLED0.96显示例程 &#x1f3ac;原创作者对W25Q64保存汉字字库演示&#xff1a; W25Q64保存汉字字库 &#x1f39e;测试字体显示效果&#xff1a; &#x1f4d1;功能实现说明 利用W25Q64保存汉字字库&#xff0c;OLED显示汉字的时…

Rust基础拾遗--辅助功能

Rust基础拾遗 前言1.错误处理1.1 panic为什么是 Result 2. create与模块3. 宏4. 不安全代码5. 外部函数 前言 通过Rust程序设计-第二版笔记的形式对Rust相关重点知识进行汇总&#xff0c;读者通读此系列文章就可以轻松的把该语言基础捡起来。 1.错误处理 Rust 中的两类错误处理…

寒假作业——2/13

作业1 作业2 cp cp 当前的文件位置 复制到哪个位置 格式 : cp 路径/文件 路径/目录名/重新命名的目录名 mv mv 当前的文件位置 复制到哪个位置 格式 : mv 路径/文件 路径/目录名/重新命名的目录名 也可进行重命名操作 find 查找文件 find 目标路径 -name 文件名 后续…

vue监视和深度监视

vue监视 监视属性watch 1.监视的属性变化时&#xff0c;回调函数自动调用&#xff0c;自动操作 2.监视的属性一定要存在&#xff0c;才可以进行监视 3.监视的写法 1.new vue的时候传入watch配置 2.通过vm.$watch监视 vue监视深度 深度监视 1.vue中的watch默认不检测对象内部…

leetcode(数组)128.最长连续序列(c++详细解释)DAY8

文章目录 1.题目示例提示 2.解答思路3.实现代码结果 4.总结 1.题目 给定一个未排序的整数数组 nums &#xff0c;找出数字连续的最长序列&#xff08;不要求序列元素在原数组中连续&#xff09;的长度。 请你设计并实现时间复杂度为 O(n) 的算法解决此问题。 示例 示例 1&a…

政安晨:在Jupyter中【示例演绎】Matplotlib的官方指南(二){Image tutorial}·{Python语言}

咱们接着上一篇&#xff0c;这次咱们讲使用Matplotlib绘制图像的简短尝试。 我的这个系列的上一篇文章在这里&#xff1a; 政安晨&#xff1a;在Jupyter中【示例演绎】Matplotlib的官方指南&#xff08;一&#xff09;{Pyplot tutorial}https://blog.csdn.net/snowdenkeke/ar…

IM聊天系统为什么需要做消息幂等?如何使用Redis以及Lua脚本做消息幂等【第12期】

0前言 消息收发模型 第一张图是一个时序图&#xff0c;第二张图是一个标清楚步骤的流程图&#xff0c;更加清晰。消息的插入环节主要在2步。save部分。主要也是对这个部分就行消息幂等的操作。 前情提要&#xff1a;使用Redis发布 token 以及lua脚本来共同完成消息的幂等 目…

119.乐理基础-五线谱-五线谱的标记

内容参考于&#xff1a;三分钟音乐社 上一个内容&#xff1a;音值组合法&#xff08;二&#xff09; 力度记号&#xff1a;简谱里什么意思&#xff0c;五线谱也完全是什么意思&#xff0c;p越多就越弱&#xff0c;f越多就越强&#xff0c;然后这些渐强、渐弱、sf、fp这些标记…