【深度学习笔记】 3_13 丢弃法

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图

3.13 丢弃法

除了前一节介绍的权重衰减以外,深度学习模型常常使用丢弃法(dropout)[1] 来应对过拟合问题。丢弃法有一些不同的变体。本节中提到的丢弃法特指倒置丢弃法(inverted dropout)。

3.13.1 方法

回忆一下,3.8节(多层感知机)的图3.3描述了一个单隐藏层的多层感知机。其中输入个数为4,隐藏单元个数为5,且隐藏单元 h i h_i hi i = 1 , … , 5 i=1, \ldots, 5 i=1,,5)的计算表达式为

h i = ϕ ( x 1 w 1 i + x 2 w 2 i + x 3 w 3 i + x 4 w 4 i + b i ) h_i = \phi\left(x_1 w_{1i} + x_2 w_{2i} + x_3 w_{3i} + x_4 w_{4i} + b_i\right) hi=ϕ(x1w1i+x2w2i+x3w3i+x4w4i+bi)

这里 ϕ \phi ϕ是激活函数, x 1 , … , x 4 x_1, \ldots, x_4 x1,,x4是输入,隐藏单元 i i i的权重参数为 w 1 i , … , w 4 i w_{1i}, \ldots, w_{4i} w1i,,w4i,偏差参数为 b i b_i bi。当对该隐藏层使用丢弃法时,该层的隐藏单元将有一定概率被丢弃掉。设丢弃概率为 p p p,那么有 p p p的概率 h i h_i hi会被清零,有 1 − p 1-p 1p的概率 h i h_i hi会除以 1 − p 1-p 1p做拉伸。丢弃概率是丢弃法的超参数。具体来说,设随机变量 ξ i \xi_i ξi为0和1的概率分别为 p p p 1 − p 1-p 1p。使用丢弃法时我们计算新的隐藏单元 h i ′ h_i' hi

h i ′ = ξ i 1 − p h i h_i' = \frac{\xi_i}{1-p} h_i hi=1pξihi

由于 E ( ξ i ) = 1 − p E(\xi_i) = 1-p E(ξi)=1p,因此

E ( h i ′ ) = E ( ξ i ) 1 − p h i = h i E(h_i') = \frac{E(\xi_i)}{1-p}h_i = h_i E(hi)=1pE(ξi)hi=hi

丢弃法不改变其输入的期望值。让我们对图3.3中的隐藏层使用丢弃法,一种可能的结果如图3.5所示,其中 h 2 h_2 h2 h 5 h_5 h5被清零。这时输出值的计算不再依赖 h 2 h_2 h2 h 5 h_5 h5,在反向传播时,与这两个隐藏单元相关的权重的梯度均为0。由于在训练中隐藏层神经元的丢弃是随机的,即 h 1 , … , h 5 h_1, \ldots, h_5 h1,,h5都有可能被清零,输出层的计算无法过度依赖 h 1 , … , h 5 h_1, \ldots, h_5 h1,,h5中的任一个,从而在训练模型时起到正则化的作用,并可以用来应对过拟合。在测试模型时,我们为了拿到更加确定性的结果,一般不使用丢弃法。

在这里插入图片描述

图3.5 隐藏层使用了丢弃法的多层感知机

3.13.2 从零开始实现

根据丢弃法的定义,我们可以很容易地实现它。下面的dropout函数将以drop_prob的概率丢弃X中的元素。

%matplotlib inline
import torch
import torch.nn as nn
import numpy as np
import sys
sys.path.append("..") 
import d2lzh_pytorch as d2ldef dropout(X, drop_prob):X = X.float()assert 0 <= drop_prob <= 1  #如果 drop_prob 小于0或大于1,这个断言将失败,程序将停止执行并显示一个错误信息。keep_prob = 1 - drop_prob# 这种情况下把全部元素都丢弃if keep_prob == 0:return torch.zeros_like(X)mask = (torch.rand(X.shape) < keep_prob).float()return mask * X / keep_prob

我们运行几个例子来测试一下dropout函数。其中丢弃概率分别为0、0.5和1。

X = torch.arange(16).view(2, 8)
dropout(X, 0)

在这里插入图片描述

dropout(X, 0.5)

在这里插入图片描述

dropout(X, 1.0)

在这里插入图片描述

3.13.2.1 定义模型参数

实验中,我们依然使用3.6节(softmax回归的从零开始实现)中介绍的Fashion-MNIST数据集。我们将定义一个包含两个隐藏层的多层感知机,其中两个隐藏层的输出个数都是256。

num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256W1 = torch.tensor(np.random.normal(0, 0.01, size=(num_inputs, num_hiddens1)), dtype=torch.float, requires_grad=True)
b1 = torch.zeros(num_hiddens1, requires_grad=True)
W2 = torch.tensor(np.random.normal(0, 0.01, size=(num_hiddens1, num_hiddens2)), dtype=torch.float, requires_grad=True)
b2 = torch.zeros(num_hiddens2, requires_grad=True)
W3 = torch.tensor(np.random.normal(0, 0.01, size=(num_hiddens2, num_outputs)), dtype=torch.float, requires_grad=True)
b3 = torch.zeros(num_outputs, requires_grad=True)params = [W1, b1, W2, b2, W3, b3]

3.13.2.2 定义模型

下面定义的模型将全连接层和激活函数ReLU串起来,并对每个激活函数的输出使用丢弃法。我们可以分别设置各个层的丢弃概率。通常的建议是把靠近输入层的丢弃概率设得小一点。在这个实验中,我们把第一个隐藏层的丢弃概率设为0.2,把第二个隐藏层的丢弃概率设为0.5。我们可以通过参数is_training来判断运行模式为训练还是测试,并只需在训练模式下使用丢弃法。

drop_prob1, drop_prob2 = 0.2, 0.5def net(X, is_training=True):X = X.view(-1, num_inputs)H1 = (torch.matmul(X, W1) + b1).relu()if is_training:  # 只在训练模型时使用丢弃法H1 = dropout(H1, drop_prob1)  # 在第一层全连接后添加丢弃层H2 = (torch.matmul(H1, W2) + b2).relu()if is_training:H2 = dropout(H2, drop_prob2)  # 在第二层全连接后添加丢弃层return torch.matmul(H2, W3) + b3

我们在对模型评估的时候不应该进行丢弃,所以我们修改一下d2lzh_pytorch中的evaluate_accuracy函数:

# 本函数已保存在d2lzh_pytorch
def evaluate_accuracy(data_iter, net):acc_sum, n = 0.0, 0for X, y in data_iter:if isinstance(net, torch.nn.Module):net.eval() # 评估模式, 这会关闭dropoutacc_sum += (net(X).argmax(dim=1) == y).float().sum().item()net.train() # 改回训练模式else: # 自定义的模型if('is_training' in net.__code__.co_varnames): # 如果有is_training这个参数# 将is_training设置成Falseacc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item() else:acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() n += y.shape[0]return acc_sum / n

注:将上诉evaluate_accuracy写回d2lzh_pytorch后要重启一下jupyter kernel才会生效。

3.13.2.3 训练和测试模型

这部分与之前多层感知机的训练和测试类似。

num_epochs, lr, batch_size = 5, 100.0, 256
loss = torch.nn.CrossEntropyLoss()
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr)

输出:

epoch 1, loss 0.0044, train acc 0.574, test acc 0.648
epoch 2, loss 0.0023, train acc 0.786, test acc 0.786
epoch 3, loss 0.0019, train acc 0.826, test acc 0.825
epoch 4, loss 0.0017, train acc 0.839, test acc 0.831
epoch 5, loss 0.0016, train acc 0.849, test acc 0.850

注:这里的学习率设置的很大,原因同3.9.6节。

3.13.3 简洁实现

在PyTorch中,我们只需要在全连接层后添加Dropout层并指定丢弃概率。在训练模型时,Dropout层将以指定的丢弃概率随机丢弃上一层的输出元素;在测试模型时(即model.eval()后),Dropout层并不发挥作用。

net = nn.Sequential(d2l.FlattenLayer(),nn.Linear(num_inputs, num_hiddens1),nn.ReLU(),nn.Dropout(drop_prob1),nn.Linear(num_hiddens1, num_hiddens2), nn.ReLU(),nn.Dropout(drop_prob2),nn.Linear(num_hiddens2, 10))for param in net.parameters():nn.init.normal_(param, mean=0, std=0.01)

下面训练并测试模型。

optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None, optimizer)

输出:

epoch 1, loss 0.0045, train acc 0.553, test acc 0.715
epoch 2, loss 0.0023, train acc 0.784, test acc 0.793
epoch 3, loss 0.0019, train acc 0.822, test acc 0.817
epoch 4, loss 0.0018, train acc 0.837, test acc 0.830
epoch 5, loss 0.0016, train acc 0.848, test acc 0.839

注:由于这里使用的是PyTorch的SGD而不是d2lzh_pytorch里面的sgd,所以就不存在3.9.6节那样学习率看起来很大的问题了。

小结

  • 我们可以通过使用丢弃法应对过拟合。
  • 丢弃法只在训练模型时使用。

参考文献

[1] Srivastava, N., Hinton, G., Krizhevsky, A., Sutskever, I., & Salakhutdinov, R. (2014). Dropout: a simple way to prevent neural networks from overfitting. JMLR


注:本节除了代码之外与原书基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/704560.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

阿里面试:最佳线程数,如何确定?

尼恩说在前面 在40岁老架构师 尼恩的读者交流群(50)中&#xff0c;最近有小伙伴拿到了一线互联网企业如得物、阿里、滴滴、极兔、有赞、shein 希音、百度、网易的面试资格&#xff0c;遇到很多很重要的面试题&#xff1a; 如何确定系统的最佳线程数&#xff1f; 小伙伴 没有回…

机器学习深度解析:原理、应用与前景

随着人工智能的迅速发展&#xff0c;机器学习已经成为当今时代最为引人注目的技术之一。它不仅仅是一种技术或工具&#xff0c;更是一种推动社会进步、影响人类生活的重要力量。那么&#xff0c;什么是机器学习&#xff1f;它是如何工作的&#xff1f;又在哪些领域中发挥着不可…

阿里云服务器ECS u1实例性能怎么样?

阿里云服务器ECS u1实例&#xff0c;2核4G&#xff0c;5M固定带宽&#xff0c;80G ESSD Entry盘优惠价格199元一年&#xff0c;性能很不错&#xff0c;CPU采用Intel Xeon Platinum可扩展处理器&#xff0c;购买限制条件为企业客户专享&#xff0c;实名认证信息是企业用户即可&a…

介绍一下我们:久菜盒子工作室

大数据科学团队/全网可搜索的久菜盒子工作室 我们是&#xff1a;985硕博/美国全奖doctor/计算机7年产品负责人/医学大数据公司医学研究员/SCI一区2篇/Nature子刊一篇/中文二区核心一篇/都是我们 主要领域&#xff1a;医学大数据分析/经管数据分析/金融模型/统计数理基础/统计学…

编程笔记 Golang基础 028 结构体与JSON

编程笔记 Golang基础 028 结构体与JSON 一、JSON二、结构体转JSON&#xff08;序列化&#xff09;三、JSON转结构体&#xff08;反序列化&#xff09;小结 结构体与JSON之间的相互转换是现代软件开发中数据处理的基础工具&#xff0c;极大地简化了数据在不同层次、不同组件间的…

spring boot 集成科大讯飞星火认知大模型

一、安装依赖 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://maven.apache.org/POM/…

Educational Codeforces Round 160 (Rated for Div. 2) D. Array Collapse(笛卡尔树+DP)

原题链接&#xff1a;D. Array Collapse 题目大意&#xff1a; 给你一个长度为 n n n 的排列 p p p &#xff0c;排列的定义为 [ 1 , 2 , 3 , . . , n ] [1,2,3,..,n] [1,2,3,..,n] 中每个数都出现 恰好 一次。 你可以做 任意多次 这样的操作&#xff1a; 选出一个任意长度…

前端导出EXCEL

步骤解析 定义了一个名为 excelDown 的函数&#xff0c;它接受两个参数&#xff1a;res 和 type。res 是包含响应数据的对象&#xff0c;type 是要导出的文件类型。如果 type 未提供&#xff0c;则默认使用 Excel 文件的 MIME 类型。 export const excelDown (res, type) >…

unity导航网格无法烘培到台阶和斜坡

如图是我在b站学Unity导航网格时建的一个示例场景&#xff0c;本场景使用的为棱长1m的立方体&#xff0c;读者可以以此为参照度量其他物体大小。 可见导航网格根本无法烘焙到斜坡和台阶上&#xff0c;为解决问题我做了不少尝试&#xff0c;调整最大坡度和步高都没办法解决问题…

AI新纪元:可能的盈利之道

本文来源于Twitter大神宝玉&#xff08;dotey&#xff09;在聊 Sora 的时候&#xff0c;总结了 Sora 的价值和可能的盈利方向&#xff0c;我把这部分内容单独摘出来再整理一下。现在的生成式 AI 大家应该不陌生&#xff0c;用它总结文章、翻译、写作、画图&#xff0c;当然真正…

搭建私有Git服务器:GitLab部署详解

引言&#xff1a; 为了方便团队协作和代码管理&#xff0c;许多组织选择搭建自己的私有Git服务器。GitLab是一个集成了Git版本控制、项目管理、代码审查等功能的开源平台&#xff0c;是搭建私有Git服务器的理想选择。 目录 引言&#xff1a; 一、准备工作 在开始部署GitLab之…

Dockerfile和jar包不同目录处理

如果Dockerfile的全路径为/srm/myDockerfile/Dockerfile&#xff0c;而JAR文件位于/srm目录下&#xff0c;你可以在Dockerfile中使用相对路径引用JAR文件。以下是如何编写Dockerfile的示例&#xff1a; 假设你的项目结构如下&#xff1a; luaCopy code /srm |-- myDockerfile …

Map集合的遍历方式

遍历Map集合的几种方式 迭代器(Iterator)forlambdaStream 代码示例 package com.haimeng.Array;import java.security.Key; import java.util.HashMap; import java.util.Iterator; import java.util.Map;public class Lambda1 {public static void main(String[] args) {//…

MySQL数据库基础(十五):PyMySQL使用介绍

文章目录 PyMySQL使用介绍 一、为什么要学习PyMySQL 二、安装PyMySQL模块 三、PyMySQL的使用 1、导入 pymysql 包 2、创建连接对象 3、获取游标对象 4、pymysql完成数据的查询操作 5、pymysql完成对数据的增删改 PyMySQL使用介绍 提前安装MySQL数据库&#xff08;可以…

shell脚本介绍及基本功能

目录 1. 什么是shell 2. hello word 2.1 echo 2.2 第一个脚本 3. Bash的基本功能 3.1别名 3.2 常用快捷键 3.3 输入输出 3.4 输出重定向 3.5 多命令执行 3.6 管道符 3.7 通配符和特殊符号 1. 什么是shell Shell 是一个用 C 语言编写的程序&#xff0c;它是用户使用…

数据分析---常见处理逻辑

目录 数据清洗数据转换数据聚合数据筛选增删改查(以查为例)数据清洗 去除重复值:使用DISTINCT关键字去除重复行。//这将返回一个包含所有不重复城市的结果集 SELECT DISTINCT city FROM students;处理缺失值:使用IS NULL或IS NOT NULL判断是否为空值,并使用COALESCE或CASE…

STM32--低功耗模式详解

一、PWR简介 正常模式与睡眠模式耗电是mA级&#xff0c;停机模式与待机模式是uA级。 二、电源框图 供电区域有三处&#xff0c;分别是模拟部分供电&#xff08;VDDA&#xff09;&#xff0c;数字部分供电&#xff0c;包括VDD供电区域和1.8V供电区域&#xff0c;后备供电&…

mysql和redis双写一致性策略分析

mysql和redis双写一致性策略分析 一.什么是双写一致性 当我们更新了mysql中的数据后也可以同时保证redis中的数据同步更新&#xff1b; 数据读取的流程&#xff1a; 1.读取redis,如果value!null,直接返回&#xff1b; 2.如果redis中valuenull&#xff0c;读取mysql中数据对应的…

【程序员养生延寿系列-万人关注的养生指南】

一.程序员面临的健康问题 应该说不只程序员&#xff0c;大部分互联网从业者&#xff0c;都会遇到很多类似的健康问题&#xff0c;比如&#xff1a; 心理压力大&#xff0c;失眠长期加班久坐不动熬夜&#xff0c;甚至通宵作息不规律饮食不均衡 短期可能不会表现出来&#xff…

MMDetection3D v1.1.0安装教程

MMDetection3D v1.1.0安装 1. 系统环境2. 安装2.1 基本环境安装2.2 验证2.3 安装MinkowskiEngine和TorchSparse 3. 最终环境配置 根据 v1.1.0版本官方手册&#xff0c;测试后的安装配置&#xff0c;亲测可行 1. 系统环境 项目版本日期Ubuntu18.04.06 LTS-显卡RTX 2070-显卡驱…