深度学习中的残差网络、加权残差连接(WRC)与跨阶段部分连接(CSP)详解

随着深度学习技术的不断发展,神经网络架构变得越来越复杂,而这些复杂网络在训练时常常遇到梯度消失、梯度爆炸以及计算效率低等问题。为了克服这些问题,研究者们提出了多种网络架构,包括 残差网络(ResNet)加权残差连接(WRC)跨阶段部分连接(CSP)

本文将详细介绍这三种网络架构的基本概念、工作原理以及如何在 PyTorch 中实现它们。我们会通过代码示例来展示每个技术的实现方式,并重点讲解其中的核心部分。

目录

一、残差网络(ResNet)

1.1 残差网络的背景与原理

1.2 残差块的实现

重点

二、加权残差连接(WRC)

2.1 WRC的提出背景

2.2 WRC的实现

重点

三、跨阶段部分连接(CSP)

3.1 CSP的提出背景

3.2 CSP的实现

重点

四、总结


一、残差网络(ResNet)

1.1 残差网络的背景与原理

有关于残差网络,详情可以查阅以下博客,更为详细与新手向:

YOLO系列基础(三)从ResNet残差网络到C3层-CSDN博客

深层神经网络的训练常常遭遇梯度消失或梯度爆炸的问题,导致训练效果不好。为了解决这一问题,微软的何凯明等人提出了 残差网络(ResNet),引入了“跳跃连接(skip connections)”的概念,使得信息可以直接绕过某些层传播,从而避免了深度网络训练中的问题。

在传统的神经网络中,每一层都试图学习输入到输出的映射。但在 ResNet 中,网络不再直接学习从输入到输出的映射,而是学习输入与输出之间的“残差”,即

H(x) = F(x) + x

其中 F(x) 是网络学到的残差部分,x 是输入。

这种方式显著提升了网络的训练效果,并且让深层网络的训练变得更加稳定。

1.2 残差块的实现

下面是一个简单的残差块实现,它包括了两层卷积和一个跳跃连接。跳跃连接帮助保持梯度的流动,避免深层网络中的梯度消失问题。

图例如下:

代码示例如下:

import torch
import torch.nn as nn
import torch.nn.functional as F# 定义残差块
class ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels):super(ResidualBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)self.bn1 = nn.BatchNorm2d(out_channels)self.bn2 = nn.BatchNorm2d(out_channels)# 如果输入和输出的通道数不同,则使用1x1卷积调整尺寸if in_channels != out_channels:self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)else:self.shortcut = nn.Identity()def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))  # 第一层卷积后激活out = self.bn2(self.conv2(out))        # 第二层卷积out += self.shortcut(x)                # 残差连接return F.relu(out)                     # ReLU激活# 构建ResNet
class ResNet(nn.Module):def __init__(self, num_classes=10):super(ResNet, self).__init__()self.layer1 = ResidualBlock(3, 64)self.layer2 = ResidualBlock(64, 128)self.layer3 = ResidualBlock(128, 256)self.fc = nn.Linear(256, num_classes)def forward(self, x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = F.adaptive_avg_pool2d(x, (1, 1))  # 全局平均池化x = torch.flatten(x, 1)                # 展平x = self.fc(x)                         # 全连接层return x# 示例:构建一个简单的 ResNet
model = ResNet(num_classes=10)
print(model)
重点
  1. 残差连接的实现:在 ResidualBlock 类中,out += self.shortcut(x) 实现了输入与输出的加法操作,这是残差学习的核心。
  2. 处理输入和输出通道数不一致的情况:如果输入和输出的通道数不同,通过使用 1x1 卷积调整输入的维度,确保加法操作能够进行。

二、加权残差连接(WRC)

2.1 WRC的提出背景

传统的残差网络通过简单的跳跃连接将输入和输出相加,但在某些情况下,不同层的输出对最终结果的贡献是不同的。为了让网络更灵活地调整各层贡献,加权残差连接(WRC) 引入了可学习的权重。公式如下

H(x) =\alpha F(x) + \beta x

其中 F(x) 是网络学到的残差部分,x 是输入,\alpha 和 \beta是权重。

WRC通过为每个残差连接引入可学习的权重 \alpha\beta,使得网络能够根据任务需求自适应地调整每个连接的贡献。

2.2 WRC的实现

以下是 WRC 的实现代码,我们为每个残差连接引入了权重参数 alphabeta,这些参数通过训练进行优化。

图例如下:

可以看到,加权残差快其实就是给残差网络的两条分支加个权而已 

代码示例如下: 

class WeightedResidualBlock(nn.Module):def __init__(self, in_channels, out_channels):super(WeightedResidualBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)self.bn1 = nn.BatchNorm2d(out_channels)self.bn2 = nn.BatchNorm2d(out_channels)# 权重初始化self.alpha = nn.Parameter(torch.ones(1))  # 可学习的权重self.beta = nn.Parameter(torch.ones(1))   # 可学习的权重# 如果输入和输出的通道数不同,则使用1x1卷积调整尺寸if in_channels != out_channels:self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)else:self.shortcut = nn.Identity()def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))# 加权残差连接:使用可学习的权重 alpha 和 betaout = self.alpha * out + self.beta * self.shortcut(x)return F.relu(out)# 示例:构建一个加权残差块
model_wrc = WeightedResidualBlock(3, 64)
print(model_wrc)
重点
  1. 可学习的权重 alphabeta:我们为残差块中的两个加法项(即残差部分和输入部分)引入了可学习的权重。通过训练,这些权重可以自动调整,使网络能够根据任务需求更好地融合输入和输出。

  2. 加权残差连接的实现:在 forward 方法中,out = self.alpha * out + self.beta * self.shortcut(x) 表示加权残差连接,其中 alphabeta 是可学习的参数。

三、跨阶段部分连接(CSP)

3.1 CSP的提出背景

虽然 ResNet 和 WRC 提供了有效的残差学习和信息融合机制,但在一些更复杂的网络中,信息的传递依然面临冗余和计算开销较大的问题。为了解决这一问题,跨阶段部分连接(CSP) 提出了更加高效的信息传递方式。CSP通过选择性地传递部分信息而不是所有信息,减少了计算量并保持了模型的表达能力。

3.2 CSP的实现

CSP通过分割输入特征,并在不同阶段进行不同的处理,从而减少冗余的信息传递。下面是 CSP 的实现代码。

CSP思想图例如下:

特征分割(Feature Splitting):CSP通过分割输入特征图,并将分割后的特征图分别送入不同的子网络进行处理。一般来说,一条分支的子网络会比较简单,一条分支的自网络则是原来主干网络的一部分。

重点
  1. 部分特征选择性连接:将输入特征分为两部分。每部分特征单独经过卷积处理后,通过 torch.cat() 进行拼接,形成最终的输出。
  2. 跨阶段部分连接:CSP块通过分割输入特征并在不同阶段处理,有效地减少了计算开销,并且保持了网络的表达能力。

四、总结

本文介绍了 残差网络(ResNet)加权残差连接(WRC)跨阶段部分连接(CSP) 这三种网络架构。

finally,求赞求赞求赞~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/63927.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Pytorch | 从零构建EfficientNet对CIFAR10进行分类

Pytorch | 从零构建EfficientNet对CIFAR10进行分类 CIFAR10数据集EfficientNet设计理念网络结构性能特点应用领域发展和改进 EfficientNet结构代码详解结构代码代码详解MBConv 类初始化方法前向传播 forward 方法 EfficientNet 类初始化方法前向传播 forward 方法 训练过程和测…

音视频入门基础:MPEG2-TS专题(20)——ES流简介

《T-REC-H.222.0-202106-S!!PDF-E.pdf》第27页对ES进行了定义。ES流是PES packets(PES包)中编码的视频、编码的音频或其他编码的比特流。一个ES流(elementary stream)在具有且只有一个stream_id的PES packets序列中携带&#xff1…

天水月亮圈圈:舌尖上的历史与传承

在天水甘谷县,有一种美食如同夜空中的明月,散发着独特的魅力,它就是有着百年历史的月亮圈圈。月亮圈圈原名甘谷酥圈圈,据传,由大像山镇蒋家庄一姓李的厨师创制而成,后经王明玖等厨师的光大传承,…

YOLOv11融合[CVPR2023]FFTformer中的FSAS模块

YOLOv11v10v8使用教程: YOLOv11入门到入土使用教程 YOLOv11改进汇总贴:YOLOv11及自研模型更新汇总 《Efficient Frequency Domain-based Transformers for High-Quality Image Deblurring》 一、 模块介绍 论文链接:https://arxiv.org/abs…

java如何使用poi-tl在word模板里渲染多张图片

1、poi-tl官网地址 http://deepoove.com/poi-tl/ 2、引入poi-tl的依赖 <dependency><groupId>com.deepoove</groupId><artifactId>poi-tl</artifactId><version>1.12.1</version></dependency>3、定义word模板 释义&#xf…

《信管通低代码信息管理系统开发平台》Windows环境安装说明

1 简介 《信管通低代码信息管理系统应用平台》提供多环境软件产品开发服务&#xff0c;包括单机、局域网和互联网。我们专注于适用国产硬件和操作系统应用软件开发应用。为事业单位和企业提供行业软件定制开发&#xff0c;满足其独特需求。无论是简单的应用还是复杂的系统&…

8K+Red+Raw+ProRes422分享5个影视级视频素材网站

Hello&#xff0c;大家好&#xff0c;我是后期圈&#xff01; 在视频创作中&#xff0c;电影级的视频素材能够为作品增添专业质感&#xff0c;让画面更具冲击力。无论是广告、电影短片&#xff0c;还是品牌宣传&#xff0c;高质量的视频素材都是不可或缺的资源。然而&#xff…

Git远程仓库的使用

一.远程仓库注册 1.github&#xff1a;GitHub Build and ship software on a single, collaborative platform GitHub 2.gitee&#xff1a;GitHub Build and ship software on a single, collaborative platform GitHub github需要使用魔法&#xff0c;而gitee是国内的仓…

Echarts连接数据库,实时绘制图表详解

文章目录 Echarts连接数据库&#xff0c;实时绘制图表详解一、引言二、步骤一&#xff1a;环境准备与数据库连接1、环境搭建2、数据库连接 三、步骤二&#xff1a;数据获取与处理1、查询数据库2、数据处理 四、步骤三&#xff1a;ECharts图表配置与渲染1、配置ECharts选项2、动…

【Java基础面试题038】栈和队列在Java中的区别是什么?

回答重点 栈&#xff08;Stack&#xff09;&#xff1a;遵循后进先出&#xff08;LIFO&#xff0c;Last In&#xff0c;First Out&#xff09;原则。即&#xff0c;最后插入的元素最先被移除。主要操作包括push&#xff08;入栈&#xff09;和pop&#xff08;出栈&#xff09;…

idea2024创建JavaWeb项目以及配置Tomcat详解

今天呢&#xff0c;博主的学习进度也是步入了JavaWeb&#xff0c;目前正在逐步杨帆旗航&#xff0c;迎接全新的狂潮海浪。 那么接下来就给大家出一期有关JavaWeb的配置教学&#xff0c;希望能对大家有所帮助&#xff0c;也特别欢迎大家指点不足之处&#xff0c;小生很乐意接受正…

由于这些关键原因,我总是手边有一台虚拟机

概括 虚拟机提供了一个安全的环境来测试有风险的设置或软件,而不会影响您的主系统。设置和保存虚拟机非常简单,无需更改主要设备即可方便地访问多个操作系统。运行虚拟机可能会占用大量资源,但现代 PC 可以很好地处理它,为实验和工作流程优化提供无限的可能性。如果您喜欢使…

【FPGA】ISE13.4操作手册,新建工程示例

关注作者了解更多 我的其他CSDN专栏 求职面试 大学英语 过程控制系统 工程测试技术 虚拟仪器技术 可编程控制器 工业现场总线 数字图像处理 智能控制 传感器技术 嵌入式系统 复变函数与积分变换 单片机原理 线性代数 大学物理 热工与工程流体力学 数字信号处…

python环境中阻止相关库的自动更新

找到conda中的Python虚拟环境位置 这里以conda中的pytorch虚拟环境为例&#xff08;Python环境位置&#xff09;&#xff0c;在.conda下的envs中进入pytorch下的conda-meta路径下 新建一个空白的pinned文档 右键点击桌面或文件资源管理器中的空白处&#xff0c;选择“新建” …

重温设计模式--外观模式

文章目录 外观模式&#xff08;Facade Pattern&#xff09;概述定义 外观模式UML图作用 外观模式的结构C 代码示例1C代码示例2总结 外观模式&#xff08;Facade Pattern&#xff09;概述 定义 外观模式是一种结构型设计模式&#xff0c;它为子系统中的一组接口提供了一个统一…

uniapp 微信小程序 页面部分截图实现

uniapp 微信小程序 页面部分截图实现 ​ 原理都是将页面元素画成canvas 然后将canvas转化为图片&#xff0c;问题是我页面里边本来就有一个canvas&#xff0c;ucharts图画的canvas我无法画出这块。 ​ 想了一晚上&#xff0c;既然canvas最后能转化为图片&#xff0c;那我直接…

Flutter 基础知识总结

1、Flutter 介绍与环境安装 为什么选择 Dart&#xff1a; 基于 JIT 快速开发周期&#xff1a;Flutter 在开发阶段采用 JIT 模式&#xff0c;避免每次改动都进行编译&#xff0c;极大的节省了开发时间基于 AOT 发布包&#xff1a;Flutter 在发布时可以通过 AOT 生成高效的 ARM…

Jenkins 持续集成部署

Jenkins的安装与部署 前言 当我们在实施一个项目时&#xff0c;从新代码中获得反馈的速度越快&#xff0c;问题越早得到解决&#xff0c;获得反馈的一种常见方法是在新代码之后运行测试&#xff0c;但这就导致了当代码正在编译并且正在运行测试时&#xff0c;开发人员无法在测…

Pytorch | 利用BIM/I-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击

Pytorch | 利用BIM/I-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击 CIFAR数据集BIM介绍基本原理算法流程 BIM代码实现BIM算法实现攻击效果 代码汇总bim.pytrain.pyadvtest.py 之前已经针对CIFAR10训练了多种分类器&#xff1a; Pytorch | 从零构建AlexNet对CIFAR10进行分类 Py…

OpenGL —— 2.6.1、绘制一个正方体并贴图渲染颜色(附源码,glfw+glad)

源码效果 C++源码 纹理图片 需下载stb_image.h这个解码图片的库,该库只有一个头文件。 具体代码: vertexShader.glsl #version