pytorch学习——正则化技术——丢弃法(dropout)

一、概念介绍

        在多层感知机(MLP)中,丢弃法(Dropout)是一种常用的正则化技术,旨在防止过拟合。(效果一般比前面的权重衰退好)

        在丢弃法中,随机选择一部分神经元并将其输出清零,被清零的神经元在该轮训练中不会被激活。这样,其他神经元就需要学习代替这些神经元的功能,从而促进了神经元之间的独立性和鲁棒性。

1.1思想原理

        丢弃法的基本思想是,在每一次训练中,随机选择一些神经元不参与训练,从而减少神经元之间的相互依赖关系,使得模型对于训练数据的过拟合程度降低。这样在测试时,所有神经元都参与,可以取得更好的泛化性能。

        丢弃法可以被应用到多层感知机的任意层中,包括输入层和输出层。在实际应用中,通常会在每一层都添加丢弃法,以充分发挥其正则化作用。

丢弃法特性:在层之间加入噪声,而不是在数据输入时加入。

 对Xi中的元素,以p概率变成0,1-p概率变大,最后期望值不变。

1.2应用场景

        通常将丢弃法作用在隐藏全连接层的输出上。

 如图所示,丢弃法可将一些中间结点丢弃,对剩余节点进行一定的增强。

 注:dropout是正则项,仅在训练中使用,不用于预测。

 二、示例演示

2.1实现dropout_layer 函数        

        该函数以dropout的概率丢弃张量输入X中的元素, 如上所述重新缩放剩余部分:将剩余部分除以1.0-dropout

import torch
from torch import nn
from d2l import torch as d2ldef dropout_layer(X, dropout):assert 0 <= dropout <= 1# 在本情况中,所有元素都被丢弃if dropout == 1:return torch.zeros_like(X)# 在本情况中,所有元素都被保留if dropout == 0:return Xmask = (torch.rand(X.shape) > dropout).float()return mask * X / (1.0 - dropout)

2.2测试dropout_layer函数

X=torch.arange(16,dtype=torch.float32).reshape((2,8))
print(X)
#暂退概率是0,0.5,1
print(dropout_layer(X,0.))
print(dropout_layer(X,0.5))
print(dropout_layer(X,1.))
#结果
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],[ 8.,  9., 10., 11., 12., 13., 14., 15.]])
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],[ 8.,  9., 10., 11., 12., 13., 14., 15.]])
tensor([[ 0.,  0.,  4.,  6.,  0.,  0.,  0., 14.],[16.,  0., 20., 22., 24.,  0., 28.,  0.]])
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0.]])

2.3定义模型 

        引入Fashion-MNIST数据集。 我们定义具有两个隐藏层的多层感知机,每个隐藏层包含256个单元。将暂退法应用于每个隐藏层的输出(在激活函数之后), 并且可以为每一层分别设置暂退概率: 常见的技巧是在靠近输入层的地方设置较低的暂退概率。 下面的模型将第一个和第二个隐藏层的暂退概率分别设置为0.2和0.5, 并且暂退法只在训练期间有效。

num_inputs,num_outputs,num_hiddens1,num_hiddens2=784,10,256,256
#定义两个隐藏层,每个隐藏层有256个单元
dropout1, dropout2 = 0.2, 0.5  # 为每个隐藏层设置一个 dropout 概率class Net(nn.Module):def __init__(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2, is_training=True):super(Net, self).__init__()self.num_inputs = num_inputsself.training = is_trainingself.lin1 = nn.Linear(num_inputs, num_hiddens1)self.lin2 = nn.Linear(num_hiddens1, num_hiddens2)self.lin3 = nn.Linear(num_hiddens2, num_outputs)self.relu = nn.ReLU()def forward(self, X):# 应用第一个全连接层和 ReLU 激活函数H1 = self.relu(self.lin1(X.reshape((-1, self.num_inputs))))# 如果处于训练模式,对第一个隐藏层应用 dropout 操作if self.training == True:H1 = dropout_layer(H1, dropout1)# 应用第二个全连接层和 ReLU 激活函数H2 = self.relu(self.lin2(H1))# 如果处于训练模式,对第二个隐藏层应用 dropout 操作if self.training == True:H2 = dropout_layer(H2, dropout2)# 应用第三个全连接层,得到输出张量out = self.lin3(H2)return out# 创建一个神经网络模型实例
net = Net(num_inputs, num_outputs, num_hiddens1, num_hiddens2)

2.4训练和测试

# 设置训练的轮数、学习率和批次大小
num_epochs, lr, batch_size = 10, 0.5, 256# 定义损失函数为交叉熵损失,并设置reduction='none'以便获得单个样本的损失值
loss = nn.CrossEntropyLoss(reduction='none')# 加载Fashion-MNIST数据集,并设置批次大小
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)# 定义优化器为随机梯度下降(SGD),并设置学习率
trainer = torch.optim.SGD(net.parameters(), lr=lr)# 使用d2l.train_ch3函数进行模型训练,其中包括训练数据迭代器、测试数据迭代器、损失函数、训练轮数和优化器等参数
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

2.5结果

 三 、简洁实现

from torch import nn
import torch
from d2l import  torch as d2l
dropout1, dropout2 = 0.2, 0.5  # 为每个隐藏层设置一个 dropout 概率
num_epochs, lr, batch_size = 10, 0.5, 256
net = nn.Sequential(nn.Flatten(),nn.Linear(784, 256),nn.ReLU(),# 在第一个全连接层之后添加一个dropout层nn.Dropout(dropout1),nn.Linear(256, 256),nn.ReLU(),# 在第二个全连接层之后添加一个dropout层nn.Dropout(dropout2),nn.Linear(256, 10))def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights)
loss = nn.CrossEntropyLoss(reduction='none')
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/18391.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue 常用指令 v-model 双向数据绑定

之前的指令&#xff0c;无论使用哪一种&#xff0c;都是在代码当中定义的内容。在web开发当中经常要去获取用户的输入&#xff0c;v-model可以十分方便的将表单的值和实例当中的数据关联起来。 这样就可以十分便捷的获取和设置表单元素的值了。&#xff08;注意是表单元素&…

SpringBoot第29讲:SpringBoot集成MySQL - MyBatis-Plus代码自动生成

SpringBoot第29讲&#xff1a;SpringBoot集成MySQL - MyBatis-Plus代码自动生成 本文是SpringBoot第29讲&#xff0c;主要介绍 MyBatis-Plus代码自动生成&#xff0c;以及产生此类代码生成工具的背景和此类工具的基本实现原理。 文章目录 SpringBoot第29讲&#xff1a;SpringBo…

【Linux】Centos7 的 Systemctl 与 创建系统服务 (shell脚本)

Systemctl systemctl 命令 # 启动 systemctl start NAME.service # 停止 systemctl stop NAME.service # 重启 systemctl restart NAME.service # 查看状态 systemctl status NAME.service # 查看所有激活系统服务 systemctl list-units -t service # 查看所有系统服务 syste…

PHP高级检索功能的实现以及动态拼接sql

我们学习了解了这么多关于PHP的知识&#xff0c;不知道你们对PHP高级检索功能的实现以及动态拼接sql是否已经完全掌握了呢&#xff0c;如果没有&#xff0c;那就跟随本篇文章一起继续学习吧! PHP高级检索功能的实现以及动态拼接sql。完成的功能有&#xff1a;可以单独根据一个…

华为云hcip核心知识笔记(数据库服务规划)

华为云hcip核心知识笔记&#xff08;数据库服务规划&#xff09; 1.云数据接库优势 1.1云数据库优点有&#xff1a; 易用性强&#xff1a;能欧快速部署和运行 高扩展&#xff1a;开放式架构和云计算存储分离 低成本&#xff1a;按需使用&#xff0c;成本更加低廉 2.云数据库r…

微软开测“Moment4”启动包:Win11 23H2要来了

近日&#xff0c; 有用户在Win11最新的7月累积更新中发现&#xff0c;更新文件中已经开始出现了对“Moment4”的引用。 具体来说&#xff0c;在7月累积更新中&#xff0c;微软加入了“Microsoft-Windows-UpdateTargeting-ClientOS-SV2Moment4-EKB”“Microsoft-Windows-23H2Ena…

2023年【零声教育】13代C/C++Linux服务器开发高级架构师课程体系分析

对于零声教育的C/CLinux服务器高级架构师的课程到2022目前已经迭代到13代了&#xff0c;像之前小编也总结过&#xff0c;但是课程每期都有做一定的更新&#xff0c;也是为了更好的完善课程跟上目前互联网大厂的岗位技术需求&#xff0c;之前课程里面也包含了一些小的分支&#…

使用Vue.js和Rust构建高性能的物联网应用

物联网(IoT)应用是现代技术的重要组成部分&#xff0c;它们可以在各种场景中&#xff08;例如智能家居&#xff0c;工业自动化等&#xff09;提供无缝的自动化解决方案。在这篇文章中&#xff0c;我们将探讨如何使用Vue.js和Rust构建高性能的物联网应用。 1. 为什么选择Vue.js…

音频客观感知MOS对比,对ViSQOL、PESQ、MosNet(神经网络MOS分)和polqa一致性对比和可信度雁阵

原创&#xff1a;转载需附链接&#xff1a; 音频客观感知MOS对比&#xff0c;对ViSQOL、PESQ、MosNet&#xff08;神经网络MOS分&#xff09;和polqa一致性对比和可信度雁阵_machine-lv的博客-CSDN博客谢谢&#xff01; 本文章以标准polqa的mos分为可信前提&#xff0c;验证vis…

MPAndroidChart学习及问题处理

1.添加依赖 项目目录->app->build.gradle dependencies {implementation com.github.PhilJay:MPAndroidChart:v3.0.3 }项目目录->app->setting.gradle dependencyResolutionManagement {repositories {maven { url https://jitpack.io }} }高版本的gradle添加依…

Ceph错误汇总

title: “Ceph错误汇总” date: “2020-05-14” categories: - “技术” tags: - “Ceph” - “错误汇总” toc: false original: true draft: true Ceph错误汇总 1、执行ceph-deploy报错 1.1、错误信息 ➜ ceph-deploy Traceback (most recent call last):File "/us…

2023年第四届“华数杯”数学建模思路 - 案例:感知机原理剖析及实现

# 赛题思路 &#xff08;赛题出来以后第一时间在CSDN分享&#xff09; https://blog.csdn.net/dc_sinor?typeblog 一、感知机的直观理解 感知机应该属于机器学习算法中最简单的一种算法&#xff0c;其原理可以看下图&#xff1a; 比如说我们有一个坐标轴&#xff08;图中的…

关于视频汇聚融合EasyCVR平台多视频播放协议的概述

视频监控综合管理平台EasyCVR具备视频融合能力&#xff0c;平台基于云边端一体化架构&#xff0c;具有强大的数据接入、处理及分发能力&#xff0c;平台既具备传统安防视频监控的能力与服务&#xff0c;也支持AI智能检测技术的接入&#xff0c;可应用在多行业领域的智能化监管场…

Python简单应用II

#第一题&#xff1a; 将字符串joy存放于列表l1中&#xff0c;按要求完成如下操作。元素添加操作&#xff1a; 1&#xff09;在列表l1的尾部添加空白字符&#xff0c; 2&#xff09;在列表l1的尾部添加字符串singing&#xff0c; 3&#xff09;在列表l1的首部添加字符串I e&…

直线模组如何进行精度校准?

直线模组是一种高精度的传动元件&#xff0c;而精度是直线模组的重要指标&#xff0c;在直线模组的使用中&#xff0c;我们应该尽可能的避免直线模组的精度受损&#xff0c;这样才能够有真正的发挥出直线模组的稳定性。 直线模组的精度一般是指重复定位精度和导向精度&#xff…

React常见面试题

React常见面试题 一、React中的样式管理有哪些方法 内联样式&#xff1a;对象&#xff0c;作用于当前组件普通样式表&#xff1a; 作用于全局&#xff0c;文件名是&#xff1a;xxx.scssCSS模块&#xff1a;类似Vue的scoped&#xff0c; 文件名需是&#xff1a;xxx.module.scs…

代客泊车对HUT功能交互规范

目录 1. 版本记录... 7 2. 文档范围和控制... 8 2.1 目的/范围... 8 2.2 文档冲突... 8 2.3 文档授权... 8 2.4 文档更改控制... 8 3. 系统组成... 9 3.1 IPAS系统&#xff08;环视和超声波雷达&#xff09;...…

Springboot简单利用@RestControllerAdvice优雅的捕获异常

1.注解 ExceptionHandler&#xff1a;用于指定异常处理方法。当与RestControllerAdvice配合使用时&#xff0c;用于全局处理控制器里的异常。 2.配置类 RestControllerAdvice Slf4j public class GlobalExceptionHandler {ExceptionHandler(Exception.class)public Result h…

LeetCode 39. 组合总和(回溯+剪枝)

题目&#xff1a; 链接&#xff1a;LeetCode 39. 组合总和 难度&#xff1a;中等 给你一个 无重复元素 的整数数组 candidates 和一个目标整数 target &#xff0c;找出 candidates 中可以使数字和为目标数 target 的 所有 不同组合 &#xff0c;并以列表形式返回。你可以按 …

Java - sh 脚本启动 jar 包等服务 - sh 脚本模板 - 适用于任何类似的服务启动

sh 脚本模板 该模板&#xff0c;每次运行一次都会 kill 掉原来的服务&#xff0c;然后重新启动 jar 包服务 #!/bin/bash# 定义Java进程的名称 APP_NAMEyour-app-name.jar# 定义Java进程的日志文件路径 LOG_PATH/var/log/your-app-name.log# 定义备份日志文件的目录 BACKUP_DI…