【深度学习】Dropout、DropPath

一、Dropout

1. 概念

Dropout 在训练阶段会让当前层每个神经元以drop_prob( 0 ≤ drop_prob ≤ 1 0\leq\text{drop\_prob}\leq1 0drop_prob1)的概率失活并停止工作,效果如下图。

测试阶段不会进行Dropout。由于不同批次、不同样本的神经元失活情况不同,测试时枚举所有情况进行推理是不现实的,所以原文使用一种均值近似的方法进行逼近。详情如下图:

如图, w \bold{w} w为一个神经元后的权重。假设该神经元的输出均值为 μ \mu μ,若训练阶段该神经元的存活概率为 p p p,则Dropout使其输出均值变为 p × μ p\times\mu p×μ,为使测试时该神经元输出逼近训练输出,测试阶段该神经元输出会被乘上 p p p以使测试与训练输出均值相同。

简单来说,训练时Dropout按照概率drop_prob使神经元停止工作,测试时所有神经元正常工作,但其输出值要乘上1-drop_prob( p = 1 − drop_prob p=1-\text{drop\_prob} p=1drop_prob)。

不过,我们希望测试代码执行效率尽可能高,即便仅增加一个概率计算也不是我们希望的。所以实际计算时,会在训练阶段给神经元乘上一个缩放因子 1 p \frac{1}{p} p1。这样,训练输出的均值仍为 μ \mu μ,测试则不进行Dropout也不再乘上 p p p而是原样输出。

2. 功能

优势:
Dropout能够提高网络的泛化能力,防止过拟合。解释如下:
(1) 训练阶段每个神经元是相互独立的,仅drop_prob相同,即使是同一批次不同样本失活的神经元也是不同的。所以原文作者将Dropout的操作视为多种模型结构下结果的集成,由于集成方法能够避免过拟合,因此Dropout也能达到同样的效果。
(2) 减少神经元之间的协同性。有些神经元可能会建立与其它节点的固定联系,通过Dropout强迫神经元和随机挑选出来的其它神经元共同工作,减弱了神经元节点间的联合适应性,增强了泛化能力。
劣势:
(1) Dropout减缓了收敛的速度。训练时需要通过伯努利分布生成是否drop每一个神经元的情况,额外的乘法和缩放运算也会增加时间。
(2) Dropout一般用于全连接层,卷积层一般使用BatchNorm来防止过拟合。Dropout与BatchNorm不易兼容,Dropout导致训练过程中每一层输出的方差发生偏移,使得BatchNorm层统计的方差不准确,影响BatchNorm的正常使用。

3. 实现

import torch.nn as nn
import torchclass dropout(nn.Module):def __init__(self, drop_prob):super(dropout, self).__init__()assert 0 <= drop_prob <= 1, 'drop_prob should be [0, 1]'self.drop_prob = drop_probdef forward(self, x):if self.training:keep_prob = 1 - self.drop_probmask = keep_prob + torch.rand(x.shape)mask.floor_()return x.div(keep_prob) * maskelse:return xif __name__ == '__main__':x = torch.randn((8, 768))  # [batch_size, feat_dim],dropout常在全连接层之后,所以我们以一维数据为例drop = dropout(0.1)my_o = drop(x)

二、DropPath

1. 概念

DropPath 在训练阶段将深度学习网络中的多分支结构随机删除,效果如下图:

上图是ViT中的一个模块,多分支体现在ResNet结构的引入。可以看出,DropPath在多分支中起作用对位置有明确的要求,需要放在分支合并之前。此外,DropPath也需要对训练输出进行缩放(乘 1 1 − drop_prob \frac{1}{1-\text{drop\_prob}} 1drop_prob1)以确保测试输出结果的有效性和计算的高效性,这样在测试阶段就不会进行DropPath。

事实上,DropPath功能的实现是按照drop_prob概率将该分支的当前输出全部置0。具体来说,对于某个含有DropPath的分支,该分支输出的一个批次的每个样本都独立的按照drop_prob概率被完全置0或完整保留。

2. 功能

一般可以作为正则化手段加入网络防止过拟合,但会增加网络训练的难度。如果设置的drop_prob过高,模型甚至有可能不收敛。

3. 实现

import torch
import torch.nn as nnclass DropPath(nn.Module):"""随机丢弃该分支上的每个样本"""def __init__(self, drop_prob=None):super(DropPath, self).__init__()self.drop_prob = drop_probdef forward(self, x):if self.drop_prob == 0. or not self.training:return xkeep_prob = 1 - self.drop_probshape = (x.shape[0],) + (1,) * (x.ndim - 1)  # (batch_size, 1, 1, 1)维数与输入保持一致,仅需要batch_size个值mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)mask.floor_()  # 二值化,向下取整用于确定保存哪些样本output = x.div(keep_prob) * maskreturn outputif __name__ == "__main__":x = torch.randn((8, 197, 768))  # [batch_size, num_token, token_dim]drop_path = DropPath(drop_prob=0.5)my_o = drop_path(x)

致谢:

本博客仅做记录使用,无任何商业用途,参考内容如下:
【个人理解向】Dropout和Droppath原理及源码讲解
nn.Dropout、DropPath的理解与pytorch代码
Drop系列正则化

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/1530.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据库管理-第171期 Oracle是用这种方式确保读一致的(20240418)

数据库管理171期 2024-04-18 数据库管理-第171期 Oracle是用这种方式确保读一致的&#xff08;20240418&#xff09;1 基本概念2 用处3 注意事项总结 数据库管理-第171期 Oracle是用这种方式确保读一致的&#xff08;20240418&#xff09; 作者&#xff1a;胖头鱼的鱼缸&#x…

MySQL中explain的用法

执行结果各字段的含义 EXPLAIN SQL语句 如&#xff1a; EXPLAIN SELECT * FROM test 执行结果&#xff1a; 列名描述id在一个大的查询语句中每个SELECT关键字都对应一个 唯一的idselect_typeSELECT关键字对应的那个查询的类型table表名partitions匹配的分区信息type针对单表…

P2P面试题

1&#xff09;描述一下你的项目流程以及你在项目中的职责&#xff1f; 一个借款产品的发布&#xff0c;投资人购买&#xff0c;借款人还款的一个业务流程&#xff0c;我主要负责测注册&#xff0c;登录&#xff0c;投资理财这三个模块 2&#xff09;你是怎么测试投资模块的&am…

HttpServlet,ServletContext,Listener它仨的故事

1.HttpServlet。 听起来是不是感觉像是个上古词汇&#xff0c;是不是没有阅读下去的兴趣了&#xff1f;Tomcat知道吧&#xff0c;它就是一个servlet容器&#xff0c;当用户向服务器发送一个HTTP请求时&#xff0c;Servlet容器&#xff08;如Tomcat&#xff09;会根据其配置找到…

overflow(溢出)4个属性值,水平/垂直溢出,文字超出显示省略号的详解

你好&#xff0c;我是云桃桃。 一个希望帮助更多朋友快速入门 WEB 前端的程序媛。 云桃桃-大专生&#xff0c;一枚程序媛&#xff0c;感谢关注。回复 “前端基础题”&#xff0c;可免费获得前端基础 100 题汇总&#xff0c;回复 “前端工具”&#xff0c;可获取 Web 开发工具合…

解析 IP(IPv4)地址

IPv 4 地址 一、组成二、IPv4 的分类三、子网掩码四、特殊的地址五、私有 IP 地址六、全局 IP 地址七、私有 IP 地址和全局 IP 地址的关系八、广播地址九、网络地址十、IP 地址个数计算十一、查看电脑的 IP 地址&#xff08;window&#xff09;十二、手动设置电脑的 IP 地址 为…

华为Pura 70系列,一种关于世界之美的可能

1874年&#xff0c;莫奈创作了《印象日出》的油画&#xff0c;在艺术界掀起了一场革命。当时的主流艺术&#xff0c;是追求细节写实&#xff0c;追求场面宏大的学院派。他们称莫奈等人是“印象派”&#xff0c;认为莫奈的画追求光影表达&#xff0c;追求描绘抽象的意境&#xf…

DRF: 序列化器、View、APIView、GenericAPIView、Mixin、ViewSet、ModelViewSet的源码解析

前言&#xff1a;还没有整理&#xff0c;后续有时间再整理&#xff0c;目前只是个人思路&#xff0c;文章较乱。 注意路径匹配的“/” 我们的url里面加了“/”&#xff0c;但是用apifox等非浏览器的工具发起请求时没有加“/”&#xff0c;而且还不是get请求&#xff0c;那么这…

天才简史——Sylvain Calinon

一、研究方向 learning from demonstration&#xff08;LfD&#xff09;领域的专家&#xff0c;机器人红宝书&#xff08;Springer handbook of robotics&#xff09;Robot programming by demonstration章节的合作者。主要研究兴趣包括&#xff1a; 机器人学习、最优控制、几…

[数据结构]——排序——插入排序

目录 ​编辑 1 .插入排序 1.基本思想&#xff1a; 2.直接插入排序&#xff1a; ​编辑 1.代码实现 2.直接插入排序的特性总结&#xff1a; 3.希尔排序( 缩小增量排序 ) 1.预排序 2.预排序代码 3.希尔排序代码 4.希尔排序的特性总结&#xff1a; 1 .插入排序 1.基本思…

从头开始构建自己的 GPT 大型语言模型

图片来源&#xff1a; Tatev Aslanyan 一、说明 我们将使用 PyTorch 从头开始构建生成式 AI、大型语言模型——包括嵌入、位置编码、多头自注意、残差连接、层归一化&#xff0c;Baby GPT 是一个探索性项目&#xff0c;旨在逐步构建类似 GPT 的语言模型。在这个项目中&#xff…

Linux 文件描述符

1、文件描述符 程序和进程的区别&#xff1a; 1、test.c&#xff1a;是一个程序&#xff0c;只占用磁盘空间&#xff0c;不占用内存空间 2、可执行文件 test&#xff1a;是一个程序&#xff0c;只占用磁盘空间&#xff0c;不占用内存空间 3、启动 可执行文件 test&#xff…

强固型工业电脑在码头智能化,龙门吊/流机车载电脑的行业应用

码头智能化行业应用 对码头运营来说&#xff0c;如何优化集装箱从船上到码头堆场到出厂区的各个流程以及达到提高效率。 降低成本的目的&#xff0c;是码头营运获利最重要的议题。为了让集装箱码头客户能够安心使用TOS系统来调度指挥码头上各种吊车、叉车、拖车和人员&#xf…

第一届 _帕鲁杯_ - CTF挑战赛

Mis 签到 题目附件&#xff1a; 27880 30693 25915 21892 38450 23454 39564 23460 21457 36865 112 108 98 99 116 102 33719 21462 21069 27573 102 108 97 103 20851 27880 79 110 101 45 70 111 120 23433 20840 22242 38431 22238 22797 112 108 98 99 116 102 33719 2…

matplotlib从起点出发(15)_Tutorial_15_blitting

0 位图传输技术与快速渲染 Blitting&#xff0c;即位图传输、块传输技术是栅格图形化中的标准技术。在Matplotlib的上下文中&#xff0c;该技术可用于&#xff08;大幅度&#xff09;提高交互式图形的性能。例如&#xff0c;动画和小部件模块在内部使用位图传输。在这里&#…

揭开ChatGPT面纱(3):使用OpenAI进行文本情感分析(embeddings接口)

文章目录 一、embeddings接口解析二、代码实现1.数据集dataset.csv2.代码3.运行结果 openai版本1.6.1 本系列博客源码仓库&#xff1a;gitlab&#xff0c;本博客对应文件夹03 在这一篇博客中我将使用OpenAI的embeddings接口判断21条服装评价是否是好评。 首先来看实现思路&am…

Llama3新一代 Llama模型

最近&#xff0c;Meta 发布了 Llama3 模型&#xff0c;从发布的数据来看&#xff0c;性能已经超越了 Gemini 1.5 和 Claud 3。 Llama 官网说&#xff0c;他们未来是要支持多语言和多模态的&#xff0c;希望那天赶紧到来。 未来 Llama3还将推出一个 400B大模型&#xff0c;目前…

计算机网络——数据链路层(介质访问控制)

计算机网络——数据链路层&#xff08;介质访问控制&#xff09; 介质访问控制静态划分信道动态划分信道ALOHA协议纯ALOHA&#xff08;Pure ALOHA&#xff09;原理特点 分槽ALOHA&#xff08;Slotted ALOHA&#xff09;原理特点 CSMA协议工作流程特点 CSMA-CD 协议工作原理主要…

JVM虚拟机(十二)ParallelGC、CMS、G1垃圾收集器的 GC 日志解析

目录 一、如何开启 GC 日志&#xff1f;二、GC 日志分析2.1 PSPO 日志分析2.2 ParNewCMS 日志分析2.3 G1 日志分析 三、GC 发生的原因3.1 Allocation Failure&#xff1a;新生代空间不足&#xff0c;触发 Minor GC3.2 Metadata GC Threshold&#xff1a;元数据&#xff08;方法…

【数据结构|C语言版】算法效率和复杂度分析

前言1. 算法效率2. 大O的渐进表示法3. 时间复杂度3.1 时间复杂度概念3.2 时间复杂度计算举例 4. 空间复杂度4.1 空间复杂度的概念4.2 空间复杂度计算举例 5. 常见复杂度对比结语 ↓ 个人主页&#xff1a;C_GUIQU 个人专栏&#xff1a;【数据结构&#xff08;C语言版&#xff09…