抑制过拟合——Dropout原理

抑制过拟合——Dropout原理

  • Dropout的工作原理
  • 实验观察

  在机器学习领域,尤其是当我们处理复杂的模型和有限的训练样本时,一个常见的问题是过拟合。简而言之,过拟合发生在模型对训练数据学得太好,以至于它捕捉到了数据中的噪声和误差,而不仅仅是底层模式。具体来说,这在神经网络训练中尤为常见,表现为在训练数据上表现优异(例如损失函数值很小,预测准确率高)而在未见过的数据(测试集)上表现不佳。

  过拟合不仅是机器学习新手容易遇到的问题,即使是经验丰富的从业者也会面临这一挑战。一个典型的解决方案是采用模型集成技术,这涉及训练多个模型并将它们的预测结合起来。但这种方法的缺点是显而易见的:它既耗时又昂贵,不仅在训练阶段,而且在模型评估和部署时也是如此。

  在这种背景下,Dropout 作为一种有效的正则化技术,可以显著减轻过拟合问题。它的基本原理是在每次训练迭代中随机“丢弃”(即暂时移除)网络中的一部分神经元。这种方法不仅简单,而且被证明在许多情况下都非常有效。

Dropout的工作原理

  在 PyTorch 中,Dropout 层的使用相当直观。通常,它被添加到神经网络的各个层之间,如下所示:

torch.nn.Dropout(p=0.5, inplace=False)

  p:这是一个关键参数,代表着每个神经元被丢弃的概率。

  在实践中,这意味着对于网络中的每个神经元,它在每次训练迭代中都有 1 − p 1-p 1p 的概率被保留, p p p 的概率被丢弃。值得注意的是,这种随机性确保了每个mini-batch都在对不完全相同的网络进行训练,从而减少过拟合的风险。

  在训练期间,对于每个训练样本,网络中的每个神经元都有概率 1 − p 1-p 1p 被保留,概率 p p p 被丢弃。如果神经元被保留,则其输出乘以 1 1 − p \frac{1}{1-p} 1p1​(这样做是为了保持该层输出的总期望值不变)。设 r j r_j rj​ 为一个随机变量,它对应于第 j j j 个神经元,且服从伯努利分布(即 r j = 1 r_j = 1 rj=1 的概率为 1 − p 1-p 1p r j = 0 r_j = 0 rj=0 的概率为 p p p)。那么在训练时,神经元的输出 y j y_j yj变为 r j × y j / ( 1 − p ) r_j \times y_j / (1-p) rj×yj/(1p)

为什么需要保持期望不变? 举个简单的例子,假设某层有两个神经元,它们的输出在没有dropout时都是1。在应用了50%的dropout后,期望只有一个神经元被激活,输出为1,另一个被丢弃,输出为0。这样,这层的平均输出变成了0.5。为了保持输出的总期望值不变,激活的神经元的输出应该乘以2,即 1 1 − p \frac{1}{1-p} 1p1​,这样平均输出才能保持为1,与没有应用dropout时相同。这样的处理有助于保持整个网络的稳定性和一致性。

  在模型预测(或测试)阶段,所有的神经元都保持激活(即不进行dropout)。因为在训练阶段,神经元的输出已经被放大了 1 1 − p \frac{1}{1-p} 1p1 倍,所以在预测时不需要进行任何调整,直接使用网络进行前向传播即可。

在这里插入图片描述

实验观察

  为了更深入地理解 Dropout 的影响,我们可以通过一个实验来观察不同的 Dropout 设置对训练过程的影响。比如,可以比较 Dropout = 0.1Dropout = 0 在训练过程中的表现差异,相关代码实现如下:

import torch
from tensorboardX import SummaryWriter
from torch import optim, nn
import timeclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linears = nn.Sequential(nn.Linear(2, 20),nn.Linear(20, 20),nn.Dropout(0.1),nn.Linear(20, 20),nn.Linear(20, 20),nn.Linear(20, 1),)def forward(self, x):_ = self.linears(x)return _lr = 0.01
iteration = 1000x1 = torch.arange(-10, 10).float()
x2 = torch.arange(0, 20).float()
x = torch.cat((x1.unsqueeze(1), x2.unsqueeze(1)), dim=1)
y = 2*x1 - x2**2 + 1model = Model()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.01)
loss_function = torch.nn.MSELoss()start_time = time.time()
writer = SummaryWriter(comment='_随机失活')for iter in range(iteration):y_pred = model(x)loss = loss_function(y, y_pred.squeeze())loss.backward()for name, layer in model.named_parameters():writer.add_histogram(name + '_grad', layer.grad, iter)writer.add_histogram(name + '_data', layer, iter)writer.add_scalar('loss', loss, iter)optimizer.step()optimizer.zero_grad()if iter % 50 == 0:print("iter: ", iter)print("Time: ", time.time() - start_time)

这里我们使用 TensorBoardX 进行结果的可视化展示。

  通过观察模型训练1000轮后的线性层梯度分布,可以发现,应用 Dropout 后的模型梯度通常会更加分散和多样化。这种梯度的多样性有助于防止模型过于依赖训练数据中的特定模式,从而减轻过拟合。

在这里插入图片描述

  同样值得注意的是,模型的损失曲线也会受到影响。加入 Dropout 通常会使损失曲线出现更多的波动(例如,图中的蓝色曲线),这反映了模型在学习过程中的不稳定性。然而,这种不稳定性通常是可接受的,因为它反映了模型正在学习更多的泛化模式而不是简单地记住训练数据。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/185906.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何提高谷歌上架成功率?代码混淆加固、AB面、账号关联

众所周知,Google play商店在全球是极具影响力的应用市场之一,随着市场和科学技术的发展,竞争越来越激烈,谷歌的政策也越来越严格。 为了维持良好的竞争环境和用户体验,谷歌不断更新政策和规则,同时加强对部…

系统部署安装-Centos7-Cassandra

文章目录 介绍安装在线下载安装启动普通启动注册服务 介绍 Apache Cassandra是一个高度可扩展的高性能分布式数据库,旨在处理许多商用服务器上的大量数据,提供高可用性而没有单点故障。 安装 在线下载 (1)使用weget下载最新的…

mabatis基于xml方式和注解方式实现多表查询

前面步骤 http://t.csdnimg.cn/IPXMY 1、解释 在数据库中,单表的操作是最简单的,但是在实际业务中最少也有十几张表,并且表与表之间常常相互间联系; 一对一、一对多、多对多是表与表之间的常见的关系。 一对一:一张…

cesium不同版本对3dtiles的渲染效果不同,固定光照的优化方案

cesium不同版本对3dtiles的渲染效果不同,固定光照的优化方案,避免map.fixedLight true,导致的光照效果太强,模型太亮的问题。 问题来源: 1.Cesium1.47版本加载tileset.json文件跟Mars3d最新版加载文件存在差异效果 Cesium1.47…

基于springboot的课程作业管理系统

摘 要 随着科学技术的飞速发展,社会的方方面面、各行各业都在努力与现代的先进技术接轨,通过科技手段来提高自身的优势,课程作业管理系统当然也不能排除在外。课程作业管理系统是以实际运用为开发背景,运用软件工程原理和开发方法…

4.1字符类型统计器(C语言实现)

【题目描述】请编写一个C程序&#xff0c;在终端用键盘输入字符串&#xff0c;以CtrlZ组合键表示输入完毕&#xff0c;统计输入的字符串中空格符、制表符、换行符的个数&#xff0c;并显示统计的结果。 【代码实现】 # include <stdio.h> int main() {int space 0, tab…

1022. 宠物小精灵之收服,二维花费的背包

1022. 宠物小精灵之收服 - AcWing题库 宠物小精灵是一部讲述小智和他的搭档皮卡丘一起冒险的故事。 一天&#xff0c;小智和皮卡丘来到了小精灵狩猎场&#xff0c;里面有很多珍贵的野生宠物小精灵。 小智也想收服其中的一些小精灵。 然而&#xff0c;野生的小精灵并不那么容…

WPF绘制进度条(弧形,圆形,异形)

前言 WPF里面圆形进度条实现还比较麻烦,主要涉及到的就是动态绘制进度条的进度需要用到简单的数学算法。其实原理比较简单,我们需要的是话两条重叠的弧线,里面的弧线要比里面的弧线要宽,这样简单的雏形就出来了。 基础写法 我们可以用Path来绘制弧线,代码如下: <Gr…

Android Studio Giraffe版本遇到的问题

背景 上周固态硬盘挂了&#xff0c;恢复数据之后&#xff0c;重新换了新的固态安装了Win11系统&#xff0c;之前安装的是Android Studio 4.x的版本&#xff0c;这次也是趁着新的系统安装新的Android开发工具。 版本如下&#xff1a; 但是打开以前的Android旧项目时&#xff…

Vue3-Eslint配置代码风格

prettier风格配置 官网&#xff1a;https://prettier.io Eslint&#xff1a;代码纠错&#xff0c;关注于规范 prettier&#xff1a;专注于代码格式化的插件&#xff0c;让代码更加美观 两者各有所长&#xff0c;配合使用优化代码 生效前提&#xff1a; 1&#xff09;禁用…

jenkins-cicd基础操作

1.先决条件 1.首先我个人势在k8s集群中创建的jenkins,部署方法搭建 k8s部署jenkins-CSDN博客 2.安装指定插件. 1.Gitlab plugin 用于调用gitlab-api的插件 2.Kubernetes plugin jenkins与k8s进行交互的插件,可以用来自动化的构建和部署 3.Build Authorizatio…

查看linux处理器架构(uname命令 使用指南)

一、查看linux处理器架构 在linux系统终端下输入uname -m(在windows下可通过git Bash输入uname -m命令) 可得输出结果与架构对应表 架构 输出结果 i386 i386, i686 amd64 x86_64 arm arm, armv7l arm64 aarch64, armv8l mips mips mips64 mips64 等等等 alpha, …

java操作windows系统功能案例(一)

下面是一个Java操作Windows系统功能的简单案例&#xff1a; 获取系统信息&#xff1a; import java.util.Properties;public class SystemInfo {public static void main(String[] args) {Properties properties System.getProperties();properties.list(System.out);} }该程…

Python with提前退出:坑与解决方案

Python with提前退出&#xff1a;坑与解决方案 问题的起源 早些时候使用with实现了一版全局进程锁&#xff0c;希望实现以下效果&#xff1a; Python with提前退出&#xff1a;坑与解决方案 全局进程锁本身不用多说&#xff0c;大部分都依靠外部的缓存来实现的&#xff0c;r…

C++ 类和对象-封装的意义

前沿 C面向对象的三大特性为&#xff1a;封装、继承、多态。 封装的意义&#xff1a; 将属性和行为作为一个整体&#xff0c;表现生活中的事物将属性和行为加以权限控制 封装意义一&#xff1a; 在设计类的时候&#xff0c;属性和行为写在一起&#xff0c;表现事物。 语法&…

python 生成器的作用

1. 生成器 参考&#xff1a; https://www.cainiaojc.com/python/python-generator.html 1.1. 什么是生成器&#xff1f; 在 python 中&#xff0c;一边循环一边计算的机制&#xff0c;称为生成器&#xff1a;generator. 1.2. 生成器有什么优点&#xff1f; 1、节约内存。p…

PC端企业微信hook协议开发,获取要群发的客户群id

产品说明 一、 hook版本&#xff1a;企业微信hook接口是指将企业微信的功能封装成dll&#xff0c;并提供简易的接口给程序调用。通过hook技术&#xff0c;可以在不修改企业微信客户端源代码的情况下&#xff0c;实现对企业微信客户端的功能进行扩展和定制化。企业微信hook接口…

【模电】基本共射放大电路的组成及各元件的作用

基本共射放大电路的组成及各元件的作用 下图所示为基本共射放大电路&#xff0c;晶体管是起放大作用的核心元件。输入信号 U ˙ i \.{U}\tiny i U˙i为正弦波电压。 当 u i 0 {u\tiny i}0 ui0时&#xff0c;称放大电路处于静态。在输入回路中&#xff0c;基极电源 V B B V\tin…

Re8 Generative Modeling by Estimating Gradients of the Data Distribution

宋扬博士的作品&#xff0c;和DDPM同属扩散模型开创工作&#xff0c;但二者的技术路线不同 Introduction 当前生成模型主要分成两类 基于似然模型 通过近似最大似然直接学习分布的概率密度&#xff0c;如VAE 隐式生成模型 概率分布由其抽样过程的模型隐式表示&#xff0c…

vue3+ts 实现时间间隔选择器

需求背景解决效果视频效果balancedTimeElement.vue 需求背景 实现一个分片的时间间隔选择器&#xff0c;需要把显示时间段显示成图表&#xff0c;涉及一下集中数据转换 [“02:30-05:30”,“07:30-10:30”,“14:30-17:30”]‘[(2,5),(7,10),(14,17)]’[4, 5, 6, 7, 8, 9, 10, …