暂退法(丢弃法)

       在深度学习中,丢弃法(Dropout)是一种常用的正则化技术,旨在减少模型的过拟合现象,可能会比之前的权重衰减(Weight Decay)效果更好。通过在训练过程中随机丢弃一部分神经元,可以有效地减少神经网络中的参数依赖性,增强模型的泛化能力。

一、丢弃法原理介绍

1、动机

       一个好的模型需要对输入数据的扰动鲁棒,也就是说,不管图片加入多少噪音,我也是能看清楚的。使用有噪音的数据等价于Tikhonov正则,正则使得权重值范围不会太大,避免一定的过拟合。与之前加入的噪音不一样,之前是固定噪音,丢弃法是随机噪音,丢弃法不是在输入加噪音,而是在层之间加入噪音,所以丢弃法也算是一个正则。

2、无偏差的加入噪音

       假如$x$是上一层到下一层的某一个输出(上一层输出向量的某一个元素)的话,对$x$加入噪音得到$x'$,我们希望加入噪音后不改变期望,即:

$ E\left[ x' \right] =x $

       丢弃法对上一层输出向量的每一个元素做如下扰动:

       此时这个元素的期望是不变的:

$ E\left[ x' \right] =p\cdot 0+\left( 1-p \right) \frac{x}{1-p}=x $

3、丢弃法的使用

       通常将丢弃法作用在隐藏全连接层的输出上。如图"MLP with one hidden layer"带有1个隐藏层和5个隐藏单元的多层感知机。当我们将暂退法应用到隐藏层,以$p$的概率将隐藏单元置为零时,结果可以看作一个只包含原始神经元子集的网络。比如在图"Hidden layer after dropout"中,删除了$h_2$$h_5$,因此输出的计算不再依赖于$h_2$$h_5$,并且它们各自的梯度在执行反向传播时也会消失。这样,输出层的计算不能过度依赖于$h_1, \ldots, h_5$的任何一个元素。

4、推理中的丢弃法

5、总结

  • 丢弃法将一些输出项随机置0来控制模型复杂度
  • 常作用在多层感知机的隐藏层输出上
  • 丢弃概率是控制模型复杂度的超参数

二、暂退法从零开始实现

1、定义dropout函数

       要实现单层的暂退法函数,我们从均匀分布$U[0, 1]$中抽取样本,样本数与这层神经网络的维度一致。然后我们保留那些对应样本大于$p$的节点,把剩下的丢弃。

       在下面的代码中,我们实现 `dropout_layer` 函数,该函数以`dropout`的概率丢弃张量输入`X`中的元素,如上所述重新缩放剩余部分:将剩余部分除以`1.0-dropout`。

import torch
from torch import nn
from d2l import torch as d2ldef dropout_layer(X, dropout):assert 0 <= dropout <= 1# 在本情况中,所有元素都被丢弃if dropout == 1:return torch.zeros_like(X)# 在本情况中,所有元素都被保留if dropout == 0:return X# torch.rand(X.shape)生成了一个与输入张量X相同形状的随机数张量,其中的元素值在[0, 1)的区间内均匀分布。# (torch.rand(X.shape) > dropout)执行了一个逻辑判断,将随机数张量中大于dropout的元素置为True,小于等于dropout的元素置为False。# .float()将布尔型张量转换为浮点型张量,将True转换为1.0,将False转换为0.0。mask = (torch.rand(X.shape) > dropout).float()return mask * X / (1.0 - dropout)

       我们可以通过下面几个例子来测试`dropout_layer`函数。我们将输入`X`通过暂退法操作,暂退概率分别为0、0.5和1。

X= torch.arange(16, dtype = torch.float32).reshape((2, 8))
print(X)
print(dropout_layer(X, 0.))
print(dropout_layer(X, 0.5))
print(dropout_layer(X, 1.))
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],[ 8.,  9., 10., 11., 12., 13., 14., 15.]])
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],[ 8.,  9., 10., 11., 12., 13., 14., 15.]])
tensor([[ 0.,  2.,  0.,  6.,  0.,  0.,  0., 14.],[16., 18.,  0., 22.,  0., 26., 28., 30.]])
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0.]])

2、定义模型参数

       同样,我们使用Softmax回归中引入的Fashion-MNIST数据集(不懂的可以看链接里面的文章)。我们定义具有两个隐藏层的多层感知机,每个隐藏层包含256个单元。

num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256

3、定义模型

       我们可以将暂退法应用于每个隐藏层的输出(在激活函数之后),并且可以为每一层分别设置暂退概率:常见的技巧是在靠近输入层的地方设置较低的暂退概率。下面的模型将第一个和第二个隐藏层的暂退概率分别设置为0.2和0.5,并且暂退法只在训练期间有效。

dropout1, dropout2 = 0.2, 0.5    # 在靠近输入层的地方设置较低的暂退概率,因此dropout1设为0.2class Net(nn.Module):def __init__(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2,is_training = True):super(Net, self).__init__()self.num_inputs = num_inputsself.training = is_trainingself.lin1 = nn.Linear(num_inputs, num_hiddens1)self.lin2 = nn.Linear(num_hiddens1, num_hiddens2)self.lin3 = nn.Linear(num_hiddens2, num_outputs)self.relu = nn.ReLU()def forward(self, X):H1 = self.relu(self.lin1(X.reshape((-1, self.num_inputs))))# 只有在训练模型时才使用dropoutif self.training == True:# 在第一个全连接层之后添加一个dropout层H1 = dropout_layer(H1, dropout1)H2 = self.relu(self.lin2(H1))if self.training == True:# 在第二个全连接层之后添加一个dropout层H2 = dropout_layer(H2, dropout2)out = self.lin3(H2)return outnet = Net(num_inputs, num_outputs, num_hiddens1, num_hiddens2)

4、训练和测试

       这类似于前面描述的多层感知机训练和测试。

num_epochs, lr, batch_size = 10, 0.5, 256
loss = nn.CrossEntropyLoss(reduction='none')
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

三、暂退法简洁实现

1、定义模型

       对于深度学习框架的高级API,我们只需在每个全连接层之后添加一个`Dropout`层,将暂退概率作为唯一的参数传递给它的构造函数。在训练时,`Dropout`层将根据指定的暂退概率随机丢弃上一层的输出(相当于下一层的输入)。在测试时,`Dropout`层仅传递数据。

net = nn.Sequential(nn.Flatten(),nn.Linear(784, 256),nn.ReLU(),# 在第一个全连接层之后添加一个dropout层nn.Dropout(dropout1),nn.Linear(256, 256),nn.ReLU(),# 在第二个全连接层之后添加一个dropout层nn.Dropout(dropout2),nn.Linear(256, 10))def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights)

2、训练和测试

       接下来,我们对模型进行训练和测试。

trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

四、总结

  • 暂退法在前向传播过程中,计算每一内部层的同时丢弃一些神经元。
  • 暂退法可以避免过拟合,它通常与控制权重向量的维数和大小结合使用的。
  • 暂退法将活性值$h$替换为具有期望值$h$的随机变量。
  • 暂退法仅在训练期间使用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/228226.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python实验项目9 :网络爬虫与自动化

实验 1&#xff1a;爬取网页中的数据。 要求&#xff1a;使用 urllib 库和 requests 库分别爬取 http://www.sohu.com 首页的前 360 个字节的数据。 # 要求&#xff1a;使用 urllib 库和 requests 库分别爬取 http://www.sohu.com 首页的前 360 个字节的数据。 import urllib.r…

微服务最佳实践:构建可扩展且高效的系统

微服务架构彻底改变了现代软件开发&#xff0c;提供了无与伦比的敏捷性、可扩展性和可维护性。然而&#xff0c;有效实施微服务需要深入了解最佳实践&#xff0c;以充分发挥微服务的潜力&#xff0c;同时避免常见的陷阱。在这份综合指南中&#xff0c;我们将深入研究微服务的关…

跟iPhone类似,不同品牌的手机、电脑随时使用“隔空投送”功能!如何开启?

iPhone的隔空投送是一个很受欢迎的功能。打开一个 App&#xff0c;然后轻点“共享”或“共享”按钮&#xff0c;再点击隔空投送&#xff0c;就可以分享图片、视频、文件出去。 然而&#xff0c;如果你用的不是苹果的产品&#xff0c;iPhone的隔空投送功能就有了“隔阂”。 不过…

CountDownLatch实战应用——实现异步多线程业务处理,异常情况回滚全部子线程

&#x1f60a; 作者&#xff1a; 一恍过去 &#x1f496; 主页&#xff1a; https://blog.csdn.net/zhuocailing3390 &#x1f38a; 社区&#xff1a; Java技术栈交流 &#x1f389; 主题&#xff1a; CountDownLatch实战应用——实现异步多线程业务处理&#xff0c;异常情…

H266/VVC编码标准介绍

视频编码标准 多样的视频应用催生了多种的视频编码方法。为了使编码后的码流能够在大范围内通用和规范&#xff0c;从20世纪80年代开始&#xff0c;国际组织就开始对视频编码建立国际标准。 什么是视频编码标准&#xff1a; 视频编码标准只规定了码流的语法语义和解码器&#…

Appium —— 初识移动APP自动化测试框架Appium

说到移动APP自动化测试&#xff0c;代表性的测试框架非Appium莫属&#xff0c;从今天开始我们将从APP结构解析、Appium框架学习、安卓/iOS自动化测试实战、自动遍历回归测试、自动化测试平台及持续集成&#xff0c;多个维度一起由浅入深的学废Appium 今天我们先来初步认识Appi…

【消息中间件】Rabbitmq的基本要素、生产和消费、发布和订阅

原文作者&#xff1a;我辈李想 版权声明&#xff1a;文章原创&#xff0c;转载时请务必加上原文超链接、作者信息和本声明。 文章目录 前言一、消息队列的基本要素1.队列:queue2.交换机:exchange3.事件:routing_key4.任务:task 二、生产消费模式1.安装pika2.模拟生产者进程3.模…

【HR培训】行为反馈复盘,走出舒适区--20231217

行为反馈复盘&#xff0c;走出舒适区–鱼缸会议 要点&#xff1a;在于建立平等、透明、敢说的反馈环境&#xff0c;不打断、不争论 鱼缸会议流程 导入——入缸——反馈——承诺——关闭 步骤1&#xff1a;导入 目的&#xff1a;平等、透明、敢说的反馈 人员&#xff1a;主…

maui中实现加载更多 RefreshView跟ListView(1)

效果如图&#xff1a; MainPage.xaml.cs: using System; using System.Collections.ObjectModel; using System.Threading.Tasks; using Microsoft.Maui.Controls; using Microsoft.Maui.Controls.Xaml; using System.ComponentModel; using System.Runtime.CompilerServices…

计算机网络基础——网线认识与制作,线缆类型、线序、端接标准及注意事项

一、引言 网线制作是网络基础知识中不可或缺的。网络传输过程中&#xff0c;网线的质量和制作方法都会直接影响传输的速度和稳定性。本文将详细介绍网线制作的基础知识、线缆类型、线序、端接标准及注意事项。希望通过本文&#xff0c;读者能够更好地了解和掌握网线制作的方法…

AMD 自适应和嵌入式产品技术日

概要 时间&#xff1a;2023年11月28日 地点&#xff1a;北京朝阳新云南皇冠假日酒店 主题内容&#xff1a;AMD自适应和嵌入式产品的更新&#xff0c;跨越 云、边、端的AI解决方案&#xff0c;赋能智能制造的机器视觉与机器人等热门话题。 注&#xff1a;本文重点关注FPGA&a…

ASP.NET MVC实战之权限拦截Authorize使用

1&#xff0c;具体的实现方法代码如下 public class CustomAuthorizeAttribute : FilterAttribute, IAuthorizationFilter{/// <summary>/// 如果需要验证权限的时候&#xff0c;就执行进来/// </summary>/// <param name"filterContext"></par…

Ubuntu系统入门指南:基础操作和使用

Ubuntu系统的基础操作和使用 一、引言二、安装Ubuntu系统三、Ubuntu系统的基础操作3.1、界面介绍3.2、应用程序的安装和卸载3.3、文件管理3.4、系统设置 四、Ubuntu系统的日常使用4.1、使用软件中心4.2、浏览器的使用和网络连接设置4.3、邮件客户端的配置和使用4.4、文件备份和…

HTML5+CSS3小实例:3D发光切换按钮效果

目录 一、运行效果 图片效果 二、项目概述 三、开发环境 四、实现步骤及代码 1.创建空文件夹 2.完成页面内容 3.完成css样式 五、项目总结 六、源码获取 一、运行效果 图片效果 二、项目概述 这个项目是一个演示3D发光切换按钮效果的网页。按钮由一个开关和一个指…

Linux之进程(四)(进程地址空间)

目录 一、程序地址空间 二、进程地址空间 1、概念 2、写时拷贝 3、为什么要有进程地址空间 四、总结 一、程序地址空间 我们先来看看下面这张图。这张图是我们在学习语言时就见到过的内存区域划分图。 下面我们在Linux下看一看内存区域是不是也是这么划分的。 可见在Li…

圣诞树绘制合集-python绘制

使用Python绘制迷人的圣诞树 引言 随着圣诞节的临近&#xff0c;我们都希望以各种方式庆祝这个欢乐的节日。作为一名编程爱好者&#xff0c;你有没有想过用Python来创造节日的气氛呢&#xff1f;在这篇文章中&#xff0c;我将向你展示如何用Python绘制几种不同风格的圣诞树&a…

索尼(ILCE-7M3)MP4文件只能播放前两分钟修复案例

索尼的ILCE-7M3是一款经典设备&#xff0c;其HEVC编码效果是比较不错的&#xff0c;因此受到很多专业人士的青睐。之前我们说过很多索尼摄像机断电生成RSV文件修复的案例&#xff0c;今天来讲一个特殊的&#xff0c;文件已经正常封装但仅能播放前两分钟多一点的画面。 故障文件…

详细教程 - 从零开发 鸿蒙harmonyOS应用 第四节 (鸿蒙Stage模型 登录页面 ArkTS版 推荐使用)

在鸿蒙OS中&#xff0c;Ability是应用程序提供的抽象功能&#xff0c;可以理解为一种功能。在应用程序中&#xff0c;一个页面即一种能力&#xff0c;如登录页面&#xff0c;即具有登录功能的能力。以下是对鸿蒙新建项目的登录代码功能的详细解读和工作流程的描述&#xff1a; …

C++入门篇

呀哈喽&#xff0c;我是结衣。 了解完C的发展历程&#xff0c;我们当然也要会用C啊。今天这篇博客就是来帮助我们来入门C的&#xff0c;当然要入门C当然也要先学会C语言啦。在我学习C的过程中我会一直把C博客更新下去的。 C关键字 我们都知道C语言是有32个关键字的&#xff0…

json JSON.parse()与JSON.stringify()

JSON.parse() 属于解析 JSON.parse()方法解析一个JSON字符串为ECMAScript值&#xff0c;返回解析后的值&#xff0c; JSON.parse({}); // -> {}JSON.parse([]); // -> []JSON.parse(1); // -> {}注意&#xff1a;JSON.parse()解析的JSON字符串不允许以逗…