【机器学习】037_暂退法

一、实现原理

具有输入噪音的训练,等价于Tikhonov正则化

核心方法:在前向传播的过程中,计算每一内部层的同时注入噪声

· 从作用上来看,表面上来说是在训练过程中丢弃一些神经元

· 假设x是某一层神经网络层的输出,是下一层的输入,我们希望对x加入一些噪音,使得:

E[x^`]=x

  ※x`的期望为x,也就是说平均上来说输出值还是x

· 暂退法对每个元素进行了如下扰动:

        有p的概率下取值:x^`_i=0

        其它情况(1-p概率):x^`_i = \frac{x_i}{1-p}

实践中使用暂退法:

· 通常将暂退法作用在全连接隐藏层的输出上

如图所示,在第一个隐藏层的输出上,有些神经元有p的概率使输出值置零。

非置零的输出值,即有1-p的概率被施加了一个较小的扰动值使其略微增大。

※暂退法只在训练中使用,dropout是正则项,在推理过程中不会使用,这样也会保证输出值确定

※每次执行暂退法的时候,实际上是每次随机采样了一些子神经网络

总结:

①暂退法将一些输出项随机置零来控制模型的复杂度

②暂退法的作用效果和正则化等价

③常应用在多层感知机的隐藏层输出上

④丢弃概率p是控制模型复杂度的超参数

二、代码实现

从零实现代码:

import torch
from torch import nn
from d2l import torch as d2ldef dropout_layer(X, dropout):# assert用于选择dropout符合范围的情况,不符合则报错assert 0 <= dropout <= 1, "不符合范围!"# 在本情况中,所有元素都被丢弃if dropout == 1:return torch.zeros_like(X)# 在本情况中,所有元素都被保留if dropout == 0:return X# 在这一步操作中,首先定义一个和X张量形状相同但元素值均为随机数的张量# 将这个张量里每个元素与dropout比较,如果大于就置为True,小于等于就置为False# 再调用float将True和False转化为1和0# 这样,mask就是一个仅含1与0的张量了# 最后将mask里的每个元素与X里的每个元素做数乘mask = (torch.rand(X.shape) > dropout).float()return mask * X / (1.0 - dropout)# 生成X来测试暂退法
X= torch.arange(16, dtype = torch.float32).reshape((2, 8))
print(X)
print(dropout_layer(X, 0.))
print(dropout_layer(X, 0.5))
print(dropout_layer(X, 1.))# 定义模型参数
num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256
# 定义模型
dropout1, dropout2 = 0.2, 0.5
# is_training用来表示当前是在测试还是在训练
class Net(nn.Module):def __init__(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2,is_training = True):super(Net, self).__init__()self.num_inputs = num_inputsself.training = is_trainingself.lin1 = nn.Linear(num_inputs, num_hiddens1)self.lin2 = nn.Linear(num_hiddens1, num_hiddens2)self.lin3 = nn.Linear(num_hiddens2, num_outputs)self.relu = nn.ReLU()def forward(self, X):H1 = self.relu(self.lin1(X.reshape((-1, self.num_inputs))))# 只有在训练模型时才使用dropoutif self.training == True:# 在第一个全连接层之后添加一个dropout层H1 = dropout_layer(H1, dropout1)H2 = self.relu(self.lin2(H1))if self.training == True:# 在第二个全连接层之后添加一个dropout层H2 = dropout_layer(H2, dropout2)# 输出不需要dropout作用out = self.lin3(H2)return outnet = Net(num_inputs, num_outputs, num_hiddens1, num_hiddens2)# 训练、测试模型
num_epochs, lr, batch_size = 10, 0.5, 256
loss = nn.CrossEntropyLoss(reduction='none')
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

简洁实现代码:

import torch
from torch import nn
from d2l import torch as d2l# 定义概率参数
dropout1, dropout2 = 0.2, 0.5net = nn.Sequential(nn.Flatten(),nn.Linear(784, 256),nn.ReLU(),# 在第一个全连接层之后添加一个dropout层nn.Dropout(dropout1),nn.Linear(256, 256),nn.ReLU(),# 在第二个全连接层之后添加一个dropout层nn.Dropout(dropout2),nn.Linear(256, 10))def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights);# 训练、测试模型
num_epochs, lr, batch_size = 10, 0.5, 256
loss = nn.CrossEntropyLoss(reduction='none')
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/152598.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【机器学习】036_权重衰退

一、范数 定义&#xff1a;向量的范数表示一个向量有多大&#xff08;分量的大小&#xff09; L1范数&#xff1a; 即向量元素绝对值之和&#xff0c;用符号 ‖ v ‖ 1 表示。 公式&#xff1a; L2范数&#xff1a; 即向量的模&#xff0c;向量各元素绝对值的平方之和再…

适合您的智能手机的 7 款优秀手机数据恢复软件分享

如今&#xff0c;我们做什么都用手机&#xff1b;从拍照到录音&#xff0c;甚至作为 MP3 播放器&#xff0c;我们已经对手机变得非常依恋。这导致我们在手机上留下了很多珍贵的回忆。 不幸的是&#xff0c;我们有可能会丢失手机上的部分甚至全部数据。幸运的是&#xff0c;这不…

1. hadoop环境准备

环境准备 准备三台虚拟机&#xff0c;配置最好是 2C 4G 以上 本文准备三台机器的内网ip分别为 172.17.0.10 172.17.0.11 172.17.0.12本机配置/etc/hosts cat >> /etc/hosts<<EOF 172.17.0.10 hadoop01 172.17.0.11 hadoop02 172.17.0.12 hadoop03 EOF本机设置与…

队列的实现和OJ练习

目录 概念 队列的实现 利用结构体存放队列结构 为什么单链表不使用这种方法&#xff1f; 初始化队列 小提示&#xff1a; 队尾入队列 队头出队列 获取队头元素 获取队尾元素 获取队列中有效元素个数 检测队列是否为空 销毁队列 最终代码 循环队列 队列的OJ题 …

MobaXterm如何连接CentOS7的Linux虚拟机?Redis可视化客户端工具如何连接Linux版Redis?

一、打开Lunix虚拟机,进入虚拟机中,在终端中输入ifconfig,得到以下信息&#xff0c;红框中为ip地址 二、打开MobaXterm&#xff0c;点击session 选择SSH&#xff0c;在Remote host中输入linux得到的IP地址&#xff0c;Specify username中可起一个任意的连接名称。 输入密码 四、…

【洛谷 P3743】kotori的设备 题解(二分答案+递归)

kotori的设备 题目背景 kotori 有 n n n 个可同时使用的设备。 题目描述 第 i i i 个设备每秒消耗 a i a_i ai​ 个单位能量。能量的使用是连续的&#xff0c;也就是说能量不是某时刻突然消耗的&#xff0c;而是匀速消耗。也就是说&#xff0c;对于任意实数&#xff0c;…

60 权限提升-MYMSORA等SQL数据库提权

目录 数据库应用提权在权限提升中的意义WEB或本地环境如何探针数据库应用数据库提权权限用户密码收集等方法目前数据库提权对应的技术及方法等 演示案例Mysql数据库提权演示-脚本&MSF1.UDF提权知识点: (基于MYSQL调用命令执行函数&#xff09;读取数据库存储或备份文件 (了…

GaussDB新特性Ustore存储引擎介绍

1、 Ustore和Astore存储引擎介绍 Ustore存储引擎&#xff0c;又名In-place Update存储引擎&#xff08;原地更新&#xff09;&#xff0c;是openGauss 内核新增的一种存储模式。此前的版本使用的行存储引擎是Append Update&#xff08;追加更新&#xff09;模式。相比于Append…

在网络攻击之前、期间和之后应采取的步骤

在当今复杂的威胁形势下&#xff0c;网络攻击是不可避免的。 恶意行为者变得越来越复杂&#xff0c;出于经济动机的攻击变得越来越普遍&#xff0c;并且每天都会发现新的恶意软件系列。 这使得对于各种规模和跨行业的组织来说&#xff0c;制定适当的攻击计划变得更加重要。 …

【Linux】进程间通信 -- 管道

对于进程间通信的理解 首先&#xff0c;进程间通信的本质是&#xff0c;让不同的进程看到同一份资源&#xff08;这份资源不能隶属于任何一个进程&#xff0c;即应该是共享的&#xff09;。而进程间通信的目的是为了实现多进程之间的协同。 但由于进程运行具有独立性&#xff…

密码加密解密之路

1.背景 做数据采集&#xff0c;客户需要把他们那边的数据库连接信息存到我们系统里&#xff0c;那我们系统就要尽可能的保证这部分数据安全&#xff0c;不被盗。 2.我的思路 1.需要加密的地方有两处&#xff0c;一个是新增的时候前端传给后端的时候&#xff0c;一个是存到数…

异步爬取+多线程+redis构建一个运转丝滑且免费http-ip代理池 (三)

内容提要: 如果说,爬取网页数据的时候,我们使用了异步,那么将数据放入redis里面,其实也需要进行异步;当然,如果使用多线程或者redis线程池技术也是可以的,但那会造成冗余; 因此,在测试完多线程redis搭配异步爬虫的时候,我发现效率直接在redis这里被无限拉低下来! 因此: 最终的r…

从0开始学习JavaScript--JavaScript中的集合类

JavaScript中的集合类是处理数据的关键&#xff0c;涵盖了数组、Set、Map等多种数据结构。本文将深入研究这些集合类的创建、操作&#xff0c;以及实际应用场景&#xff0c;并通过丰富的示例代码&#xff0c;帮助大家更全面地了解和应用这些概念。 数组&#xff08;Array&…

SystemVerilog学习 (11)——覆盖率

目录 一、概述 二、覆盖率的种类 1、概述 2、分类 三、代码覆盖率 四、功能覆盖率 五、从功能描述到覆盖率 一、概述 “验证如果没有量化&#xff0c;那么就意味着没有尽头。” 伴随着复杂SoC系统的验证难度系数成倍增加&#xff0c;无论是定向测试还是随机测试&#xff…

安全框架springSecurity+Jwt+Vue-1(vue环境搭建、动态路由、动态标签页)

一、安装vue环境&#xff0c;并新建Vue项目 ①&#xff1a;安装node.js 官网(https://nodejs.org/zh-cn/) 2.安装完成之后检查下版本信息&#xff1a; ②&#xff1a;创建vue项目 1.接下来&#xff0c;我们安装vue的环境 # 安装淘宝npm npm install -g cnpm --registryhttps:/…

软件测试/测试开发/人工智能丨基于Spark的分布式造数工具:加速大规模测试数据构建

随着软件开发规模的扩大&#xff0c;测试数据的构建变得越来越复杂&#xff0c;传统的造数方法难以应对大规模数据需求。本文将介绍如何使用Apache Spark构建分布式造数工具&#xff0c;以提升测试数据构建的效率和规模。 为什么选择Spark&#xff1f; 分布式计算&#xff1a;…

easyExcel注解详情

前言11个注解字段注解 类注解基础综合示例补充颜色总结 11个注解 ExcelProperty ColumnWith 列宽 ContentFontStyle 文本字体样式 ContentLoopMerge 文本合并 ContentRowHeight 文本行高度 ContentStyle 文本样式 HeadFontStyle 标题字体样式 HeadRowHeight 标题高度 HeadStyle…

Python将原始数据集和标注文件进行数据增强(随机仿射变换),并生成随机仿射变换的数据集和标注文件

Python将原始数据集和标注文件进行数据增强&#xff08;随机仿射变换&#xff09;&#xff0c;并生成随机仿射变换的数据集和标注文件 前言前提条件相关介绍实验环境生成随机仿射变换的数据集和标注文件代码实现输出结果 前言 由于本人水平有限&#xff0c;难免出现错漏&#x…

OpenCV快速入门:图像滤波与边缘检测

文章目录 前言一、噪声种类与生成1.1 椒盐噪声1.2 高斯噪声1.3 彩色噪声 二、卷积操作2.1 卷积基本原理2.2 卷积操作代码实现 三、线性滤波3.1 均值滤波均值滤波原理均值滤波公式均值滤波代码实现 3.2 方框滤波方框滤波原理方框滤波公式方框滤波代码实现 3.3 高斯滤波高斯滤波原…

redis非关系型数据库(缓存型数据库)——中间件

【重点】redis为什么这么快&#xff1f;&#xff08;应届&#xff09; ①redis是纯内存结构&#xff0c;避免磁盘I/O的耗时 ②redis核心模块是一个单进程&#xff0c;减少线程切换和回收线程资源时间 ③redis采用的是I/O的多路复用机制&#xff08;每一个执行线路可以同时完…