Stochastic Depth 原理与代码解析

paper:Deep Networks with Stochastic Depth

official implementation:https://github.com/yueatsprograms/Stochastic_Depth

third-party implementation:https://github.com/open-mmlab/mmcv/blob/main/mmcv/cnn/bricks/drop.py

存在的问题

网络深度是模型表达能力的决定性因素,但是非常深的网络也带来了新的挑战:反向传播中的梯度消失、前向传播中特征重用变少、更长的训练时间。 

本文的创新点

本文提出了一种新的训练深度网络的方法,随机深度stochastic depth,在训练阶段随机删除某些层使得网络的总层数变少,既缓解了梯度消失和特征重用减少的问题,又缩短了训练时间。此外和Dropout类似,stochastic depth还起到了正则化的作用,即使在有BN的情况下。用随机深度训练的网络还可以看作不同深度网络的隐式集和ensemble。

方法介绍

Stochastic depth的目的是在训练过程中减小网络的深度,同时在测试过程中保持其不变。\(b_{\ell}\in\{0,1\}\) 是一个伯努利随机变量,表示第 \(\ell^{th}\) ResBlock是active(\(b_{\ell}=1\))还是inactive(\(b_{\ell}=0\)),更进一步,将ResBlock \(\ell\) 的“存活”概率表示为 \(p_{\ell}=Pr(b_{\ell}=1)\)。\(p_{\ell}\) 是唯一的超参,关于 \(p_{\ell}\) 的设置有两种方式,一种是对所有层 \(\ell\) 设置同一 \(p_{\ell}\)。另一种是采用线性递减的方式,对于输入 \(p_{0}=1\) 线性递减到最后一个ResBlock的 \(p_{L}\),如下

作者通过实验得出结论第二种设置方式效果更好。

代码解析

这里的代码来自MMCV,和论文中的实现有出入,其中shape只保留batch_size的值,其它所有维度的值都为1。torch.rand从均匀分布[0, 1]中采样,与keep_prob相加得到random_tensor后最后再向下取整.floor(),假设drop_prob=0.2即有0.2的概率丢弃该层,则有keep_prob=1-drop_prob=0.8的概率保留该层,与[0, 1]均匀分布相加后有0.8的概率大于1向下取整后为1,0.2的概率小于1向下取整后为0。x.div(keep_prob)和dropout中的操作一样,因为训练时随机丢弃部分连接或层,推理时不丢弃,除以keep_prob是为了保持总体期望不变。

def drop_path(x: torch.Tensor,drop_prob: float = 0.,training: bool = False) -> torch.Tensor:"""Drop paths (Stochastic Depth) per sample (when applied in main path ofresidual blocks).We follow the implementationhttps://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py  # noqa: E501"""if drop_prob == 0. or not training:return xkeep_prob = 1 - drop_prob# handle tensors with different dimensions, not just 4D tensors.shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)output = x.div(keep_prob) * random_tensor.floor()return output

实验结果

如图所示,在CIFAR-10和CIFAR-100上,使用stochastic depth在测试机上的误差都更小。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/690408.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【day01】每天三道 java后端面试题:JDK、JRE和JVM | 字节码 | ACID

文章目录 1. JDK, JRE, JVM分别是什么?有什么区别?2. 什么是字节码?采用字节码的最大好处是什么?3. 什么是数据库事务?讲一下事务的ACID特性。 1. JDK, JRE, JVM分别是什么?有什么区别? 答题思路…

深度解析 Transformer 模型:原理、应用与实践指南【收藏版】

深度解析 Transformer 模型:原理、应用与实践指南 1. Transformer 模型的背景与引言2. Transformer 模型的原理解析2.1 自注意力机制(Self-Attention)自注意力机制原理 2.2 多头注意力机制(Multi-Head Attention)多头注…

探索单片机应用领域:从智能家居到工业自动化

单片机作为一种微型计算机芯片,在智能家居和工业自动化领域有着广泛的应用。以下将从智能家居和工业自动化两个方面分点论述单片机的应用。 智能家居领域: 1. 智能灯光控制: 单片机可以用于控制智能灯光系统,实现灯光的远程控制…

找出盗窃者

找出盗窃者 题目描述: 某地发生了⼀件盗窃案,警察通过排查确定盗窃者必为4个嫌疑⼈的⼀个。 以下为4个嫌疑人的供词: A说:不是我。 B说:是C。 C说:是D。 D说:C在胡说 已知3个人说了真话,1个人…

java+vue_springboot企业设备安全信息系统14jbc

企业防爆安全信息系统采用B/S架构,数据库是MySQL。网站的搭建与开发采用了先进的java进行编写,使用了vue框架。该系统从三个对象:由管理员、人员和企业来对系统进行设计构建。主要功能包括:个人信息修改,对人员管理&am…

C++ 浮点数二分 数的三次方根

给定一个浮点数 n ,求它的三次方根。 输入格式 共一行,包含一个浮点数 n 。 输出格式 共一行,包含一个浮点数,表示问题的解。 注意,结果保留 6 位小数。 数据范围 −10000≤n≤10000 输入样例: 1000.00…

树与二叉树

树与二叉树 文章目录 树与二叉树一、树的概念及结构1.、树的概念2、树的相关概念1.3 树的表示 二、二叉树1.概念2、特殊的二叉树3、二叉树的性质4、二叉树的存储结构 三、二叉树的顺序结构及实现1、二叉树的顺序结构2、堆的概念及结构3、堆的实现 四、二叉树链式结构的实现1、遍…

Jtti:PHP怎么实现Memcached主从复制自动切换

在 PHP 应用中实现 Memcached 主从复制自动切换通常需要结合一些额外的工具和技术来实现。下面是一种可能的方案: 1. 使用 Memcached 主从复制: 首先,您需要设置 Memcached 主从复制,确保主服务器和从服务器之间同步数据。这可以通…

python统计分析——一元线性回归分析

参考资料:用python动手学统计学 1、导入库 # 导入库 # 用于数值计算的库 import numpy as np import pandas as pd import scipy as sp from scipy import stats # 用于绘图的库 import matplotlib.pyplot as plt import seaborn as sns sns.set() # 用于估计统计…

dayjs实现前端消息通知日期格式显示——仿微信消息时间

背景:在做一个消息通知类的需求,在PC端实现消息接收界面,日期显示参考微信聊天界面消息时间提示。具体规则如下: 当天:显示时分 昨天:显示‘昨天时分’ 本周:显示“周几时分” 本周之前&#xf…

LeetCode 36天 | 435.无重叠区域 763.划分字母区间 56.合并区间

435. 无重叠区间 左边排序&#xff0c;右边裁剪为当前最小的 class Solution { public:// 按照左边界排序static bool cmp(vector<int> a, vector<int> b) {return a[0] < b[0];}int eraseOverlapIntervals(vector<vector<int>>& intervals) {…

JAVA常见IO模型 BIO、NIO、AIO总结

BIO Blocking IO 同步阻塞型IO。当系统进行IO读写的时候&#xff0c;会阻塞&#xff0c;直到IO读写完毕。比如调用系统Read后&#xff0c;需要将内核空间的数据读取到用户空间。需要等待内核空间 数据准备&#xff0c;数据就绪&#xff0c;拷贝数据&#xff0c;线程一直处于阻…

CSS之重绘与回流

重绘&#xff08;Repaint&#xff09; 当页面中元素样式的改变并不影响它在文档流中的位置时&#xff08;例如改变颜色、阴影等&#xff09;&#xff0c;浏览器会进行重绘&#xff0c;即重新绘制元素的外观。 回流&#xff08;Reflow&#xff09; 当元素的大小、位置、隐藏等…

IO进程:fread\fwrite图像拷贝,read\write文件拷贝,时间函数

1.使用fread、fwrite实现图片拷贝 程序代码&#xff1a; 1 #include<myhead.h>2 int main(int argc, const char *argv[])3 {4 //判断传入文件个数5 if(argc!3)6 {7 printf("input file error\n");8 printf("usage:./a.out …

【QCA6174】SDX12+QCA6174驱动屏蔽120/124/128信道修改方案

SDX12基线版本 SDX12.LE.1.0-00215-NBOOT.NEFS.PROD-1.39743.1 问题描述 对于欧洲国家来说,默认支持DFS信道,但是有三个信道比较特殊,是天气雷达信道,如下图所示120、124、128,天气雷达信道有个特点就是在信号可以发射之前需要检测静默15min,如果信道自动选择到了天气雷达…

情感分析入门:使用Python和TextBlob进行情感分析

简介 情感分析是自然语言处理领域的一个重要任务&#xff0c;它涉及分析文本中的情感和情绪&#xff0c;如积极、消极、中性等。TextBlob是一个简单易用的自然语言处理工具库&#xff0c;其中包含了情感分析功能。本文将介绍如何使用Python编程语言和TextBlob库进行情感分析&a…

洪泛法:计算机网络中的信息洪流——原理、优化与应用全景解析

洪泛法 - 概述 洪泛法&#xff08;Flooding&#xff09;是计算机网络中一种简单直接的数据传输技术。它不依赖于网络中的路由表或者路径选择算法。在洪泛法中&#xff0c;每个接收到消息的节点将消息复制并发送给除了消息来源外的所有其他节点。这个过程一直重复&#xff0c;直…

GB/T 29418-2023 塑木复合材料挤出型材性能检测

塑木复合材料是指由木质或其他纤维素基材料和热塑性塑料经配混成型加工制成的复合材料&#xff0c;又称为木塑复合材料&#xff0c;塑木复合材料多用于木塑地板&#xff0c;围栏等产品&#xff0c;用于户外花园&#xff0c;公园等场所。 GB/T 29418-2023 塑木复合材料挤出型材…

微信多开(无需关闭软件)优化

C实现微信多开 原理 解除mutex独占 同时改用新的API&#xff0c;不再使用废弃的windows API 源码 #include <aclapi.h> #include <shlwapi.h> #include <windows.h> #include <iostream> #pragma comment(lib, "Shlwapi.lib")static bo…

BI 数据分析,数据库,Office,可视化,数据仓库

AIGC ChatGPT 职场案例 AI 绘画 与 短视频制作 PowerBI 商业智能 68集 Mysql 8.0 54集 Oracle 21C 142集 Office 2021实战应用 Python 数据分析实战&#xff0c; ETL Informatica 数据仓库案例实战 51集 Excel 2021实操 100集&#xff0c; Excel 2021函数大全 80集 Excel 2021…