Stochastic Depth 原理与代码解析

paper:Deep Networks with Stochastic Depth

official implementation:https://github.com/yueatsprograms/Stochastic_Depth

third-party implementation:https://github.com/open-mmlab/mmcv/blob/main/mmcv/cnn/bricks/drop.py

存在的问题

网络深度是模型表达能力的决定性因素,但是非常深的网络也带来了新的挑战:反向传播中的梯度消失、前向传播中特征重用变少、更长的训练时间。 

本文的创新点

本文提出了一种新的训练深度网络的方法,随机深度stochastic depth,在训练阶段随机删除某些层使得网络的总层数变少,既缓解了梯度消失和特征重用减少的问题,又缩短了训练时间。此外和Dropout类似,stochastic depth还起到了正则化的作用,即使在有BN的情况下。用随机深度训练的网络还可以看作不同深度网络的隐式集和ensemble。

方法介绍

Stochastic depth的目的是在训练过程中减小网络的深度,同时在测试过程中保持其不变。\(b_{\ell}\in\{0,1\}\) 是一个伯努利随机变量,表示第 \(\ell^{th}\) ResBlock是active(\(b_{\ell}=1\))还是inactive(\(b_{\ell}=0\)),更进一步,将ResBlock \(\ell\) 的“存活”概率表示为 \(p_{\ell}=Pr(b_{\ell}=1)\)。\(p_{\ell}\) 是唯一的超参,关于 \(p_{\ell}\) 的设置有两种方式,一种是对所有层 \(\ell\) 设置同一 \(p_{\ell}\)。另一种是采用线性递减的方式,对于输入 \(p_{0}=1\) 线性递减到最后一个ResBlock的 \(p_{L}\),如下

作者通过实验得出结论第二种设置方式效果更好。

代码解析

这里的代码来自MMCV,和论文中的实现有出入,其中shape只保留batch_size的值,其它所有维度的值都为1。torch.rand从均匀分布[0, 1]中采样,与keep_prob相加得到random_tensor后最后再向下取整.floor(),假设drop_prob=0.2即有0.2的概率丢弃该层,则有keep_prob=1-drop_prob=0.8的概率保留该层,与[0, 1]均匀分布相加后有0.8的概率大于1向下取整后为1,0.2的概率小于1向下取整后为0。x.div(keep_prob)和dropout中的操作一样,因为训练时随机丢弃部分连接或层,推理时不丢弃,除以keep_prob是为了保持总体期望不变。

def drop_path(x: torch.Tensor,drop_prob: float = 0.,training: bool = False) -> torch.Tensor:"""Drop paths (Stochastic Depth) per sample (when applied in main path ofresidual blocks).We follow the implementationhttps://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py  # noqa: E501"""if drop_prob == 0. or not training:return xkeep_prob = 1 - drop_prob# handle tensors with different dimensions, not just 4D tensors.shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)output = x.div(keep_prob) * random_tensor.floor()return output

实验结果

如图所示,在CIFAR-10和CIFAR-100上,使用stochastic depth在测试机上的误差都更小。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/690408.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【day01】每天三道 java后端面试题:JDK、JRE和JVM | 字节码 | ACID

文章目录 1. JDK, JRE, JVM分别是什么?有什么区别?2. 什么是字节码?采用字节码的最大好处是什么?3. 什么是数据库事务?讲一下事务的ACID特性。 1. JDK, JRE, JVM分别是什么?有什么区别? 答题思路…

深度解析 Transformer 模型:原理、应用与实践指南【收藏版】

深度解析 Transformer 模型:原理、应用与实践指南 1. Transformer 模型的背景与引言2. Transformer 模型的原理解析2.1 自注意力机制(Self-Attention)自注意力机制原理 2.2 多头注意力机制(Multi-Head Attention)多头注…

java+vue_springboot企业设备安全信息系统14jbc

企业防爆安全信息系统采用B/S架构,数据库是MySQL。网站的搭建与开发采用了先进的java进行编写,使用了vue框架。该系统从三个对象:由管理员、人员和企业来对系统进行设计构建。主要功能包括:个人信息修改,对人员管理&am…

C++ 浮点数二分 数的三次方根

给定一个浮点数 n ,求它的三次方根。 输入格式 共一行,包含一个浮点数 n 。 输出格式 共一行,包含一个浮点数,表示问题的解。 注意,结果保留 6 位小数。 数据范围 −10000≤n≤10000 输入样例: 1000.00…

树与二叉树

树与二叉树 文章目录 树与二叉树一、树的概念及结构1.、树的概念2、树的相关概念1.3 树的表示 二、二叉树1.概念2、特殊的二叉树3、二叉树的性质4、二叉树的存储结构 三、二叉树的顺序结构及实现1、二叉树的顺序结构2、堆的概念及结构3、堆的实现 四、二叉树链式结构的实现1、遍…

python统计分析——一元线性回归分析

参考资料:用python动手学统计学 1、导入库 # 导入库 # 用于数值计算的库 import numpy as np import pandas as pd import scipy as sp from scipy import stats # 用于绘图的库 import matplotlib.pyplot as plt import seaborn as sns sns.set() # 用于估计统计…

JAVA常见IO模型 BIO、NIO、AIO总结

BIO Blocking IO 同步阻塞型IO。当系统进行IO读写的时候,会阻塞,直到IO读写完毕。比如调用系统Read后,需要将内核空间的数据读取到用户空间。需要等待内核空间 数据准备,数据就绪,拷贝数据,线程一直处于阻…

IO进程:fread\fwrite图像拷贝,read\write文件拷贝,时间函数

1.使用fread、fwrite实现图片拷贝 程序代码&#xff1a; 1 #include<myhead.h>2 int main(int argc, const char *argv[])3 {4 //判断传入文件个数5 if(argc!3)6 {7 printf("input file error\n");8 printf("usage:./a.out …

【QCA6174】SDX12+QCA6174驱动屏蔽120/124/128信道修改方案

SDX12基线版本 SDX12.LE.1.0-00215-NBOOT.NEFS.PROD-1.39743.1 问题描述 对于欧洲国家来说,默认支持DFS信道,但是有三个信道比较特殊,是天气雷达信道,如下图所示120、124、128,天气雷达信道有个特点就是在信号可以发射之前需要检测静默15min,如果信道自动选择到了天气雷达…

BI 数据分析,数据库,Office,可视化,数据仓库

AIGC ChatGPT 职场案例 AI 绘画 与 短视频制作 PowerBI 商业智能 68集 Mysql 8.0 54集 Oracle 21C 142集 Office 2021实战应用 Python 数据分析实战&#xff0c; ETL Informatica 数据仓库案例实战 51集 Excel 2021实操 100集&#xff0c; Excel 2021函数大全 80集 Excel 2021…

阿里云服务器镜像选择方法详解,这么选就对了!

阿里云服务器镜像怎么选择&#xff1f;云服务器操作系统镜像分为Linux和Windows两大类&#xff0c;Linux可以选择Alibaba Cloud Linux&#xff0c;Windows可以选择Windows Server 2022数据中心版64位中文版&#xff0c;阿里云服务器网aliyunfuwuqi.com来详细说下阿里云服务器操…

HarmonyOS 鸿蒙应用开发(十一、面向鸿蒙开发的JavaScript基础)

ArkTS 是HarmonyOS&#xff08;鸿蒙操作系统&#xff09;原生应用开发的首选语言。它是用于构建用户界面的一种TypeScript方言&#xff0c;扩展了TypeScript以适应HarmonyOS生态系统的UI开发需求。ArkTS 融合了TypeScript的静态类型系统和现代UI框架的设计理念&#xff0c;为开…

练习题解(关于最短路径)

目录 1.租用游艇 2.邮递员送信 3.【模板】单源最短路径&#xff08;标准版&#xff09; 1.租用游艇 P1359 租用游艇 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 输入数据&#xff1a; 3 5 15 7 因为这道题数据不大&#xff0c;所有我们直接使用Floyd 算法。 这道题大…

Linux系统——防火墙Firewalld

目录 一、firewalld介绍 1.归入zone顺序 2.firewalld zone分类 3.预定义服务 二、图形化操作 1.打开firewalld图形化界面 2.以http服务为例&#xff0c;打开httpd服务 3.修改端口号 三、命令行配置 1.基础配置 2.查看现有firewalld设置 3.设置查看默认区 4.添加源…

类(接口)图几种箭头含义

导语 在平时的开发中&#xff0c;难免会遇到画UML图的时候&#xff0c;也就是我们所说的类图&#xff0c;但是UML图中的箭头多种多样&#xff0c;所代表的含义也是各不相同&#xff0c;今天我们就来说说这几种箭头所代表的含义。 1 泛化 概念&#xff1a;泛化表示一个更泛化的元…

CTFshow web(php命令执行 55-59)

web55 <?php /* # -*- coding: utf-8 -*- # Author: Lazzaro # Date: 2020-09-05 20:49:30 # Last Modified by: h1xa # Last Modified time: 2020-09-07 20:03:51 # email: h1xactfer.com # link: https://ctfer.com */ // 你们在炫技吗&#xff1f; if(isset($_GET[…

Jetpack 之Glance+Compose实现一个小组件

Glance&#xff0c;官方对其解释是使用 Jetpack Compose 样式的 API 构建远程 Surface 的布局&#xff0c;通俗的讲就是使用Compose风格的API来搭建小插件布局&#xff0c;其最新版本是2022年2月23日更新的1.0.0-alpha03。众所周知&#xff0c;Compose样式的API与原生差别不小&…

Android逆向学习(七)绕过root检测与smali修改学习

Android逆向学习&#xff08;七&#xff09;绕过root检测与smali修改学习 一、写在前面 这是吾爱破解正己大大教程的第五个作业&#xff0c;然后我的系统还是ubuntu&#xff0c; 这个是剩下作业的完成步骤。 二、任务目标 现在我们已经解决了一些问题&#xff0c;现在剩下的…

matplotlib图例使用案例1.1:在不同行或列的图例上添加title

我们将图例进行行显示或者列显示后&#xff0c;只能想继续赋予不同行或者列不同的title来进行分类。比较简单的方式&#xff0c;就是通过ax.annotate方法添加标签&#xff0c;这样方法复用率比较低&#xff0c;每次使用都要微调ax.annotate的显示位置。比较方便的方法是在案例1…

基于jieba、TfidfVectorizer、LogisticRegression的垃圾邮件分类,模型平均得分为0.98左右(附代码和数据集)

基于jieba、TfidfVectorizer、LogisticRegression的垃圾邮件分类,模型平均得分为0.98左右(附代码和数据集)。 垃圾邮件分类识别是一种常见的文本分类任务,旨在将收件箱中的邮件分为垃圾邮件和非垃圾邮件。以下是一些常用的技术和方法用于垃圾邮件分类识别: 基于规则的过…