抑制过拟合——Dropout原理

抑制过拟合——Dropout原理

  • Dropout的工作原理
  • 实验观察

  在机器学习领域,尤其是当我们处理复杂的模型和有限的训练样本时,一个常见的问题是过拟合。简而言之,过拟合发生在模型对训练数据学得太好,以至于它捕捉到了数据中的噪声和误差,而不仅仅是底层模式。具体来说,这在神经网络训练中尤为常见,表现为在训练数据上表现优异(例如损失函数值很小,预测准确率高)而在未见过的数据(测试集)上表现不佳。

  过拟合不仅是机器学习新手容易遇到的问题,即使是经验丰富的从业者也会面临这一挑战。一个典型的解决方案是采用模型集成技术,这涉及训练多个模型并将它们的预测结合起来。但这种方法的缺点是显而易见的:它既耗时又昂贵,不仅在训练阶段,而且在模型评估和部署时也是如此。

  在这种背景下,Dropout 作为一种有效的正则化技术,可以显著减轻过拟合问题。它的基本原理是在每次训练迭代中随机“丢弃”(即暂时移除)网络中的一部分神经元。这种方法不仅简单,而且被证明在许多情况下都非常有效。

Dropout的工作原理

  在 PyTorch 中,Dropout 层的使用相当直观。通常,它被添加到神经网络的各个层之间,如下所示:

torch.nn.Dropout(p=0.5, inplace=False)

  p:这是一个关键参数,代表着每个神经元被丢弃的概率。

  在实践中,这意味着对于网络中的每个神经元,它在每次训练迭代中都有 1 − p 1-p 1p 的概率被保留, p p p 的概率被丢弃。值得注意的是,这种随机性确保了每个mini-batch都在对不完全相同的网络进行训练,从而减少过拟合的风险。

  在训练期间,对于每个训练样本,网络中的每个神经元都有概率 1 − p 1-p 1p 被保留,概率 p p p 被丢弃。如果神经元被保留,则其输出乘以 1 1 − p \frac{1}{1-p} 1p1​(这样做是为了保持该层输出的总期望值不变)。设 r j r_j rj​ 为一个随机变量,它对应于第 j j j 个神经元,且服从伯努利分布(即 r j = 1 r_j = 1 rj=1 的概率为 1 − p 1-p 1p r j = 0 r_j = 0 rj=0 的概率为 p p p)。那么在训练时,神经元的输出 y j y_j yj变为 r j × y j / ( 1 − p ) r_j \times y_j / (1-p) rj×yj/(1p)

为什么需要保持期望不变? 举个简单的例子,假设某层有两个神经元,它们的输出在没有dropout时都是1。在应用了50%的dropout后,期望只有一个神经元被激活,输出为1,另一个被丢弃,输出为0。这样,这层的平均输出变成了0.5。为了保持输出的总期望值不变,激活的神经元的输出应该乘以2,即 1 1 − p \frac{1}{1-p} 1p1​,这样平均输出才能保持为1,与没有应用dropout时相同。这样的处理有助于保持整个网络的稳定性和一致性。

  在模型预测(或测试)阶段,所有的神经元都保持激活(即不进行dropout)。因为在训练阶段,神经元的输出已经被放大了 1 1 − p \frac{1}{1-p} 1p1 倍,所以在预测时不需要进行任何调整,直接使用网络进行前向传播即可。

在这里插入图片描述

实验观察

  为了更深入地理解 Dropout 的影响,我们可以通过一个实验来观察不同的 Dropout 设置对训练过程的影响。比如,可以比较 Dropout = 0.1Dropout = 0 在训练过程中的表现差异,相关代码实现如下:

import torch
from tensorboardX import SummaryWriter
from torch import optim, nn
import timeclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linears = nn.Sequential(nn.Linear(2, 20),nn.Linear(20, 20),nn.Dropout(0.1),nn.Linear(20, 20),nn.Linear(20, 20),nn.Linear(20, 1),)def forward(self, x):_ = self.linears(x)return _lr = 0.01
iteration = 1000x1 = torch.arange(-10, 10).float()
x2 = torch.arange(0, 20).float()
x = torch.cat((x1.unsqueeze(1), x2.unsqueeze(1)), dim=1)
y = 2*x1 - x2**2 + 1model = Model()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.01)
loss_function = torch.nn.MSELoss()start_time = time.time()
writer = SummaryWriter(comment='_随机失活')for iter in range(iteration):y_pred = model(x)loss = loss_function(y, y_pred.squeeze())loss.backward()for name, layer in model.named_parameters():writer.add_histogram(name + '_grad', layer.grad, iter)writer.add_histogram(name + '_data', layer, iter)writer.add_scalar('loss', loss, iter)optimizer.step()optimizer.zero_grad()if iter % 50 == 0:print("iter: ", iter)print("Time: ", time.time() - start_time)

这里我们使用 TensorBoardX 进行结果的可视化展示。

  通过观察模型训练1000轮后的线性层梯度分布,可以发现,应用 Dropout 后的模型梯度通常会更加分散和多样化。这种梯度的多样性有助于防止模型过于依赖训练数据中的特定模式,从而减轻过拟合。

在这里插入图片描述

  同样值得注意的是,模型的损失曲线也会受到影响。加入 Dropout 通常会使损失曲线出现更多的波动(例如,图中的蓝色曲线),这反映了模型在学习过程中的不稳定性。然而,这种不稳定性通常是可接受的,因为它反映了模型正在学习更多的泛化模式而不是简单地记住训练数据。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/185906.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

系统部署安装-Centos7-Cassandra

文章目录 介绍安装在线下载安装启动普通启动注册服务 介绍 Apache Cassandra是一个高度可扩展的高性能分布式数据库,旨在处理许多商用服务器上的大量数据,提供高可用性而没有单点故障。 安装 在线下载 (1)使用weget下载最新的…

mabatis基于xml方式和注解方式实现多表查询

前面步骤 http://t.csdnimg.cn/IPXMY 1、解释 在数据库中,单表的操作是最简单的,但是在实际业务中最少也有十几张表,并且表与表之间常常相互间联系; 一对一、一对多、多对多是表与表之间的常见的关系。 一对一:一张…

cesium不同版本对3dtiles的渲染效果不同,固定光照的优化方案

cesium不同版本对3dtiles的渲染效果不同,固定光照的优化方案,避免map.fixedLight true,导致的光照效果太强,模型太亮的问题。 问题来源: 1.Cesium1.47版本加载tileset.json文件跟Mars3d最新版加载文件存在差异效果 Cesium1.47…

基于springboot的课程作业管理系统

摘 要 随着科学技术的飞速发展,社会的方方面面、各行各业都在努力与现代的先进技术接轨,通过科技手段来提高自身的优势,课程作业管理系统当然也不能排除在外。课程作业管理系统是以实际运用为开发背景,运用软件工程原理和开发方法…

WPF绘制进度条(弧形,圆形,异形)

前言 WPF里面圆形进度条实现还比较麻烦,主要涉及到的就是动态绘制进度条的进度需要用到简单的数学算法。其实原理比较简单,我们需要的是话两条重叠的弧线,里面的弧线要比里面的弧线要宽,这样简单的雏形就出来了。 基础写法 我们可以用Path来绘制弧线,代码如下: <Gr…

Android Studio Giraffe版本遇到的问题

背景 上周固态硬盘挂了&#xff0c;恢复数据之后&#xff0c;重新换了新的固态安装了Win11系统&#xff0c;之前安装的是Android Studio 4.x的版本&#xff0c;这次也是趁着新的系统安装新的Android开发工具。 版本如下&#xff1a; 但是打开以前的Android旧项目时&#xff…

Vue3-Eslint配置代码风格

prettier风格配置 官网&#xff1a;https://prettier.io Eslint&#xff1a;代码纠错&#xff0c;关注于规范 prettier&#xff1a;专注于代码格式化的插件&#xff0c;让代码更加美观 两者各有所长&#xff0c;配合使用优化代码 生效前提&#xff1a; 1&#xff09;禁用…

jenkins-cicd基础操作

1.先决条件 1.首先我个人势在k8s集群中创建的jenkins,部署方法搭建 k8s部署jenkins-CSDN博客 2.安装指定插件. 1.Gitlab plugin 用于调用gitlab-api的插件 2.Kubernetes plugin jenkins与k8s进行交互的插件,可以用来自动化的构建和部署 3.Build Authorizatio…

java操作windows系统功能案例(一)

下面是一个Java操作Windows系统功能的简单案例&#xff1a; 获取系统信息&#xff1a; import java.util.Properties;public class SystemInfo {public static void main(String[] args) {Properties properties System.getProperties();properties.list(System.out);} }该程…

Python with提前退出:坑与解决方案

Python with提前退出&#xff1a;坑与解决方案 问题的起源 早些时候使用with实现了一版全局进程锁&#xff0c;希望实现以下效果&#xff1a; Python with提前退出&#xff1a;坑与解决方案 全局进程锁本身不用多说&#xff0c;大部分都依靠外部的缓存来实现的&#xff0c;r…

【模电】基本共射放大电路的组成及各元件的作用

基本共射放大电路的组成及各元件的作用 下图所示为基本共射放大电路&#xff0c;晶体管是起放大作用的核心元件。输入信号 U ˙ i \.{U}\tiny i U˙i为正弦波电压。 当 u i 0 {u\tiny i}0 ui0时&#xff0c;称放大电路处于静态。在输入回路中&#xff0c;基极电源 V B B V\tin…

Re8 Generative Modeling by Estimating Gradients of the Data Distribution

宋扬博士的作品&#xff0c;和DDPM同属扩散模型开创工作&#xff0c;但二者的技术路线不同 Introduction 当前生成模型主要分成两类 基于似然模型 通过近似最大似然直接学习分布的概率密度&#xff0c;如VAE 隐式生成模型 概率分布由其抽样过程的模型隐式表示&#xff0c…

vue3+ts 实现时间间隔选择器

需求背景解决效果视频效果balancedTimeElement.vue 需求背景 实现一个分片的时间间隔选择器&#xff0c;需要把显示时间段显示成图表&#xff0c;涉及一下集中数据转换 [“02:30-05:30”,“07:30-10:30”,“14:30-17:30”]‘[(2,5),(7,10),(14,17)]’[4, 5, 6, 7, 8, 9, 10, …

掌握Python BentoML:构建、部署和管理机器学习模型

更多资料获取 &#x1f4da; 个人网站&#xff1a;ipengtao.com BentoML是一个开源的Python框架&#xff0c;旨在简化机器学习模型的打包、部署和管理。本文将深入介绍BentoML的功能和用法&#xff0c;提供详细的示例代码和解释&#xff0c;帮助你更好地理解和应用这个强大的工…

volatile-之小总结

凭什么我们Java写了一个volatile关键字&#xff0c;系统底层加入内存屏障&#xff1f;两者的关系如何勾搭&#xff1f; 内存屏障是什么&#xff1f; 是一种屏障指令&#xff0c;它使得CPU或编译器对屏障指令的前和后所发出的内存操作执行一个排序的约 束。也称为内存栅栏或栅…

低价商品采购API接口

采购商品地址http://sly.yizhaosulianyun.com/More/Push/888889?type3 低价商品采购API接口 1) 请求地址 http://sly.yizhaosulianyun.com/jd/keyWords 2) 调用方式&#xff1a;HTTP post 3) 接口描述&#xff1a; 低价商品采购接口 4) 请求参数: POST参数: 字段名称字段…

《Python机器学习原理与算法实现》学习笔记--一文掌握机器学习与Python的基础概念

机器学习常见的基础概念 根据输入数据是否具有“响应变量”信息&#xff0c;机器学习被分为“监督式学习”和“非监督式学习”。“监督式学习”即输入数据中即有X变量&#xff0c;也有y变量&#xff0c;特色在于使用“特征&#xff08;X变量&#xff09;”来预测“响应变量&am…

会泽一村民上山放羊吸烟引发森林火灾,AI科技急需关注

2023年4月&#xff0c;会泽县古城街道厂沟村委会望香台山林中发生了一场由疏忽引发的森林火灾。张某某在放羊时未完全熄灭烟头&#xff0c;导致7.33公顷的林地和草地被焚毁&#xff0c;直接经济损失高达29.097万元。这一事件再次凸显了日常生活中的安全隐患。 在这一背景下&…

GeoServer改造Springboot源码四(图层管理设计)

一、界面设计 图 1图层管理列表 图 2选择图层数据源 图 3添加图层 图 4编辑图层

如何决定产品功能的优先顺序:从 Scrum 过渡到 Shape Up

领导者应该决定要解决的问题的“内容”和“时间”&#xff08;而不是要实施的解决方案&#xff09;。产品团队成员应该可以自由地通过他们只能根据自己的专业知识和知识构思和执行的解决方案来定义“如何”。本文将指导我们从 Scrum 转向Shape Up&#xff0c;立即开始按时交货&…