【机器学习】037_暂退法

一、实现原理

具有输入噪音的训练,等价于Tikhonov正则化

核心方法:在前向传播的过程中,计算每一内部层的同时注入噪声

· 从作用上来看,表面上来说是在训练过程中丢弃一些神经元

· 假设x是某一层神经网络层的输出,是下一层的输入,我们希望对x加入一些噪音,使得:

E[x^`]=x

  ※x`的期望为x,也就是说平均上来说输出值还是x

· 暂退法对每个元素进行了如下扰动:

        有p的概率下取值:x^`_i=0

        其它情况(1-p概率):x^`_i = \frac{x_i}{1-p}

实践中使用暂退法:

· 通常将暂退法作用在全连接隐藏层的输出上

如图所示,在第一个隐藏层的输出上,有些神经元有p的概率使输出值置零。

非置零的输出值,即有1-p的概率被施加了一个较小的扰动值使其略微增大。

※暂退法只在训练中使用,dropout是正则项,在推理过程中不会使用,这样也会保证输出值确定

※每次执行暂退法的时候,实际上是每次随机采样了一些子神经网络

总结:

①暂退法将一些输出项随机置零来控制模型的复杂度

②暂退法的作用效果和正则化等价

③常应用在多层感知机的隐藏层输出上

④丢弃概率p是控制模型复杂度的超参数

二、代码实现

从零实现代码:

import torch
from torch import nn
from d2l import torch as d2ldef dropout_layer(X, dropout):# assert用于选择dropout符合范围的情况,不符合则报错assert 0 <= dropout <= 1, "不符合范围!"# 在本情况中,所有元素都被丢弃if dropout == 1:return torch.zeros_like(X)# 在本情况中,所有元素都被保留if dropout == 0:return X# 在这一步操作中,首先定义一个和X张量形状相同但元素值均为随机数的张量# 将这个张量里每个元素与dropout比较,如果大于就置为True,小于等于就置为False# 再调用float将True和False转化为1和0# 这样,mask就是一个仅含1与0的张量了# 最后将mask里的每个元素与X里的每个元素做数乘mask = (torch.rand(X.shape) > dropout).float()return mask * X / (1.0 - dropout)# 生成X来测试暂退法
X= torch.arange(16, dtype = torch.float32).reshape((2, 8))
print(X)
print(dropout_layer(X, 0.))
print(dropout_layer(X, 0.5))
print(dropout_layer(X, 1.))# 定义模型参数
num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256
# 定义模型
dropout1, dropout2 = 0.2, 0.5
# is_training用来表示当前是在测试还是在训练
class Net(nn.Module):def __init__(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2,is_training = True):super(Net, self).__init__()self.num_inputs = num_inputsself.training = is_trainingself.lin1 = nn.Linear(num_inputs, num_hiddens1)self.lin2 = nn.Linear(num_hiddens1, num_hiddens2)self.lin3 = nn.Linear(num_hiddens2, num_outputs)self.relu = nn.ReLU()def forward(self, X):H1 = self.relu(self.lin1(X.reshape((-1, self.num_inputs))))# 只有在训练模型时才使用dropoutif self.training == True:# 在第一个全连接层之后添加一个dropout层H1 = dropout_layer(H1, dropout1)H2 = self.relu(self.lin2(H1))if self.training == True:# 在第二个全连接层之后添加一个dropout层H2 = dropout_layer(H2, dropout2)# 输出不需要dropout作用out = self.lin3(H2)return outnet = Net(num_inputs, num_outputs, num_hiddens1, num_hiddens2)# 训练、测试模型
num_epochs, lr, batch_size = 10, 0.5, 256
loss = nn.CrossEntropyLoss(reduction='none')
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

简洁实现代码:

import torch
from torch import nn
from d2l import torch as d2l# 定义概率参数
dropout1, dropout2 = 0.2, 0.5net = nn.Sequential(nn.Flatten(),nn.Linear(784, 256),nn.ReLU(),# 在第一个全连接层之后添加一个dropout层nn.Dropout(dropout1),nn.Linear(256, 256),nn.ReLU(),# 在第二个全连接层之后添加一个dropout层nn.Dropout(dropout2),nn.Linear(256, 10))def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights);# 训练、测试模型
num_epochs, lr, batch_size = 10, 0.5, 256
loss = nn.CrossEntropyLoss(reduction='none')
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/152598.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

redhat下使用CentOS yum源,并安装docker

一、安装yum源 1.卸载yum # 查看系统自身安装的yum软件包 rpm -qa | grep yum # 卸载软yum件包 rpm -e 软件包名称 --nodeps #可以使用简称如 rpm -e yum-* --nodeps2. 安装yum [rootbogon ~]# rpm -ivh --nodeps https://mirrors.aliyun.com/centos/8/BaseOS/x86_64/os/Pa…

海康Visionmaster-环境配置:运行出现 Vm.Core.Solu tion 报错的解决方法

&#xff08;1&#xff09;检查加密狗有没有插好&#xff1f; 是否以管理员权限启动程序&#xff1f;首选 32 位是否取消勾选&#xff1f; &#xff08;2&#xff09;查看 VM4.0 的版本信息是否为最新版本&#xff1f;版本信息为 20220415 以上&#xff0c;版本越新问题就会越少…

【机器学习】036_权重衰退

一、范数 定义&#xff1a;向量的范数表示一个向量有多大&#xff08;分量的大小&#xff09; L1范数&#xff1a; 即向量元素绝对值之和&#xff0c;用符号 ‖ v ‖ 1 表示。 公式&#xff1a; L2范数&#xff1a; 即向量的模&#xff0c;向量各元素绝对值的平方之和再…

低代码平台技术分享官 | 漫话iGIX前端设计模式

设计模式是一个程序员进阶高级的必备技巧&#xff0c;也是评判一个工程师工作经验和能力的试金石。设计模式是程序员多年工作经验的凝练和总结&#xff0c;能够更大限度的优化代码以及对已有代码进行合理重构。但如果你还不知道如何使用设计模式提升前端开发质量&#xff0c;那…

MQTT协议详解

前言 MQTT是一个即时通讯协议&#xff0c;它工作在TCP/IP协议族上&#xff0c;是为硬件性能低下的远程设备以及网络状况糟糕的情况下而设计的发布/订阅型消息协议。它使用发布/订阅消息模式&#xff0c;提供一对多的消息发布&#xff0c;解除应用程序耦合。MQTT是轻量、简单、…

适合您的智能手机的 7 款优秀手机数据恢复软件分享

如今&#xff0c;我们做什么都用手机&#xff1b;从拍照到录音&#xff0c;甚至作为 MP3 播放器&#xff0c;我们已经对手机变得非常依恋。这导致我们在手机上留下了很多珍贵的回忆。 不幸的是&#xff0c;我们有可能会丢失手机上的部分甚至全部数据。幸运的是&#xff0c;这不…

1. hadoop环境准备

环境准备 准备三台虚拟机&#xff0c;配置最好是 2C 4G 以上 本文准备三台机器的内网ip分别为 172.17.0.10 172.17.0.11 172.17.0.12本机配置/etc/hosts cat >> /etc/hosts<<EOF 172.17.0.10 hadoop01 172.17.0.11 hadoop02 172.17.0.12 hadoop03 EOF本机设置与…

队列的实现和OJ练习

目录 概念 队列的实现 利用结构体存放队列结构 为什么单链表不使用这种方法&#xff1f; 初始化队列 小提示&#xff1a; 队尾入队列 队头出队列 获取队头元素 获取队尾元素 获取队列中有效元素个数 检测队列是否为空 销毁队列 最终代码 循环队列 队列的OJ题 …

C++-特殊类和单例模式

1.请设计一个类&#xff0c;不能被拷贝 拷贝构造函数以及赋值运算符重载&#xff0c;因此想要让一个类禁止拷贝&#xff0c;只需让该类不能调用拷贝构造函数以及赋值运算符重载即可。 //该类不能发生拷贝class NonCopy{public:NonCopy(const NonCopy& Nc) delete;NonCopy&…

MobaXterm如何连接CentOS7的Linux虚拟机?Redis可视化客户端工具如何连接Linux版Redis?

一、打开Lunix虚拟机,进入虚拟机中,在终端中输入ifconfig,得到以下信息&#xff0c;红框中为ip地址 二、打开MobaXterm&#xff0c;点击session 选择SSH&#xff0c;在Remote host中输入linux得到的IP地址&#xff0c;Specify username中可起一个任意的连接名称。 输入密码 四、…

AM@傅里叶级数@周期为2l的一般情形

文章目录 abstract周期为 2 l 2l 2l的Fourier展开推导例 三角函数和(-1)的幂转换关系(-1)的幂与级数的奇偶项级数通项变形例例 abstract 从特殊到一般,从对周期为 2 π 2\pi 2π的函数到周期为 2 l 2l 2l的函数 推导周期为 2 l 2l 2l情况下的公式又可以借助于周期为 2 π 2\pi…

【洛谷 P3743】kotori的设备 题解(二分答案+递归)

kotori的设备 题目背景 kotori 有 n n n 个可同时使用的设备。 题目描述 第 i i i 个设备每秒消耗 a i a_i ai​ 个单位能量。能量的使用是连续的&#xff0c;也就是说能量不是某时刻突然消耗的&#xff0c;而是匀速消耗。也就是说&#xff0c;对于任意实数&#xff0c;…

60 权限提升-MYMSORA等SQL数据库提权

目录 数据库应用提权在权限提升中的意义WEB或本地环境如何探针数据库应用数据库提权权限用户密码收集等方法目前数据库提权对应的技术及方法等 演示案例Mysql数据库提权演示-脚本&MSF1.UDF提权知识点: (基于MYSQL调用命令执行函数&#xff09;读取数据库存储或备份文件 (了…

记录 ubuntu 硬盘分区跟格式化(fdisk命令)

1、sudo su 2、fdisk -l 查找需要挂载的硬盘是哪一个 3、fdisk /dev/sda 可以输入m&#xff0c;查看帮助信息&#xff0c;按p&#xff0c;查看磁盘分区信息 输入n&#xff0c;选择分两个区&#xff08;原来已经有一个&#xff0c;再加1个就是2两个&#xff09;&#xff0c…

GaussDB新特性Ustore存储引擎介绍

1、 Ustore和Astore存储引擎介绍 Ustore存储引擎&#xff0c;又名In-place Update存储引擎&#xff08;原地更新&#xff09;&#xff0c;是openGauss 内核新增的一种存储模式。此前的版本使用的行存储引擎是Append Update&#xff08;追加更新&#xff09;模式。相比于Append…

在网络攻击之前、期间和之后应采取的步骤

在当今复杂的威胁形势下&#xff0c;网络攻击是不可避免的。 恶意行为者变得越来越复杂&#xff0c;出于经济动机的攻击变得越来越普遍&#xff0c;并且每天都会发现新的恶意软件系列。 这使得对于各种规模和跨行业的组织来说&#xff0c;制定适当的攻击计划变得更加重要。 …

【Linux】进程间通信 -- 管道

对于进程间通信的理解 首先&#xff0c;进程间通信的本质是&#xff0c;让不同的进程看到同一份资源&#xff08;这份资源不能隶属于任何一个进程&#xff0c;即应该是共享的&#xff09;。而进程间通信的目的是为了实现多进程之间的协同。 但由于进程运行具有独立性&#xff…

密码加密解密之路

1.背景 做数据采集&#xff0c;客户需要把他们那边的数据库连接信息存到我们系统里&#xff0c;那我们系统就要尽可能的保证这部分数据安全&#xff0c;不被盗。 2.我的思路 1.需要加密的地方有两处&#xff0c;一个是新增的时候前端传给后端的时候&#xff0c;一个是存到数…

异步爬取+多线程+redis构建一个运转丝滑且免费http-ip代理池 (三)

内容提要: 如果说,爬取网页数据的时候,我们使用了异步,那么将数据放入redis里面,其实也需要进行异步;当然,如果使用多线程或者redis线程池技术也是可以的,但那会造成冗余; 因此,在测试完多线程redis搭配异步爬虫的时候,我发现效率直接在redis这里被无限拉低下来! 因此: 最终的r…

从0开始学习JavaScript--JavaScript中的集合类

JavaScript中的集合类是处理数据的关键&#xff0c;涵盖了数组、Set、Map等多种数据结构。本文将深入研究这些集合类的创建、操作&#xff0c;以及实际应用场景&#xff0c;并通过丰富的示例代码&#xff0c;帮助大家更全面地了解和应用这些概念。 数组&#xff08;Array&…