动手学深度学习(Pytorch版)代码实践 -深度学习基础-11暂退法Dropout

11暂退法Dropout

#Dropout 是一种正则化技术,主要用于防止过拟合,
#通过在训练过程中随机丢弃神经元来提高模型的泛化能力。
import torch
from torch import nn
from d2l import torch as d2l
import liliPytorch as lpdef dropout_layer(X, dropout):assert 0 <= dropout <= 1#该情况下,所有元素都被丢弃if dropout == 1:return torch.zeros_like(X)#该情况下所有元素都被保留if dropout == 0:return Xmask = (torch.rand(X.shape) > dropout).float()"""生成一个与 X 形状相同的掩码张量 mask。torch.rand(X.shape) 生成一个元素值在 [0, 1) 范围内的均匀分布的随机张量。mask 中的每个元素与 dropout 进行比较,若大于 dropout 则为 1(保留),否则为 0(丢弃)。最后将布尔值转换为浮点数。"""#将 mask 和 X 元素逐位相乘,以应用掩码效果,即丢弃部分神经元。#为了保持输出的期望值不变,结果除以 (1.0 - dropout) 进行缩放补偿。return mask * X / (1.0 - dropout)#测试dropout_layer函数
X= torch.arange(16, dtype = torch.float32).reshape((2, 8))
print(X)
print(dropout_layer(X, 0.))
print(dropout_layer(X, 0.5))
print(dropout_layer(X, 1.))
"""
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],[ 8.,  9., 10., 11., 12., 13., 14., 15.]])
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],[ 8.,  9., 10., 11., 12., 13., 14., 15.]])
tensor([[ 0.,  0.,  4.,  6.,  0.,  0., 12.,  0.],[16.,  0., 20.,  0.,  0., 26., 28.,  0.]])
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0.]])"""
num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256dropout1, dropout2 = 0.2, 0.5
#dropout1:第一个隐藏层之后的 dropout 比例为 20%。
#dropout2:第二个隐藏层之后的 dropout 比例为 50%。class Net(nn.Module):def __init__(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2,is_training = True):super(Net, self).__init__()self.num_inputs = num_inputsself.training = is_trainingself.lin1 = nn.Linear(num_inputs, num_hiddens1)self.lin2 = nn.Linear(num_hiddens1, num_hiddens2)self.lin3 = nn.Linear(num_hiddens2, num_outputs)self.relu = nn.ReLU()def forward(self, X):#将输入张量X重塑为 (batch_size, num_inputs) 的形状。H1 = self.relu(self.lin1(X.reshape((-1, self.num_inputs))))# 只有在训练模型时才使用dropoutif self.training == True:# 在第一个全连接层之后添加一个dropout层H1 = dropout_layer(H1, dropout1)H2 = self.relu(self.lin2(H1))if self.training == True:# 在第二个全连接层之后添加一个dropout层H2 = dropout_layer(H2, dropout2)out = self.lin3(H2)return outnet = Net(num_inputs, num_outputs, num_hiddens1, num_hiddens2)num_epochs, lr, batch_size = 10, 0.5, 256
loss = nn.CrossEntropyLoss(reduction='none')
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
trainer = torch.optim.SGD(net.parameters(), lr=lr)
# lp.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
# d2l.plt.show() #Dropout 简洁实现
net = nn.Sequential(nn.Flatten(),#它会将多维的输入张量展平成一维。nn.Linear(784,256),nn.ReLU(),#在第一个全连接层之后添加一个dropout层nn.Dropout(dropout1),nn.Linear(256,256),nn.ReLU(),#在第二个全连接层后添加一个dropout层nn.Dropout(dropout2),nn.Linear(256,10)
)#函数接受一个参数 m,通常是一个神经网络模块(例如,线性层,卷积层等)
def init_weights(m):
#这行代码检查传入的模块 m 是否是 nn.Linear 类型,即线性层(全连接层)if type(m) == nn.Linear:nn.init.normal_(m.weight,std=0.01)
#m.weight 是线性层的权重矩阵。
#std=0.01 指定了初始化权重的标准差为 0.01,表示权重将从均值为0,标准差为0.01的正态分布中随机采样。#model.apply(init_weights) 会遍历模型的所有模块,并对每个模块调用 init_weights 函数。
#如果模块是 nn.Linear 类型,则初始化它的权重。
net.apply(init_weights)trainer = torch.optim.SGD(net.parameters(), lr = lr)
lp.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
d2l.plt.show() 

运行结果:

epoch: 1,train_loss: 1.1632388224283854,train_acc: 0.55005,test_acc: 0.7137
<Figure size 350x250 with 1 Axes>
epoch: 2,train_loss: 0.5765969015757243,train_acc: 0.7862833333333333,test_acc: 0.7971
<Figure size 350x250 with 1 Axes>
epoch: 3,train_loss: 0.5013401063283285,train_acc: 0.8177166666666666,test_acc: 0.6976
<Figure size 350x250 with 1 Axes>
epoch: 4,train_loss: 0.46441060066223144,train_acc: 0.8299666666666666,test_acc: 0.837
<Figure size 350x250 with 1 Axes>
epoch: 5,train_loss: 0.4177045190811157,train_acc: 0.8482,test_acc: 0.8348
<Figure size 350x250 with 1 Axes>
epoch: 6,train_loss: 0.4039476199467977,train_acc: 0.8522,test_acc: 0.8376
<Figure size 350x250 with 1 Axes>
epoch: 7,train_loss: 0.38559712861378986,train_acc: 0.8593333333333333,test_acc: 0.8499
<Figure size 350x250 with 1 Axes>
epoch: 8,train_loss: 0.37514646828969317,train_acc: 0.86315,test_acc: 0.8587
<Figure size 350x250 with 1 Axes>
epoch: 9,train_loss: 0.36000535113016763,train_acc: 0.8681166666666666,test_acc: 0.853
<Figure size 350x250 with 1 Axes>
epoch: 10,train_loss: 0.3473748308181763,train_acc: 0.8719333333333333,test_acc: 0.85

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/855197.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

大数据—“西游记“全集文本数据挖掘分析实战教程

项目背景介绍 四大名著&#xff0c;又称四大小说&#xff0c;是汉语文学中经典作品。这四部著作历久不衰&#xff0c;其中的故事、场景&#xff0c;已经深深地影响了国人的思想观念、价值取向。四部著作都有很高的艺术水平&#xff0c;细致的刻画和所蕴含的思想都为历代读者所…

0元体验苹果macOS系统,最简单的虚拟机部署macOS教程

前言 最近发现小伙伴热衷于在VMware上安装体验macOS系统&#xff0c;所以就有了今天的帖子。 正文开始 首先&#xff0c;鉴于小伙伴们热衷macOS&#xff0c;所以小白搜罗了一圈macOS系统&#xff0c;并开启了分享通道。 本次更新的系统版本是&#xff1a; macOS 10.13.6 ma…

【靶场搭建】-01- 在kali上搭建DVWA靶机

1.DVWA靶机 DVWA&#xff08;Damn Vulnerable Web Application&#xff09;是使用PHPMysql编写的web安全测试框架&#xff0c;主要用于安全人员在一个合法的环境中测试技能和工具。 2.下载DVWA 从GitHub上将DVWA的源码clone到kali上 git clone https://github.com/digininj…

温湿度采集与OLED显示

目录 一、什么是软件I2C 二、什么是硬件I2C 三、STM32CubeMX配置 1、RCC配置 2、SYS配置 3、I2C1配置 3、I2C2配置 4、USART1配置 5、TIM1配置 6、时钟树配置 7、工程配置 四、设备链接 1、OLED连接 2、串口连接 3、温湿度传感器连接 五、每隔2秒钟采集一次温湿…

第二十三篇——香农第二定律(二):到底要不要扁平化管理?

目录 一、背景介绍二、思路&方案三、过程1.思维导图2.文章中经典的句子理解3.学习之后对于投资市场的理解4.通过这篇文章结合我知道的东西我能想到什么&#xff1f; 四、总结五、升华 一、背景介绍 对于企业的理解&#xff0c;扁平化的管理&#xff0c;如果从香农第二定律…

Qt 实战(5)布局管理器 | 5.2、深入解析Qt布局管理器

文章目录 一、深入解析Qt布局管理器1、为什么要使用布局管理器&#xff1f;2、布局管理器类型3、布局管理器用法详解3.1、QBoxLayout&#xff08;垂直与水平布局&#xff09;3.2、QGridLayout&#xff08;网格布局&#xff09;3.3、QFormLayout&#xff08;表单布局&#xff09…

特斯拉、路特斯、中国一汽、毕博、博世等企业将出席中国汽车供应链降碳和可持续国际峰会

由ECV International 举办的2024中国汽车供应链脱碳与可持续国际峰会将于2024年9月23-24日在上海召开。 在本次峰会上&#xff0c;来自全球各地的行业领袖、政策制定者、研究人员和利益相关者将齐聚一堂&#xff0c;商讨对于减少碳排放和促进整个汽车供应链可持续实践至关重要…

教学资源共享平台的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;管理员管理&#xff0c;老师管理&#xff0c;用户管理&#xff0c;成绩管理&#xff0c;教学资源管理&#xff0c;作业管理 老师账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;用…

什么是拷贝?我:Ctrl + C ...

前言 当谈及拷贝&#xff0c;你的第一印象会不会和我一样&#xff0c;ctrl c ctrl v ... &#xff1b;虽然效果和拷贝是一样的&#xff0c;但是你知道拷贝的原理以及它的实现方法吗&#xff1f;今天就让我们一起探究一下拷贝中深藏的知识点吧。 拷贝 首先来看下面一段代码…

MySQL数据库回顾(1)

数据库相关概念 关系型数据库 概念: 建立在关系模型基础上&#xff0c;由多张相互连接的二维表组成的数据库。 特点&#xff1a; 1.使用表存储数据&#xff0c;格式统一&#xff0c;便于维护 2.使用SQL语言操作&#xff0c;标准统一&#xff0c;使用方便 SOL SQL通用语法 …

粒子群算法PSO优化BP神经网络预测MATLAB代码实现(PSO-BP预测)

本文以MATLAB自带的脂肪数据集为例&#xff0c;数据集为EXCEL格式&#xff0c;接下来介绍粒子群算法优化BP神经网络预测的MATLAB代码步骤&#xff0c;主要流程包括1. 读取数据 2.划分训练集和测试集 3.归一化 4.确定BP神经网络的隐含层最优节点数量 5. 使用粒子群算法优化BP的神…

强得离谱,AI音乐的 Stable Diffusion: MusicGen

节前&#xff0c;我们星球组织了一场算法岗技术&面试讨论会&#xff0c;邀请了一些互联网大厂朋友、参加社招和校招面试的同学。 针对算法岗技术趋势、大模型落地项目经验分享、新手如何入门算法岗、该如何准备、面试常考点分享等热门话题进行了深入的讨论。 合集&#x…

打造完美Mac多屏视界,BetterDisplay Pro一键掌控!

BetterDisplay Pro for Mac是一款专为Mac用户打造的显示器管理与优化软件&#xff0c;旨在为用户带来卓越的视觉体验和工作效率。它凭借强大的功能和简洁易用的界面&#xff0c;成为了Mac用户优化显示器设置的得力助手。 一、全方位管理与优化 BetterDisplay Pro for Mac支持…

探索Python的多媒体解决方案:ffmpy库

文章目录 探索Python的多媒体解决方案&#xff1a;ffmpy库一、背景&#xff1a;数字化时代的多媒体处理二、ffmpy&#xff1a;Python与ffmpeg的桥梁三、安装ffmpy&#xff1a;轻松几步四、ffmpy的五项基本功能1. 转换视频格式2. 调整视频质量3. 音频转换4. 视频截图5. 视频合并…

java架构设计-COLA

参考&#xff1a;https://github.com/alibaba/COLA 架构 要素&#xff1a;组成架构的重要元素 结构&#xff1a;要素直接的关系 意义&#xff1a;定义良好的结构&#xff0c;治理应用复杂度&#xff0c;降低系统熵值&#xff0c;改善混乱状态 创建COLA应用&#xff1a; mvn …

Git的3个主要区域

一般来说&#xff0c;日常使用只要记住下图6个命令&#xff0c;就可以了。但是熟练使用&#xff0c;恐怕要记住60&#xff5e;100个命令。 下面是我整理的常用 Git 命令清单。几个专用名词的译名如下。 Workspace&#xff1a;工作区 Index / Stage&#xff1a;暂存区 Reposito…

你的企业真的适合做私域吗?

现在&#xff0c;都在提倡企业做私域&#xff0c;可是所有的企业都适合做私域吗&#xff1f;看看市场上成功的案例&#xff0c;显然&#xff0c;并不是所有企业都适合做私域&#xff0c;所以&#xff0c;做私域之前&#xff0c;企业也应该充分的分析&#xff0c;自己的优势是什…

spark常见问题

写文章只是为了学习总结或者工作内容备忘&#xff0c;不保证及时性和准确性&#xff0c;看到的权当个参考哈&#xff01; 1. 执行Broadcast大表时&#xff0c;等待超时异常&#xff08;awaitResult&#xff09; 现象&#xff1a;org.apache.spark.SparkException: Exception…

玩转OurBMC第八期:OpenBMC webui之通信交互

栏目介绍&#xff1a;“玩转OurBMC”是OurBMC社区开创的知识分享类栏目&#xff0c;主要聚焦于社区和BMC全栈技术相关基础知识的分享&#xff0c;全方位涵盖了从理论原理到实践操作的知识传递。OurBMC社区将通过“玩转OurBMC”栏目&#xff0c;帮助开发者们深入了解到社区文化、…

【网络】序列化和反序列化

一、序列化和反序列化 序列化和反序列化是计算机中用于数据存储和传输的重要概念。 1.序列化 &#xff08;Serialization&#xff09; 是将数据结构或对象转换成一种可存储或可传输格式的过程。在序列化后&#xff0c;数据可以被写入文件、发送到网络或存储在数据库中&…