【PyTorch】 暂退法(dropout)

文章目录

  • 1. 理论介绍
  • 2. 实例解析
    • 2.1. 实例描述
    • 2.2. 代码实现
      • 2.2.1. 主要代码
      • 2.2.2. 完整代码
      • 2.2.3. 输出结果

1. 理论介绍

  • 线性模型泛化的可靠性是有代价的,因为线性模型没有考虑到特征之间的交互作用,由此模型灵活性受限。
  • 泛化性和灵活性之间的基本权衡被描述为偏差-方差权衡
    • 线性模型有很高的偏差,因此它们只能表示一小类函数,但其方差很低,因此它们在不同的随机数据样本上可以得出相似的结果。
    • 神经网络不局限于单独查看每个特征,而是学习特征之间的交互,但即使我们有比特征多得多的样本,深度神经网络也有可能过拟合。
  • 经典泛化理论认为,为了缩小训练和测试性能之间的差距,应该以简单的模型为目标。 简单性以较小维度的形式展现,简单性的另一个角度是平滑性,即函数不应该对其输入的微小变化敏感
  • 暂退法在前向传播过程中,计算每一内部层的同时注入噪声。 因为当训练一个有多层的深层网络时,注入噪声只会在输入-输出映射上增强平滑性。从表面上看是在训练过程中丢弃(drop out)一些神经元。 在整个训练过程的每一次迭代中,标准暂退法包括在计算下一层之前将当前层中的一些节点置零。
  • 神经网络过拟合与每一层都依赖于前一层激活值相关,这种情况称为共适应性,而暂退法会破坏共适应性。
  • 可以以一种无偏向的方式注入噪声,在固定住其他层时,每一层的期望值等于没有噪音时的值。
  • 标准暂退正则化通过按未丢弃的节点的分数进行规范化来消除每一层的偏差,换言之,每个中间活性值 h h h以暂退概率 p p p由随机变量 h ′ h' h替换,即:
    h ′ = { 0 概率为  p h 1 − p 其他情况 \begin{aligned} h' = \begin{cases} 0 & \text{ 概率为 } p \\ \frac{h}{1-p} & \text{ 其他情况} \end{cases} \end{aligned} h={01ph 概率为 p 其他情况
  • 通常,我们在测试时不用暂退法。 给定一个训练好的模型和一个新的样本,我们不会丢弃任何节点,因此不需要标准化。 然而也有一些例外:一些研究人员在测试时使用暂退法, 用于估计神经网络预测的不确定性: 如果通过许多不同的暂退法遮盖后得到的预测结果都是一致的,那么我们可以说网络发挥更稳定。
  • 我们可以将暂退法应用于每个隐藏层的输出(在激活函数之后), 并且可以为每一层分别设置暂退概率,常见的技巧是在靠近输入层的地方设置较低的暂退概率。

2. 实例解析

2.1. 实例描述

使用具有两个隐藏层的多层感知机和暂退法,拟合Fashion-MNIST数据集。

2.2. 代码实现

2.2.1. 主要代码

net = nn.Sequential(nn.Flatten(),nn.Linear(784, 256),nn.ReLU(),nn.Dropout(0.2),	# 暂退概率为0.2nn.Linear(256, 256),nn.ReLU(),nn.Dropout(0.5),	# 暂退概率为0.5nn.Linear(256, 10)
).to(device)

2.2.2. 完整代码

import os
from tensorboardX import SummaryWriter
from rich.progress import track
from torchvision.transforms import Compose, ToTensor
from torchvision.datasets import FashionMNIST
import torch
from torch.utils.data import DataLoader
from torch import nn, optimdef load_dataset():"""加载数据集"""root = "./dataset"transform = Compose([ToTensor()])mnist_train = FashionMNIST(root, True, transform, download=True)mnist_test = FashionMNIST(root, False, transform, download=True)dataloader_train = DataLoader(mnist_train, batch_size, shuffle=True, num_workers=num_workers,)dataloader_test = DataLoader(mnist_test, batch_size, shuffle=False,num_workers=num_workers,)return dataloader_train, dataloader_testif __name__ == '__main__':# 全局参数设置num_epochs = 10batch_size = 256num_workers = 3device = torch.device('cuda:0')lr = 0.5# 创建记录器def log_dir():root = "runs"if not os.path.exists(root):os.mkdir(root)order = len(os.listdir(root)) + 1return f'{root}/exp{order}'writer = SummaryWriter(log_dir=log_dir())# 数据集配置dataloader_train, dataloader_test = load_dataset()# 定义模型net = nn.Sequential(nn.Flatten(),nn.Linear(784, 256),nn.ReLU(),nn.Dropout(0.2),nn.Linear(256, 256),nn.ReLU(),nn.Dropout(0.5),nn.Linear(256, 10)).to(device)def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights)criterion = nn.CrossEntropyLoss(reduction='none')optimizer = optim.SGD(net.parameters(), lr=lr)# 训练循环for epoch in track(range(num_epochs), description='dropout'):for X, y in dataloader_train:X, y = X.to(device), y.to(device)loss = criterion(net(X), y)optimizer.zero_grad()loss.mean().backward()optimizer.step()with torch.no_grad():train_loss, train_acc, num_samples = 0.0, 0.0, 0for X, y in dataloader_train:X, y = X.to(device), y.to(device)y_hat = net(X)loss = criterion(y_hat, y)train_loss += loss.sum()train_acc += (y_hat.argmax(dim=1) == y).sum()num_samples += y.numel()train_loss /= num_samplestrain_acc /= num_samplestest_acc, num_samples = 0.0, 0for X, y in dataloader_test:X, y = X.to(device), y.to(device)y_hat = net(X)test_acc += (y_hat.argmax(dim=1) == y).sum()num_samples += y.numel()test_acc /= num_sampleswriter.add_scalars('metrics', {'train_loss': train_loss,'train_acc': train_acc,'test_acc': test_acc}, epoch)writer.close()

2.2.3. 输出结果

dropout

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/206749.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Docker构建自定义镜像

创建一个docker-demo的文件夹,放入需要构建的文件 主要是配置Dockerfile文件 第一种配置方法 # 指定基础镜像 FROM ubuntu:16.04 # 配置环境变量,JDK的安装目录 ENV JAVA_DIR/usr/local# 拷贝jdk和java项目的包 COPY ./jdk8.tar.gz $JAVA_DIR/ COPY ./docker-demo…

Java基础50题: 21.实现一个方法printArray, 以数组为参数,循环访问数组中的每个元素,打印每个元素的值.

概述 实现一个方法printArray, 以数组为参数,循环访问数组中的每个元素,打印每个元素的值. 代码 public static void printArray(int[] array) {for (int i 0; i < array.length; i) {System.out.println(array[i] " ");}System.out.println();}public static…

【日常总结】mybatis-plus WHERE BINARY 中文查不出来

目录 一、场景 二、问题 三、原因 四、解决方案 五、拓展&#xff08;全表全字段修改字符集一键更改&#xff09; 准备工作&#xff1a;做好整个库备份 1. 全表一键修改 Stage 1&#xff1a;运行如下查询 Stage 2&#xff1a;复制sql语句 Stage 3&#xff1a;执行即可…

100. 相同的树(Java)

目录 解法&#xff1a; 官方解法&#xff1a; 方法一&#xff1a;深度优先搜索 复杂度分析 时间复杂度&#xff1a; 空间复杂度&#xff1a; 方法二&#xff1a;广度优先搜索 复杂度分析 时间复杂度&#xff1a; 空间复杂度&#xff1a; 给你两棵二叉树的根节点 p 和…

L1-028:判断素数

题目描述 本题的目标很简单&#xff0c;就是判断一个给定的正整数是否素数。 输入格式&#xff1a; 输入在第一行给出一个正整数N&#xff08;≤ 10&#xff09;&#xff0c;随后N行&#xff0c;每行给出一个小于231的需要判断的正整数。 输出格式&#xff1a; 对每个需要判断的…

Kotlin(十五) 高阶函数详解

高阶函数的定义 高阶函数和Lambda的关系是密不可分的。在之前的文章中&#xff0c;我们熟悉了Lambda编程的基础知识&#xff0c;并且掌握了一些与集合相关的函数式API的用法&#xff0c;如map、filter函数等。另外&#xff0c;我们也了解了Kotlin的标准函数&#xff0c;如run、…

vuepress-----22、其他评论方案

vuepress 支持评论 本文讲述 vuepress 站点如何集成评论系统&#xff0c;选型是 valineleancloud, 支持匿名评论&#xff0c;缺点是数据没有存储在自己手里。市面上也有其他的方案, 如 gitalk,vssue 等, 但需要用户登录 github 才能发表评论, 但 github 经常无法连接,导致体验…

[wp]“古剑山”第一届全国大学生网络攻防大赛 Web部分wp

“古剑山”第一届全国大学生网络攻防大赛 群友说是原题杯 哈哈哈哈 我也不懂 我比赛打的少 Web Web | unse 源码&#xff1a; <?phpinclude("./test.php");if(isset($_GET[fun])){if(justafun($_GET[fun])){include($_GET[fun]);}}else{unserialize($_GET[…

使用cmake构建的工程的编译方法

1、克隆项目工程 2、进入到工程目录 3、执行 mkdir build && cd build 4、执行 cmake .. 5、执行 make 执行以上步骤即可完成对cmake编写的工程进行编译 &#xff0c;后面只需执行你的编译结果即可 $ git clone 你想要克隆的代码路径 $ cd 代码文件夹 $ mkdir bu…

vmware安装centos7总结

vmware安装centos7总结 文章目录 vmware安装centos7总结一、配置网络&#xff08;桥接模式&#xff09;二、配置yum源&#xff08;连网配置&#xff09;三、可视化界面四、安装Docker五、安装DockerUI 一、配置网络&#xff08;桥接模式&#xff09; 网络连接模式选择桥接模式…

Ubuntu安装nvidia GPU显卡驱动教程

Ubuntu安装nvidia显卡驱动 1.安装前安装必要的依赖 sudo apt-get install build-essential sudo apt-get install g sudo apt-get install make2.到官网下载对应驱动 https://www.nvidia.cn/Download/index.aspx?langcn 3.卸载原有驱动 sudo apt-get remove --purge nvidi…

深度学习:注意力机制(Attention Mechanism)

1 注意力机制概述 1.1 定义 注意力机制&#xff08;Attention Mechanism&#xff09;是深度学习领域中的一种重要技术&#xff0c;特别是在序列模型如自然语言处理&#xff08;NLP&#xff09;和计算机视觉中。它使模型能够聚焦于输入数据的重要部分&#xff0c;从而提高整体…

孩子都能学会的FPGA:第二十五课——用FPGA实现频率计

&#xff08;原创声明&#xff1a;该文是作者的原创&#xff0c;面向对象是FPGA入门者&#xff0c;后续会有进阶的高级教程。宗旨是让每个想做FPGA的人轻松入门&#xff0c;作者不光让大家知其然&#xff0c;还要让大家知其所以然&#xff01;每个工程作者都搭建了全自动化的仿…

基于SpringBoot+maven+Mybatis+html慢性病报销系统(源码+数据库)

一、项目简介 本项目是一套基于SpringBootmavenMybatishtml慢性病报销系统&#xff0c;主要针对计算机相关专业的正在做bishe的学生和需要项目实战练习的Java学习者。 包含&#xff1a;项目源码、数据库脚本等&#xff0c;该项目可以直接作为bishe使用。 项目都经过严格调试&a…

二十一章(网络通信)

计算机网络实现了多台计算机间的互联&#xff0c;使得它们彼此之间能够进行数据交流。网络应用程序就是在已连接的不同计算机上运行的程序&#xff0c;这些程序借助于网络协议&#xff0c;相互之间可以交换数据。编写网络应用程序前&#xff0c;首先必须明确所要使用的网络协议…

C++_命名空间(namespace)

目录 1、namespace的重要性 2、 namespace的定义及作用 2.1 作用域限定符 3、命名空间域与全局域的关系 4、命名空间的嵌套 5、展开命名空间的方法 5.1 特定展开 5.1 部分展开 5.2 全部展开 结语&#xff1a; 前言&#xff1a; C作为c语言的“升级版”&#xff0c;其在…

异常检测 | MATLAB实现BiLSTM(双向长短期记忆神经网络)数据异常检测

异常检测 | MATLAB实现BiLSTM(双向长短期记忆神经网络)数据异常检测 目录 异常检测 | MATLAB实现BiLSTM(双向长短期记忆神经网络)数据异常检测效果一览基本介绍模型准备模型设计参考资料效果一览 基本介绍 训练一个双向 LSTM 自动编码器来检测机器是否正常工作。 自动编码器接受…

CleanMyMac X2024最新版本软件实用性测评

信大多数MAC用户都较为了解&#xff0c;Mac虽然有着许多亮点的性能&#xff0c;但是让用户叫苦不迭的还其硬盘空间小的特色&#xff0c;至于很多人因为文件堆积以及软件缓存等&#xff0c;造成系统空间内存不够使用的情况。于是清理工具就成为了大多数MAC用户使用频率较高的实用…

二十一章网络通信

计算机网络实现了多台计算机间的互联&#xff0c;使得它们彼此之间能够进行数据交流。网络应用程序就是在已连接的不同计算机上运行的程序&#xff0c;这些程序借助于网络协议&#xff0c;相互之间可以交换数据。编写网络应用程序前&#xff0c;首先必须明确所要使用的网络协议…

数据采集工具的大全【都是免费值得收藏】

数据是推动业务成功的关键之一。为了获取准确、全面的信息&#xff0c;数据采集成为了许多企业和个人的必备工作。本文将专注于数据采集工具&#xff0c;探讨其在全网和指定网站采集方面的优势&#xff0c;为大家提供对比分析&#xff0c;以帮助大家找到最适合的数据采集利器。…