深度学习(过拟合 欠拟合)

过拟合: 

深度学习模型由于其复杂性,往往容易出现过拟合的问题。以下是一些深度学习中常见的过拟合原因和解决方法:

1. 数据量不足:深度学习模型通常需要大量的数据来进行训练,如果数据量不足,模型容易过度拟合训练集。解决方法包括增加数据集的规模,或者使用数据增强技术来生成更多的数据样本。

2. 模型复杂度过高:如果深度学习模型的层数或参数过多,模型容易过度拟合训练数据。解决方法包括减少模型层数,减少模型参数数量,或者通过正则化(如L1、L2正则化)加入额外的约束限制模型的复杂度。

3. 缺乏正则化:正则化是一种常用的降低模型过拟合的方法,可以通过在损失函数中加入正则化项来约束模型的复杂度。常见的正则化方法包括L1/L2正则化、Dropout等。

4. 数据标签错误或不平衡:数据集中的标签错误或不均衡会影响模型的学习和泛化能力,导致过拟合。解决方法包括仔细检查数据集并修正标签错误,或者采用数据平衡技术,如欠采样、过采样等。

5. 训练集与测试集分布不一致:如果训练集与测试集的分布不一致,模型将无法很好地泛化到新的数据上。解决方法包括确保训练集和测试集的数据来源和分布相似,或者使用领域适应技术来使模型更好地适应新的数据。

6. 提前停止:通过监控模型在验证集上的性能,当模型在验证集上的性能开始下降时,及时停止训练,可以避免过拟合。

综上所述,深度学习中的过拟合问题可以通过增加数据量、降低模型复杂度、添加正则化、修正数据标签、平衡数据分布、提前停止等方法来解决。在实践中,需要根据具体情况选择合适的方法来降低过拟合的风险。

欠拟合: 

欠拟合是指模型无法充分拟合训练数据的情况,导致模型在训练集上的性能不佳,也无法在测试集或新的样本上良好地泛化。

以下是一些深度学习中常见的欠拟合原因和解决方法:

1. 模型复杂度不足:深度学习模型可能过于简单,无法捕捉到数据中的复杂关系。解决方法包括增加模型的层数,增加模型的宽度(增加隐藏层的神经元数量),或者使用更复杂的模型架构(如使用更多的卷积核、更深的网络结构)。

2. 数据量不足:如果训练数据太少,模型可能无法学习到充分的特征表示。解决方法包括增加数据集的规模,或者使用数据增强技术来生成更多的数据样本。

3. 特征选择不当:如果选择的特征不足以表示数据的复杂性,模型无法充分学习数据的特征。解决方法包括增加更多的特征,或者使用更好的特征工程技术(如使用更高级的特征提取方法、使用领域专业知识进行特征选择等)。

4. 学习率过高或过低:学习率是指模型在每次更新参数时的步长,过高或过低的学习率都会导致模型无法达到良好的拟合效果。解决方法包括适当调整学习率,可以通过网格搜索或使用自适应学习率算法(如Adam等)来寻找最佳的学习率。

5. 过拟合的解决方法:很多过拟合解决方法也可用于欠拟合问题,如增加数据量、降低模型复杂度、添加正则化等。

综上所述,欠拟合问题可以通过增加模型复杂度、增加数据量、优化特征选择、调整学习率等方法来解决。在实践中,需要根据具体情况选择合适的方法来提高模型的性能。

代码:

#@tab pytorch
from d2l import torch as d2l
import torch
from torch import nn
import numpy as np
import math#@tab all 生成数据集
max_degree = 20  # 多项式的最大阶数
n_train, n_test = 100, 100  # 训练和测试数据集大小
true_w = np.zeros(max_degree)  # 分配大量的空间
true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])#真实权重features = np.random.normal(size=(n_train + n_test, 1))#随机特征
np.random.shuffle(features)
poly_features = np.power(features, np.arange(max_degree).reshape(1, -1))
for i in range(max_degree):poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)!
# labels的维度:(n_train+n_test,)
labels = np.dot(poly_features, true_w)
labels += np.random.normal(scale=0.1, size=labels.shape)# NumPy ndarray转换为tensor
true_w, features, poly_features, labels = [d2l.tensor(x, dtype=d2l.float32) for x in [true_w, features, poly_features, labels]]#实现一个函数来评估模型在给定数据集上的损失
def evaluate_loss(net, data_iter, loss):  """评估给定数据集上模型的损失"""metric = d2l.Accumulator(2)  # 损失的总和,样本数量for X, y in data_iter:out = net(X)y = d2l.reshape(y, out.shape)l = loss(out, y)metric.add(d2l.reduce_sum(l), d2l.size(l))return metric[0] / metric[1]#训练模型
def train(train_features, test_features, train_labels, test_labels,num_epochs=400):loss = nn.MSELoss()#损失函数input_shape = train_features.shape[-1]# 不设置偏置,因为我们已经在多项式中实现了它net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))#单层线性回归batch_size = min(10, train_labels.shape[0])train_iter = d2l.load_array((train_features, train_labels.reshape(-1,1)),batch_size)test_iter = d2l.load_array((test_features, test_labels.reshape(-1,1)),batch_size, is_train=False)trainer = torch.optim.SGD(net.parameters(), lr=0.01)animator = d2l.Animator(xlabel='epoch', ylabel='loss', yscale='log',xlim=[1, num_epochs], ylim=[1e-3, 1e2],legend=['train', 'test'])for epoch in range(num_epochs):d2l.train_epoch_ch3(net, train_iter, loss, trainer)if epoch == 0 or (epoch + 1) % 20 == 0:animator.add(epoch + 1, (evaluate_loss(net, train_iter, loss),evaluate_loss(net, test_iter, loss)))print('weight:', net[0].weight.data.numpy())

 正常模型:

# 从多项式特征中选择前4个维度,即1,x,x^2/2!,x^3/3!
train(poly_features[:n_train, :4], poly_features[n_train:, :4],labels[:n_train], labels[n_train:])

欠拟合模型:

# 欠拟合,欠拟合是指模型无法继续减少训练误差
# 从多项式特征中选择前2个维度,即1和x,实际上有四个特征
train(poly_features[:n_train, :2], poly_features[n_train:, :2],labels[:n_train], labels[n_train:])

过拟合模型:

#@tab all,过拟合是指训练误差远小于验证误差
# 从多项式特征中选取所有维度(20个),实际只有四个
train(poly_features[:n_train, :], poly_features[n_train:, :],labels[:n_train], labels[n_train:],num_epochs=1250)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/766314.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue3怎么使用reactive赋值

使用ref赋值: const list ref([]) const getList async () > {const res await axios.get(/list)list.value res.data } 如何使用reactive来替换呢? //const list ref([]) const list reactive([]) const getList async () > {const res…

NLP 笔记:LDA(训练篇)

1 前言:吉布斯采样 吉布斯采样的基本思想是,通过迭代的方式,逐个维度地更新所有变量的状态 1.1 举例 收拾东西 假设我们现在有一个很乱的屋子,我们不知道东西应该放在哪里(绝对位置),但知道哪…

iOS模拟器 Unable to boot the Simulator —— Ficow笔记

本文首发于 Ficow Shen’s Blog,原文地址: iOS模拟器 Unable to boot the Simulator —— Ficow笔记。 内容概览 前言终结模拟器进程命令行改权限清除模拟器缓存总结 前言 iOS模拟器和Xcode一样不靠谱,问题也不少。😂 那就有病治…

鸿蒙Harmony应用开发—ArkTS-ForEach:循环渲染

ForEach基于数组类型数据执行循环渲染。 说明: 从API version 9开始,该接口支持在ArkTS卡片中使用。 接口描述 ForEach(arr: Array,itemGenerator: (item: Array, index?: number) > void,keyGenerator?: (item: Array, index?: number): string …

【wails】(10):研究go-llama.cpp项目,但是发现不支持最新的qwen大模型,可以运行llama-2-7b-chat

1,视频演示地址 2,项目地址go-llama.cpp 下载并进行编译: git clone --recurse-submodules https://github.com/go-skynet/go-llama.cpp cd go-llama.cpp make libbinding.a项目中还打了个补丁: 给 编译成功,虽然有…

深度学习 线性神经网络(线性回归 从零开始实现)

介绍: 在线性神经网络中,线性回归是一种常见的任务,用于预测一个连续的数值输出。其目标是根据输入特征来拟合一个线性函数,使得预测值与真实值之间的误差最小化。 线性回归的数学表达式为: y w1x1 w2x2 ... wnxn …

【隐私计算实训营——004上手隐语SecretFlow和SecretNote安装部署】

1. SecretFlow安装 1.1 环境要求 Python>3.8操作系统 Ubuntu18 资源:>8核16GB安装包 secretflow-lite 安装方式 docker(推荐) 2. SecretFlow部署模式 SecretFlow使用Ray作为分布式计算调度框架。 Ray集群由一个主节点和零或若干个…

Fabric Measurement

Fabric Measurement 布料测量

分布式组件 Nacos

1.在之前的文章写过的就不用重复写。 写一些没有写过的新东西 2.细节 2.1命名空间 : 配置隔离 默认: public (默认命名空间):默认新增所有的配置都在public空间下 2.1.1 开发 、测试 、生产:有不同的配置文件 比如…

docker 数据卷 (二)

1,为什么使用数据卷 卷是在一个或多个容器内被选定的目录,为docker提供持久化数据或共享数据,是docker存储容器生成和使用的数据的首选机制。对卷的修改会直接生效,当提交或创建镜像时,卷不被包括在镜像中。 总结为两…

Orbit 使用指南 10|在机器人上安装传感器 | Isaac Sim | Omniverse

如是我闻: 资产类(asset classes)允许我们创建和模拟机器人,而传感器 (sensors) 则帮助我们获取关于环境的信息,获取不同的本体感知和外界感知信息。例如,摄像头传感器可用于获取环境的视觉信息&#xff0c…

ADB环境配置和基础使用

目录 一、ADB简介工作原理 二、安装ADB驱动程序配置环境变量验证ADB安装 三、启用USB调试模式四、连接设备到计算机五、使用ADB命令安装/卸载包Android 设备与电脑传输文件exit 退出目录日志操作指令系统操作指令adb ps命令 一、ADB简介 ADB全称是Android Debug Bridge&#x…

CentOS系统部署YesPlayMusic播放器并实现公网访问本地音乐资源

文章目录 1. 安装Docker2. 本地安装部署YesPlayMusic3. 安装cpolar内网穿透4. 固定YesPlayMusic公网地址 本篇文章讲解如何使用Docker搭建YesPlayMusic网易云音乐播放器,并且结合cpolar内网穿透实现公网访问音乐播放器。 YesPlayMusic是一款优秀的个人音乐播放器&am…

校园大数据平台的顶层设计与微观应用PDF下载

校园大数据平台的顶层设计与微观应用文档,是一份全面深入的解决方案,旨在构建一个集数据收集、存储、处理、分析及可视化于一体的综合平台。该设计以提升教育教学质量、优化资源配置、增强学生服务体验和提高管理效率为核心目标,通过大数据分…

c++的学习之路:3、入门(2)

一、引用 1、引用的概念 引用不是新定义一个变量,而是给已存在变量取了一个别名,编译器不会为引用变量开辟内存空 间,它和它引用的变量共用同一块内存空间。 怎么说呢,简单点理解就是你的小名,家里人叫你小名&#…

基于springboot和vue的旅游资源网站的设计与实现

环境以及简介 基于vue, springboot旅游资源网站的设计与实现,Java项目,SpringBoot项目,含开发文档,源码,数据库以及ppt 环境配置: 框架:springboot JDK版本:JDK1.8 服务器&#xf…

谷歌seo营销服务有哪些服务?

以我们举例,如果你在做B2B外贸建站,这里有全套保姆式托管服务,让你既省心又省力,七天就能搞定网站建设,快速上线,再来就是谷歌白帽SEO,我们这边强调的是纯白帽操作,专注于高质量的原…

今天聊聊新零售

一、什么是新零售? 2016年,在杭州举行的“云栖大会”上,马云发表了讲话,首次提出了“新零售”这一概念。 1.1 新零售概念 新零售,英文是New Retailing,新零售是对人货场的重构。人是消费者、销售人员、…

CISP 4.2备考之《物理与网络通信安全》知识点总结

文章目录 第 1 节 物理与环境安全第 2 节 网络安全基础第 3 节 网络安全技术与设备第 1 部分 防火墙第 2 部分 入侵检测系统第 3 部分 其他安全产品 第 4 节 网络安全设计规划 第 1 节 物理与环境安全 1.场地选择 1.1 场地选择:自然条件、社会条件、其他条件。1.2 抗震和承重&…

【操作系统】进程基础知识

目录 1、进程的介绍 2、进程的五个基本特性 3、进程的组成 4、进程的并行和并发执行 5、进程的状态 6、进程的通信 7、线程 1、进程的介绍 进程(Process)是程序在某个数据集合上的一次运行活动,也是操作系统进行资源分配和保护的基本单…