pytorch学习——正则化技术——丢弃法(dropout)

一、概念介绍

        在多层感知机(MLP)中,丢弃法(Dropout)是一种常用的正则化技术,旨在防止过拟合。(效果一般比前面的权重衰退好)

        在丢弃法中,随机选择一部分神经元并将其输出清零,被清零的神经元在该轮训练中不会被激活。这样,其他神经元就需要学习代替这些神经元的功能,从而促进了神经元之间的独立性和鲁棒性。

1.1思想原理

        丢弃法的基本思想是,在每一次训练中,随机选择一些神经元不参与训练,从而减少神经元之间的相互依赖关系,使得模型对于训练数据的过拟合程度降低。这样在测试时,所有神经元都参与,可以取得更好的泛化性能。

        丢弃法可以被应用到多层感知机的任意层中,包括输入层和输出层。在实际应用中,通常会在每一层都添加丢弃法,以充分发挥其正则化作用。

丢弃法特性:在层之间加入噪声,而不是在数据输入时加入。

 对Xi中的元素,以p概率变成0,1-p概率变大,最后期望值不变。

1.2应用场景

        通常将丢弃法作用在隐藏全连接层的输出上。

 如图所示,丢弃法可将一些中间结点丢弃,对剩余节点进行一定的增强。

 注:dropout是正则项,仅在训练中使用,不用于预测。

 二、示例演示

2.1实现dropout_layer 函数        

        该函数以dropout的概率丢弃张量输入X中的元素, 如上所述重新缩放剩余部分:将剩余部分除以1.0-dropout

import torch
from torch import nn
from d2l import torch as d2ldef dropout_layer(X, dropout):assert 0 <= dropout <= 1# 在本情况中,所有元素都被丢弃if dropout == 1:return torch.zeros_like(X)# 在本情况中,所有元素都被保留if dropout == 0:return Xmask = (torch.rand(X.shape) > dropout).float()return mask * X / (1.0 - dropout)

2.2测试dropout_layer函数

X=torch.arange(16,dtype=torch.float32).reshape((2,8))
print(X)
#暂退概率是0,0.5,1
print(dropout_layer(X,0.))
print(dropout_layer(X,0.5))
print(dropout_layer(X,1.))
#结果
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],[ 8.,  9., 10., 11., 12., 13., 14., 15.]])
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],[ 8.,  9., 10., 11., 12., 13., 14., 15.]])
tensor([[ 0.,  0.,  4.,  6.,  0.,  0.,  0., 14.],[16.,  0., 20., 22., 24.,  0., 28.,  0.]])
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0.]])

2.3定义模型 

        引入Fashion-MNIST数据集。 我们定义具有两个隐藏层的多层感知机,每个隐藏层包含256个单元。将暂退法应用于每个隐藏层的输出(在激活函数之后), 并且可以为每一层分别设置暂退概率: 常见的技巧是在靠近输入层的地方设置较低的暂退概率。 下面的模型将第一个和第二个隐藏层的暂退概率分别设置为0.2和0.5, 并且暂退法只在训练期间有效。

num_inputs,num_outputs,num_hiddens1,num_hiddens2=784,10,256,256
#定义两个隐藏层,每个隐藏层有256个单元
dropout1, dropout2 = 0.2, 0.5  # 为每个隐藏层设置一个 dropout 概率class Net(nn.Module):def __init__(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2, is_training=True):super(Net, self).__init__()self.num_inputs = num_inputsself.training = is_trainingself.lin1 = nn.Linear(num_inputs, num_hiddens1)self.lin2 = nn.Linear(num_hiddens1, num_hiddens2)self.lin3 = nn.Linear(num_hiddens2, num_outputs)self.relu = nn.ReLU()def forward(self, X):# 应用第一个全连接层和 ReLU 激活函数H1 = self.relu(self.lin1(X.reshape((-1, self.num_inputs))))# 如果处于训练模式,对第一个隐藏层应用 dropout 操作if self.training == True:H1 = dropout_layer(H1, dropout1)# 应用第二个全连接层和 ReLU 激活函数H2 = self.relu(self.lin2(H1))# 如果处于训练模式,对第二个隐藏层应用 dropout 操作if self.training == True:H2 = dropout_layer(H2, dropout2)# 应用第三个全连接层,得到输出张量out = self.lin3(H2)return out# 创建一个神经网络模型实例
net = Net(num_inputs, num_outputs, num_hiddens1, num_hiddens2)

2.4训练和测试

# 设置训练的轮数、学习率和批次大小
num_epochs, lr, batch_size = 10, 0.5, 256# 定义损失函数为交叉熵损失,并设置reduction='none'以便获得单个样本的损失值
loss = nn.CrossEntropyLoss(reduction='none')# 加载Fashion-MNIST数据集,并设置批次大小
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)# 定义优化器为随机梯度下降(SGD),并设置学习率
trainer = torch.optim.SGD(net.parameters(), lr=lr)# 使用d2l.train_ch3函数进行模型训练,其中包括训练数据迭代器、测试数据迭代器、损失函数、训练轮数和优化器等参数
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

2.5结果

 三 、简洁实现

from torch import nn
import torch
from d2l import  torch as d2l
dropout1, dropout2 = 0.2, 0.5  # 为每个隐藏层设置一个 dropout 概率
num_epochs, lr, batch_size = 10, 0.5, 256
net = nn.Sequential(nn.Flatten(),nn.Linear(784, 256),nn.ReLU(),# 在第一个全连接层之后添加一个dropout层nn.Dropout(dropout1),nn.Linear(256, 256),nn.ReLU(),# 在第二个全连接层之后添加一个dropout层nn.Dropout(dropout2),nn.Linear(256, 10))def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights)
loss = nn.CrossEntropyLoss(reduction='none')
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/18391.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue 常用指令 v-model 双向数据绑定

之前的指令&#xff0c;无论使用哪一种&#xff0c;都是在代码当中定义的内容。在web开发当中经常要去获取用户的输入&#xff0c;v-model可以十分方便的将表单的值和实例当中的数据关联起来。 这样就可以十分便捷的获取和设置表单元素的值了。&#xff08;注意是表单元素&…

SpringBoot第29讲:SpringBoot集成MySQL - MyBatis-Plus代码自动生成

SpringBoot第29讲&#xff1a;SpringBoot集成MySQL - MyBatis-Plus代码自动生成 本文是SpringBoot第29讲&#xff0c;主要介绍 MyBatis-Plus代码自动生成&#xff0c;以及产生此类代码生成工具的背景和此类工具的基本实现原理。 文章目录 SpringBoot第29讲&#xff1a;SpringBo…

【Linux】Centos7 的 Systemctl 与 创建系统服务 (shell脚本)

Systemctl systemctl 命令 # 启动 systemctl start NAME.service # 停止 systemctl stop NAME.service # 重启 systemctl restart NAME.service # 查看状态 systemctl status NAME.service # 查看所有激活系统服务 systemctl list-units -t service # 查看所有系统服务 syste…

PHP高级检索功能的实现以及动态拼接sql

我们学习了解了这么多关于PHP的知识&#xff0c;不知道你们对PHP高级检索功能的实现以及动态拼接sql是否已经完全掌握了呢&#xff0c;如果没有&#xff0c;那就跟随本篇文章一起继续学习吧! PHP高级检索功能的实现以及动态拼接sql。完成的功能有&#xff1a;可以单独根据一个…

华为云hcip核心知识笔记(数据库服务规划)

华为云hcip核心知识笔记&#xff08;数据库服务规划&#xff09; 1.云数据接库优势 1.1云数据库优点有&#xff1a; 易用性强&#xff1a;能欧快速部署和运行 高扩展&#xff1a;开放式架构和云计算存储分离 低成本&#xff1a;按需使用&#xff0c;成本更加低廉 2.云数据库r…

微软开测“Moment4”启动包:Win11 23H2要来了

近日&#xff0c; 有用户在Win11最新的7月累积更新中发现&#xff0c;更新文件中已经开始出现了对“Moment4”的引用。 具体来说&#xff0c;在7月累积更新中&#xff0c;微软加入了“Microsoft-Windows-UpdateTargeting-ClientOS-SV2Moment4-EKB”“Microsoft-Windows-23H2Ena…

2023年【零声教育】13代C/C++Linux服务器开发高级架构师课程体系分析

对于零声教育的C/CLinux服务器高级架构师的课程到2022目前已经迭代到13代了&#xff0c;像之前小编也总结过&#xff0c;但是课程每期都有做一定的更新&#xff0c;也是为了更好的完善课程跟上目前互联网大厂的岗位技术需求&#xff0c;之前课程里面也包含了一些小的分支&#…

音频客观感知MOS对比,对ViSQOL、PESQ、MosNet(神经网络MOS分)和polqa一致性对比和可信度雁阵

原创&#xff1a;转载需附链接&#xff1a; 音频客观感知MOS对比&#xff0c;对ViSQOL、PESQ、MosNet&#xff08;神经网络MOS分&#xff09;和polqa一致性对比和可信度雁阵_machine-lv的博客-CSDN博客谢谢&#xff01; 本文章以标准polqa的mos分为可信前提&#xff0c;验证vis…

MPAndroidChart学习及问题处理

1.添加依赖 项目目录->app->build.gradle dependencies {implementation com.github.PhilJay:MPAndroidChart:v3.0.3 }项目目录->app->setting.gradle dependencyResolutionManagement {repositories {maven { url https://jitpack.io }} }高版本的gradle添加依…

2023年第四届“华数杯”数学建模思路 - 案例:感知机原理剖析及实现

# 赛题思路 &#xff08;赛题出来以后第一时间在CSDN分享&#xff09; https://blog.csdn.net/dc_sinor?typeblog 一、感知机的直观理解 感知机应该属于机器学习算法中最简单的一种算法&#xff0c;其原理可以看下图&#xff1a; 比如说我们有一个坐标轴&#xff08;图中的…

关于视频汇聚融合EasyCVR平台多视频播放协议的概述

视频监控综合管理平台EasyCVR具备视频融合能力&#xff0c;平台基于云边端一体化架构&#xff0c;具有强大的数据接入、处理及分发能力&#xff0c;平台既具备传统安防视频监控的能力与服务&#xff0c;也支持AI智能检测技术的接入&#xff0c;可应用在多行业领域的智能化监管场…

直线模组如何进行精度校准?

直线模组是一种高精度的传动元件&#xff0c;而精度是直线模组的重要指标&#xff0c;在直线模组的使用中&#xff0c;我们应该尽可能的避免直线模组的精度受损&#xff0c;这样才能够有真正的发挥出直线模组的稳定性。 直线模组的精度一般是指重复定位精度和导向精度&#xff…

PyTorch(安装及卸载)

目录 1. 安装 2. 卸载 参考文献 为什么用PyTorch&#xff1a;简单来说&#xff0c;19年之前tensorflow是大哥&#xff0c;19年tensorflow和PyTorch双龙并行&#xff0c;20年之后PyTorch一往无前。宗旨&#xff0c;哪个用的人多用哪个。 1. 安装 1. 先打开Anaconda Prompt&…

uniapp自定义消息语音

需求是后端推送的消息APP要响自定义语音&#xff0c;利用官方插件&#xff0c;总结下整体流程 uniapp后台配置 因为2.0只支持uniapp自己的后台发送消息&#xff0c;所以要自己的后台发送消息只能用1.0 插件地址和代码 插件地址: link let isIos (plus.os.name "iOS&qu…

C++内存管理

目录 一.C中内存区域划分 一.C中内存区域划分 1.栈又叫堆栈--非静态局部变量/函数参数/返回值等等&#xff0c;栈是向下增长的。 2.内存映射段是高效的I/O映射方式&#xff0c;用于装载一个共享的动态内存库。用户可使用系统接口创建共享共享内存&#xff0c;做进程间通信。 …

手撕SpringBoot的自定义启动器

一. 前言 哈喽&#xff0c;大家好&#xff0c;最近金九银十&#xff0c;又有不少小伙伴私信辉哥&#xff0c;说自己在面试时被问到SpringBoot如何自定义启动器&#xff0c;结果自己不知道该怎么回答。那么今天就手把手地带着大家&#xff0c;去看看在SpringBoot中到底该怎么实…

亚马逊买家账号ip关联怎么处理

对于亚马逊买家账号&#xff0c;同样需要注意IP关联问题。在亚马逊的眼中&#xff0c;如果多个买家账号共享相同的IP地址&#xff0c;可能会被视为潜在的操纵、违规或滥用行为。这种情况可能导致账号受到限制或处罚。 处理亚马逊买家账号IP关联问题&#xff0c;建议采取以下步骤…

生化危机5找不到xlive.dll,要如何修复xlive.dll缺失

有朋友反映说他在玩生化危机5的时候&#xff0c;突然电脑就弹出一个找不到xlive.dll&#xff0c;然后游戏就打不开了&#xff0c;一直都很懵逼&#xff0c;不知道怎么处理这个问题&#xff0c;今天小编就来给大家详细的讲讲&#xff0c;找不到xlive.dll要怎么去修复&#xff01…

危化品行业防雷检测综合解决方案

危化品是指具有毒害、腐蚀、爆炸、燃烧、助燃等性质&#xff0c;能够对人体、设施或者环境造成危害的化学品。危化品的生产、储存、运输、使用等过程中&#xff0c;都存在着遭受雷击引发火灾或者爆炸事故的风险。因此&#xff0c;对危化品场所进行防雷检测&#xff0c;是保障危…

IDEA中修改类头的文档注释信息

IDEA中修改类头的文档注释信息 选择File--Settings--Editor--File and Code Templates--Includes&#xff0c;可以把文档注释写成这种的 /**author: Arbicoralcreate: ${YEAR}-${MONTH}-${DAY} ${TIME}Description: */这样回看就可以很清楚的看到自己创建脚本的时间&#xff…