人工智能算法工程师(中级)课程14-神经网络的优化与设计之拟合问题及优化与代码详解

大家好,我是微学AI,今天给大家介绍一下人工智能算法工程师(中级)课程14-神经网络的优化与设计之拟合问题及优化与代码详解。在机器学习和深度学习领域,模型的训练目标是找到一组参数,使得模型能够从训练数据中学习到有用的模式,并对未知数据做出准确预测。这一过程涉及到解决两种主要的拟合问题:欠拟合(Underfitting)和过拟合(Overfitting)。

文章目录

  • 一、拟合问题概述
    • 欠拟合现象
    • 过拟合现象
    • 解决策略
  • 二、正则化方法
    • 1. L1正则化
    • 2. L2正则化
  • 三、正则化参数的更新
  • 四、Dropout
  • 五、代码实现

一、拟合问题概述

在机器学习领域,拟合问题是指通过训练数据找到最佳模型参数,使得模型在未知数据上的表现尽可能好。拟合问题主要包括欠拟合和过拟合两种现象。

欠拟合现象

定义:欠拟合指的是机器学习模型在训练集上的表现不佳,无法充分学习到数据的内在规律,导致模型的预测能力低下。这就好比一个学生在考试中,由于知识掌握不牢固,对已知题目的解答都做不好,更不用说应对新题目了。
原因分析:
模型复杂度低:如果模型太简单,如用线性模型去拟合非线性的数据分布,那么模型就无法捕捉到数据中的复杂模式,就像用直尺去测量曲线长度一样,永远无法得到准确的结果。
训练数据不足:模型需要足够的数据来学习和概括数据的特性。如果数据量太少,模型可能没有机会接触到数据的全貌,就像从一本书中只读了几页就想理解整本书的内容一样困难。
特征选择不当:如果使用的特征与目标预测无关或相关性弱,模型就难以从中学习到有效的信息,相当于在解决问题时选择了错误的工具。

过拟合现象

定义:过拟合是指模型在训练数据上表现得过于出色,以至于对训练数据中的噪声或偶然性细节也进行了学习,这导致模型在面对未见过的数据时,泛化能力下降。这就像一个学生过分依赖于记忆特定的例题,而没有真正理解背后的原理,因此在遇到稍微变化的问题时就束手无策。
原因分析:
模型复杂度过高:如果模型过于复杂,如高阶多项式回归,它可能会过度适应训练数据中的每一个细节,包括噪声和异常值,而不是学习数据的普遍规律。
训练数据包含噪声:现实世界的数据往往带有噪声,如果模型试图学习这些噪声,就会导致过拟合。这类似于试图从嘈杂的环境中听清对话,噪声会干扰对真实信息的理解。
训练数据量不足:即使模型复杂度适中,但如果训练数据量不够,模型仍然可能过拟合。这是因为数据量不足时,模型可能会把偶然出现的模式误认为是普遍规律。

解决策略

增加模型复杂度:对于欠拟合,可以通过增加模型复杂度来提升模型的学习能力,如使用更高阶的多项式或更复杂的神经网络结构。
增加训练数据量:无论是欠拟合还是过拟合,增加训练数据量都能帮助模型更好地学习数据的分布,提高泛化能力。
特征工程:优化特征选择,确保模型能够基于有意义的特征进行学习。
正则化:使用L1或L2正则化等技术来限制模型复杂度,防止过拟合。
交叉验证:通过交叉验证来评估模型的泛化能力,确保模型不仅在训练数据上表现好,也能在未见数据上给出准确预测。
早停法:在训练过程中监控验证集的性能,一旦发现验证集上的性能不再提升,就停止训练,避免过拟合。
在这里插入图片描述

二、正则化方法

为了解决过拟合问题,通常采用正则化方法对模型进行约束。常见的正则化方法有L1正则化和L2正则化。

1. L1正则化

L1正则化的目标函数为:
J ( θ ) = 1 2 m ∑ i = 1 m ( h θ ( x ( i ) ) − y ( i ) ) 2 + α ∑ j = 1 n ∣ θ j ∣ J(\theta) = \frac{1}{2m}\sum_{i=1}^{m}(h_{\theta}(x^{(i)}) - y^{(i)})^2 + \alpha\sum_{j=1}^{n}|\theta_j| J(θ)=2m1i=1m(hθ(x(i))y(i))2+αj=1nθj
其中,第一项为损失函数,第二项为L1正则化项, α \alpha α为惩罚系数, θ j \theta_j θj为模型参数。

2. L2正则化

L2正则化的目标函数为:
J ( θ ) = 1 2 m ∑ i = 1 m ( h θ ( x ( i ) ) − y ( i ) ) 2 + α 2 ∑ j = 1 n θ j 2 J(\theta) = \frac{1}{2m}\sum_{i=1}^{m}(h_{\theta}(x^{(i)}) - y^{(i)})^2 + \frac{\alpha}{2}\sum_{j=1}^{n}\theta_j^2 J(θ)=2m1i=1m(hθ(x(i))y(i))2+2αj=1nθj2
其中,第一项为损失函数,第二项为L2正则化项, α \alpha α为惩罚系数, θ j \theta_j θj为模型参数。

三、正则化参数的更新

在优化目标函数时,我们需要对正则化参数进行更新。以下为L2正则化的参数更新公式:
θ j : = θ j − α ( 1 m ∑ i = 1 m ( h θ ( x ( i ) ) − y ( i ) ) x j ( i ) + λ θ j ) \theta_j := \theta_j - \alpha\left(\frac{1}{m}\sum_{i=1}^{m}(h_{\theta}(x^{(i)}) - y^{(i)})x_j^{(i)} + \lambda\theta_j\right) θj:=θjα(m1i=1m(hθ(x(i))y(i))xj(i)+λθj)
其中, λ = α m \lambda = \frac{\alpha}{m} λ=mα为正则化参数。
在这里插入图片描述

四、Dropout

Dropout是一种有效的正则化方法,通过在训练过程中随机丢弃部分神经元,来减少模型对特定训练样本的依赖。以下是Dropout的实现步骤:
(1)在训练过程中,按照一定概率随机丢弃神经元;
(2)在测试过程中,将所有神经元的输出乘以概率因子。

五、代码实现

以下是基于PyTorch的拟合问题及优化代码实现:

import torch
import torch.nn as nn
import torch.optim as optim
# 定义模型
class LinearRegression(nn.Module):def __init__(self, input_dim, output_dim):super(LinearRegression, self).__init__()self.linear = nn.Linear(input_dim, output_dim)def forward(self, x):return self.linear(x)
# 生成数据
x = torch.randn(100, 1)
y = 3 * x + 2 + torch.randn(100, 1)
# 实例化模型
model = LinearRegression(1, 1)
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=0.01)  # L2正则化
# 训练模型
num_epochs = 100
for epoch in range(num_epochs):model.train()optimizer.zero_grad()outputs = model(x)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch+1) % 10 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')
# 测试模型
model.eval()
with torch.no_grad():predicted = model(x).detach().numpy()print(f'预测值:{predicted}')

通过本文的介绍,相信大家对拟合问题及优化方法有了更深入的了解。在实际应用中,可根据数据特点选择合适的正则化方法,以提高模型的泛化能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/872046.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2023年高教杯数学建模2023B题解析(仅从代码角度出发)

前言 最近博主正在和队友准备九月的数学建模,在做往年的题目,博主主要是负责数据处理,运算以及可视化,这里分享一下自己部分的工作,相关题目以及下面所涉及的代码后续我会作为资源上传 问题求解 第一题 第一题的思路主要如下:…

【SpringBoot】SpringCache轻松启用Redis缓存

目录: 1.前言 2.常用注解 3.启用缓存 1.前言 Spring Cache是Spring提供的一种缓存抽象机制,旨在通过简化缓存操作来提高系统性能和响应速度。Spring Cache可以将方法的返回值缓存起来,当下次调用方法时如果从缓存中查询到了数据&#xf…

基于 jenkins 部署接口自动化测试项目!

引言 在现代软件开发过程中,自动化测试是保证代码质量的关键环节。通过自动化测试,可以快速发现和修复代码中的问题,从而提高开发效率和产品质量。而 Jenkins 作为一款开源的持续集成工具,可以帮助我们实现自动化测试的自动化部署…

Mysql:解决CPU飙升至100%问题的系统诊断与优化策略

在服务器运维过程中,CPU使用率飙升到100%是一个常见且棘手的问题。这不仅会严重影响服务器的性能,还可能导致服务中断。当遇到这类情况时,首要任务是快速定位问题源头并采取相应措施。以下是一个基于操作系统命令和MySQL数据库优化的详细解决…

快排的3种方式

//(前两种时间复杂度为o(n^2) , 最后一种为o(n*logn)public static void swap(int[] arr , int i , int j){arr[i] arr[i] ^arr[j];arr[j] arr[i] ^arr[j];arr[i] arr[i] ^arr[j]; } //使数组中以arr[R]划分,返回循环后arr[R]的所在地 public…

代码随想录算法训练营Day 62| 图论 part02 | 695. 岛屿的最大面积、1020.飞地的数量、130.被围绕的区域

代码随想录算法训练营Day 62| 图论 part02 | 695. 岛屿的最大面积、1020.飞地的数量、130.被围绕的区域 文章目录 代码随想录算法训练营Day 62| 图论 part02 | 695. 岛屿的最大面积、1020.飞地的数量、130.被围绕的区域65.岛屿的最大面积一、BFS二、DFS 1020.飞地的数量一、DFS…

自动化(二正)

Java接口自动化用到的技术栈 技术栈汇总: ①Java基础(封装、反射、泛型、jdbc) ②配置文件解析(properties) ③httpclient(发送http请求) ④fastjson、jsonpath处理数据的 ⑤testng自动化测试框架重点 ⑥allure测试报…

JMeter CSV 参数文件的使用教程

在 JMeter 测试过程中,合理地使用参数化技术是提高测试逼真度的关键步骤。本文将介绍如何通过 CSV 文件实现 JMeter 中的参数化。 设定 CSV 文件 首先,构建一个包含需要参数化数据的 CSV 文件。打开任何文本编辑器,输入希望模拟的用户数据&…

Scrapy 核心组件之Spiders组件的使用

Spiders 组件是 Scrapy 框架的核心组件,它定义了网络爬虫抓取网站数据的方式,其中包 括抓取的动作,如是否跟进链接,以及如何从网页内容中提取结构化数据。换言之,Spiders 组件用于定义抓取网页数据的动作及解析网页数据…

IGBT参数学习

IGBT(绝缘栅双极晶体管(Insulated Gate Bipolar Transistor))的内部架构如下所示: IGBT是个单向的器件,电流只能朝一个方向流动,通常IGBT会并联一个续流二极管 IGBT型号:IKW40N120T2 IKW40N120T2 电路符号…

ICPC铜牌算法

铜牌算法 2021ICPC上海站 铜牌开题: D:数学思维构造 E:贪心思维 G:树形dp H:图论克鲁斯卡尔重构树 I:背包dp K:思维构造2021ICPC沈阳站 铜牌开题: B:并查集 E:字符串简单查找 F:字符串简单构造模拟 J:BFS预处理2021ICPC南京站 铜牌开题: A:思维 C:暴力均摊stl D:贪心暴力…

【代码规范】.train(False)和.eval()的相似性和区别

【代码规范】.train(False)和.eval()的相似性和区别 文章目录 一、.train(False) 和 .eval() 的功能二、.train(False) 和 .eval() 的区别2.1 .eval()2.2 .train(False)2.3 总结 三、.eval()更加规范 一、.train(False) 和 .eval() 的功能 .train(False) 和 .eval() 在功能上非…

Centos7 安装Redis6.2.6 gcc报错问题解决

Redis 报错信息 make: *** [all] 错误 2 安装gcc 修改yum源,在安装更新rpm包时获得比较理想的速度,走阿里云镜像通道 发现报错信息如下: 正在解析主机 mirrors.aliyun.com (mirrors.aliyun.com)… 失败:未知的名称或服务。 wget: 无法解析主机地址 “mi…

LLM:学习清单 ing

根据模型的数据流程方向和自己的经验列出: 一、模型输入 分词器:BPE,BBPE 位置编码:绝对位置编码,三角函数编码,ROPE 词向量模型:词袋,监督学习模型;BGE,BC…

数据中心内存RAS技术发展背景

随着数据量的爆炸性增长和云计算的普及,数据中心内存的多比特错误及由无法纠正错误(UE)导致的停机问题日益凸显,这些故障不仅影响服务质量,还会带来高昂的修复或更换成本。随着工作负载、硬件密度以及对高性能要求的增加,数据中心…

01--IptablesFirewalld详解

前言:这里写一下,前面文章里都是直接关闭然后实验,感觉这样有点草率,这里写一下大概的概念和用法,作为知识的补充,这章写轻松点,毕竟是网安毕业的,算是给自己放松一下吧。 1、iptabl…

RK3568笔记三十八:DS18B20驱动开发测试

若该文为原创文章,转载请注明原文出处。 DS18B20驱动参考的是讯为电子的单总线驱动第十四期 | 单总线_北京迅为的博客-CSDN博客 博客很详细,具体不描述。 只是记录测试下DS18B20读取温度。 一、介绍 流程基本和按键驱动差不多,主要功能是…

asio之fd_set_adapter

简介 fd_set_adapter是对fd_set的封装 fd_set_adapter 是不同平台fd_set的别名 #if defined(BOOST_ASIO_WINDOWS) || defined(__CYGWIN__) typedef win_fd_set_adapter fd_set_adapter; #else typedef posix_fd_set_adapter fd_set_adapter; #endifposix_fd_set_adapter l…

为什么要做USB转多路UART项目 - 技术角度

前言 之前专门为USB转多路UART项目写了个序,提到了技术方案原因,这个文章打算展开讲一下。 一、工业物联网关 最初是因为有个工业物联网关的项目,需要出多路RS485接口,每路外接几十个三相电表PLC之类的电力电子设备。其中一款需…

构建艺术:精通Gradle依赖替换的策略与实践

构建艺术:精通Gradle依赖替换的策略与实践 在软件开发的构建过程中,依赖管理是确保项目顺利进行的关键环节。Gradle,作为一款强大的构建工具,提供了灵活的依赖管理功能,包括依赖替换,这使得开发者能够精细…