深度学习——过拟合和Dropout

基本概念

什么是过拟合?

过拟合(Overfitting)是机器学习和深度学习中常见的问题之一,它指的是模型在训练数据上表现得很好,但在未见过的新数据上表现较差的现象。
当一个模型过度地学习了训练数据的细节和噪声,而忽略了数据中的一般规律和模式时,就会发生过拟合。过拟合是由于模型过于复杂或者训练数据过少,导致模型记住了训练数据中的每个细节,从而无法泛化到新数据。

解决方法

1.增加训练数据量:通过增加更多的训练数据,使得模型能够更好地学习数据的一般规律,而不是过多地依赖于少量的数据样本。
2.简化模型:减少模型的复杂度,如减少网络的层数、减少节点数、减少参数量等,从而降低过拟合的风险。
3.使用正则化技术:如L1正则化、L2正则化等,通过在损失函数中添加正则化项,惩罚过大的权重,防止模型过度拟合训练数据。
4.使用Dropout:在训练过程中随机丢弃部分神经元,减少模型的复杂性,有助于防止过拟合。
5.交叉验证:使用交叉验证来评估模型的性能,通过不同子集的训练集和测试集来评估模型的泛化能力。

Dropout

Dropout是一种用于减少过拟合问题的正则化技术,常用于深度神经网络训练中。是一种随机丢弃(drop)神经元的方法。
在正常的神经网络中,每个神经元都会对输入进行权重计算和传递,这样每个神经元都可能贡献过多,导致网络过拟合训练数据。Dropout通过在训练过程中随机丢弃一部分神经元,即在前向传播过程中以一定的概率将某些神经元的输出置为0,这样可以强制神经网络学习到更加鲁棒的特征。

对比加Dropout层和不加Dropout层

import torch
import matplotlib.pyplot as plt# 用于复现
# torch.manual_seed(1)    # reproducible# 20个数据点
N_SAMPLES = 20
# 隐藏层的个数为300
N_HIDDEN = 300# training data
# 在-1到1之间等差取N_SAMPLES个点,然后再加维度,最终的数据变为N_SAMPLES行、1列的向量
x = torch.unsqueeze(torch.linspace(-1, 1, N_SAMPLES), 1)
# 在均值为0、标准差为1的正态分布中采样N_SAMPLES个点的值,然后乘0.3,加上x,最后得到x对应的y值
y = x + 0.3*torch.normal(torch.zeros(N_SAMPLES, 1), torch.ones(N_SAMPLES, 1))# test data
test_x = torch.unsqueeze(torch.linspace(-1, 1, N_SAMPLES), 1)
test_y = test_x + 0.3*torch.normal(torch.zeros(N_SAMPLES, 1), torch.ones(N_SAMPLES, 1))# show data
plt.scatter(x.data.numpy(), y.data.numpy(), c='magenta', s=50, alpha=0.5, label='train')
plt.scatter(test_x.data.numpy(), test_y.data.numpy(), c='cyan', s=50, alpha=0.5, label='test')
plt.legend(loc='upper left')
plt.ylim((-2.5, 2.5))
plt.show()# 快速搭建神经网络,不加dropout层
net_overfitting = torch.nn.Sequential(torch.nn.Linear(1, N_HIDDEN),torch.nn.ReLU(),torch.nn.Linear(N_HIDDEN, N_HIDDEN),torch.nn.ReLU(),torch.nn.Linear(N_HIDDEN, 1),
)# 加了dropout层的
net_dropped = torch.nn.Sequential(torch.nn.Linear(1, N_HIDDEN),torch.nn.Dropout(0.5),  # drop 50% of the neurontorch.nn.ReLU(),torch.nn.Linear(N_HIDDEN, N_HIDDEN),torch.nn.Dropout(0.5),  # drop 50% of the neurontorch.nn.ReLU(),torch.nn.Linear(N_HIDDEN, 1),
)print(net_overfitting)  # net architecture
print(net_dropped)# 使用Adam优化神经网络的参数
optimizer_ofit = torch.optim.Adam(net_overfitting.parameters(), lr=0.01)
optimizer_drop = torch.optim.Adam(net_dropped.parameters(), lr=0.01)
# 误差函数使用MSELoss
loss_func = torch.nn.MSELoss()# 开启交互式绘图
plt.ion()   # something about plotting# 训练五百步
for t in range(500):# 将x输入到不加dropout层的神经网络中,得预测值pred_ofit = net_overfitting(x)# 将x输入到加了dropout层的神经网络中,得预测值pred_drop = net_dropped(x)# 计算lossloss_ofit = loss_func(pred_ofit, y)# 计算lossloss_drop = loss_func(pred_drop, y)# 梯度清零optimizer_ofit.zero_grad()optimizer_drop.zero_grad()# 误差反向传播loss_ofit.backward()loss_drop.backward()# 优化器逐步优化optimizer_ofit.step()optimizer_drop.step()# 每10步进行更新if t % 10 == 0:"""net_overfitting.eval()和net_dropped.eval()是将两个神经网络模型切换到评估模式,用于在测试数据上进行稳定的前向传播,得到准确的预测结果。"""# change to eval mode in order to fix drop out effectnet_overfitting.eval()net_dropped.eval()  # parameters for dropout differ from train mode# plottingplt.cla()test_pred_ofit = net_overfitting(test_x)test_pred_drop = net_dropped(test_x)plt.scatter(x.data.numpy(), y.data.numpy(), c='magenta', s=50, alpha=0.3, label='train')plt.scatter(test_x.data.numpy(), test_y.data.numpy(), c='cyan', s=50, alpha=0.3, label='test')plt.plot(test_x.data.numpy(), test_pred_ofit.data.numpy(), 'r-', lw=3, label='overfitting')plt.plot(test_x.data.numpy(), test_pred_drop.data.numpy(), 'b--', lw=3, label='dropout(50%)')plt.text(0, -1.2, 'overfitting loss=%.4f' % loss_func(test_pred_ofit, test_y).data.numpy(), fontdict={'size': 20, 'color':  'red'})plt.text(0, -1.5, 'dropout loss=%.4f' % loss_func(test_pred_drop, test_y).data.numpy(), fontdict={'size': 20, 'color': 'blue'})plt.legend(loc='upper left')plt.ylim((-2.5, 2.5))plt.pause(0.1)# change back to train mode"""在训练模式下,神经网络中的Dropout层将会生效,即在前向传播过程中会随机丢弃一部分神经元。这是为了在训练阶段增加模型的鲁棒性,避免过拟合。"""net_overfitting.train()net_dropped.train()# 关闭交互模式
plt.ioff()
plt.show()

运行效果

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/6486.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【多模态】17、CORA | 将 CLIP 使用到开集目标检测

文章目录 一、背景二、方法2.1 总体结构2.2 region prompting2.3 anchor pre-matching 三、效果 论文:CORA: Adapting CLIP for Open-Vocabulary Detection with Region Prompting and Anchor Pre-Matching 代码:https://github.com/tgxs002/CORA 出处…

Qt/C++音视频开发48-推流到rtsp服务器

一、前言 之前已经打通了rtmp的推流,理论上按照同样的代码,只要将rtmp推流地址换成rtsp推流地址,然后格式将flv换成rtsp就行,无奈直接遇到协议不支持的错误提示,网上说要换成rtp,换了也没用,而…

Linux 学习记录54(ARM篇)

Linux 学习记录54(ARM篇) 本文目录 Linux 学习记录54(ARM篇)一、框图分析1. 芯片手册内部框图2. 操作GPIO过程 二、通过汇编完成GPIO操作1. 常用的汇编指令2. GPIO初始化流程3. 查找相关寄存器(1. RCC寄存器(2. GPIO寄存器>1. 模式配置寄存器>2. 输出模式配置寄存器>3…

Jenkins常用管理功能配置 - 插件管理

Jenkins插件介绍 Jenkins是一个流行的开源持续集成/持续交付(CI/CD)工具,它有大量的插件来扩展其功能。这些插件可以用于构建、测试、部署和监控软件项目。下面是一些常用的Jenkins插件及其简单介绍和使用方法: 1. Git插件:允许Jenkins从Gi…

vue2如何将页面生成 pdf 导出 html2Canvas + jspdf

1.引入两个依赖 npm i html2canvas npm i jspdf 2.在utils文件夹下新建html2pdf.js文件 import html2canvas from html2canvas; import jsPDF from jspdf export const htmlToPDF async (htmlId, title "报表", bgColor "#fff") > { let pdfDom do…

【LeetCode每日一题合集】2023.7.17-2023.7.23(离线算法 环形子数组的最大和 接雨水)

文章目录 415. 字符串相加(高精度计算、大数运算)1851. 包含每个查询的最小区间⭐⭐⭐⭐⭐解法1——按区间长度排序 离线询问 并查集解法2——离线算法 优先队列 874. 模拟行走机器人(哈希表 方向数组)918. 环形子数组的最大和…

sentinel深入讲解流量控制/熔断降级

文章目录 sentinelsentinel介绍重要的核心概念引入依赖限流的规则熔断规则yaml 项目配置使用注解 SentinelResource讲解类的静态方法 sentinel sentinel介绍 随着微服务的流行,服务和服务之间的稳定性变得越来越重要。Sentinel 是面向分布式、多语言异构化服务架构…

【深度学习之YOLO8】环境部署

目录 一、确定版本CUDA toolkit、cuDNN版本Python、PyTorch版本 二、安装Python下载环境变量验证安装 三、安装Anaconda安装环境变量验证安装创建conda虚拟环境常用命令 四、安装CUDA toolkit下载环境变量验证安装 五、配置cuDNN下载 六、安装PyTorch(torchtorchversiontorchau…

华为、阿里巴巴、字节跳动 100+ Python 面试问题总结(五)

系列文章目录 个人简介:机电专业在读研究生,CSDN内容合伙人,博主个人首页 Python面试专栏:《Python面试》此专栏面向准备面试的2024届毕业生。欢迎阅读,一起进步!🌟🌟🌟 …

RUST腐蚀基因种植

RUST腐蚀基因种植 试验地址:www.xiaocao.cloud RUST基因: RUST基因计算器,腐蚀基因计算器,前后端分离架构,前端目录/resouce/ui/rust,欢迎大佬评价,

算法笔记(java)——回溯篇

回溯算法解决问题最有规律性,借用一下卡哥的图: 只要遇到上述问题就可以考虑使用回溯,回溯法的效率并不高,是一种暴力解法,其代码是嵌套在for循环中的递归,用来解决暴力算法解决不了的问题,即…

Tensorflow无人车使用移动端的SSD(单发多框检测)来识别物体及Graph的认识

环境是树莓派3B,当然这里安装tensorflow并不是一定要在树莓派环境,只需要是ARM架构就行,也就是目前市场上绝大部分的嵌入式系统都是用这套精简指令集。 在电脑端的检测,有兴趣的可以查阅SSD(Single Shot MultiBox Detector)系列&a…

19 QListWidget控件

Tips: 对于列表式数据可以使用QStringList进行左移一块输入。 代码: //listWidget使用 // QListWidgetItem * item new QListWidgetItem("锄禾日当午"); // QListWidgetItem * item2 new QListWidgetItem("汗滴禾下土"); // ui->…

十、正则表达式详解:掌握强大的文本处理工具(二)

文章目录 🍀多字符匹配🍀匹配规则的代替🍀特殊的匹配🍀特殊的匹配plus🍀总结 🍀多字符匹配 星号(*):匹配0个或者多个字符 import retext 111-222-333 result re.matc…

苹果的Apple GPT要来了?

据外媒消息,苹果正在内部开发类 ChatGPT 的产品,与微软、OpenAI、谷歌、Meta 等科技巨头在生成式 AI 赛道展开竞争。该消息使得苹果股价上涨了 2%。据苹果工程师透露,苹果在内部构建了代号为“Ajax”的大语言模型开发框架,并构建了…

Unity自定义后处理——Bloom效果

大家好,我是阿赵。   继续介绍屏幕后处理效果,这一期讲一下Bloom效果。 一、Bloom效果介绍 还是拿这个模型作为背景。 Bloom效果,就是一种全屏泛光的效果,让模型和特效有一种真的在发光的感觉。 根据参数不一样,可…

Packet Tracer – 实施静态 NAT 和动态 NAT

Packet Tracer – 实施静态 NAT 和动态 NAT 拓扑图 目标 第 1 部分:利用 PAT 配置动态 NAT 第 2 部分:配置静态 NAT 第 3 部分:验证 NAT 实施 第 1 部分: 利用 PAT 配置动态 NAT 步骤 1: 配置允许用于 NAT …

【基于CentOS 7 的iscsi服务】

目录 一、概述 1.简述 2.作用 3. iscsi 4.相关名称 二、使用步骤 - 构建iscsi服务 1.使用targetcli工具进入到iscsi服务器端管理界面 2.实现步骤 2.1 服务器端 2.2 客户端 2.2.1 安装软件 2.2.2 在认证文件中生成iqn编号 2.2.3 开启客户端服务 2.2.4 查找可用的i…

AJAX-day03-AJAX进阶

(创作不易,感谢有你,你的支持,就是我前行的最大动力,如果看完对你有帮助,请留下您的足迹) 目录 同步代码和异步代码 回调函数地狱 Promise - 链式调用 Promise 链式应用 async函数和await async函…

Stable Diffusion入门笔记(自用)

学习视频:20分钟搞懂Prompt与参数设置,你的AI绘画“咒语”学明白了吗? | 零基础入门Stable Diffusion保姆级新手教程 | Prompt关键词教学_哔哩哔哩_bilibili 1.图片提示词模板 2.权重(提示词) 无数字 (flower)//花的…