【理解机器学习中的过拟合与欠拟合】

在机器学习中,模型的表现很大程度上取决于我们如何平衡“过拟合”和“欠拟合”。本文通过理论介绍和代码演示,详细解析过拟合与欠拟合现象,并提出应对策略。主要内容如下:

什么是过拟合和欠拟合?
如何防止过拟合和欠拟合?
出现过拟合或欠拟合时怎么办?
使用代码和图像辅助理解。


一、什么是过拟合和欠拟合?

1.1过拟合(Overfitting)

定义:过拟合就是模型“学得太多了”,它不仅学会了数据中的规律,还把噪声和细节当成规律记住了。这就好比一个学生在考试前死记硬背了答案,但稍微换一道题就不会了。

过拟合的表现:

训练集表现非常好:训练数据上的准确率高,误差低。
测试集表现很差:新数据上的准确率低,误差大。
模型太复杂:比如使用了不必要的高阶多项式或过深的神经网络。

1.2 欠拟合(Underfitting)

欠拟合是什么?

欠拟合就是模型“学得太少了”。它只掌握了最基本的规律,无法捕获数据中的复杂模式。这就像一个学生只学到了皮毛,考试的时候连最简单的题都答不对。

欠拟合的表现:

训练集和测试集表现都很差:无论新数据还是老数据,模型都表现不好。
模型太简单:比如使用了线性模型拟合非线性数据,或者训练时间不足。

二、如何防止过拟合和欠拟合?

2.1 防止过拟合的方法

  1. 获取更多数据

更多的数据可以帮助模型更好地学习数据的真实分布,减少对训练数据细节的依赖。

  1. 正则化

正则化通过惩罚模型的复杂度,让模型不容易“过拟合”。

from sklearn.linear_model import Ridge  # L2正则化
model = Ridge(alpha=0.1)  # alpha控制正则化强度
  1. 降低模型复杂度

简化模型,比如减少神经网络层数或多项式的阶数。

  1. 早停法(Early Stopping)

在模型训练时,监控验证集的误差,如果误差开始上升,提前停止训练。

from keras.callbacks import EarlyStopping
early_stopping = EarlyStopping(monitor='val_loss', patience=5)
  1. 数据增强(Data Augmentation)

在图像分类任务中,通过旋转、裁剪、翻转等方法增加数据的多样性,提升模型的泛化能力。

2.2 防止欠拟合的方法

  1. 增加模型复杂度

增加模型的参数,比如更多的神经元或更深的网络层。

  1. 延长训练时间

欠拟合可能是因为训练时间不够长,模型没有学到足够的规律。

3。 优化特征工程

如果模型无法拟合数据,可能是因为输入的特征不够好。尝试创建更多、更有意义的特征。

  1. 降低正则化强度

正则化强度过大可能限制了模型的学习能力,适当减小正则化系数。

三、过拟合与欠拟合时怎么办?
当你发现模型出现问题时,可以通过以下策略调整:

现象解决方法
过拟合- 获取更多数据
- 使用正则化
- 降低模型复杂度
- 使用早停法
欠拟合- 增加模型复杂度
- 延长训练时间
- 改善特征质量
- 减小正则化强度

四、代码与图像演示:多项式拟合的例子

下面通过一个简单的例子,用多项式拟合来直观感受过拟合与欠拟合。

4.1 数据生成
我们生成一个非线性数据集,并可视化:

import numpy as np
import matplotlib.pyplot as plt
import matplotlibmatplotlib.rcParams['font.sans-serif'] = ['SimHei']  # 设置字体为 SimHei,显示中文
matplotlib.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题# 生成非线性数据
np.random.seed(42)  # 设置随机种子,保证结果可复现
X = np.random.rand(100, 1) * 6 - 3  # X范围[-3, 3]
y = 0.5 * X**3 - X**2 + 2 + np.random.randn(100, 1) * 2  # 非线性关系并添加噪声# 可视化数据
plt.scatter(X, y, color='blue', alpha=0.7, label='数据')  # 绘制散点图
plt.xlabel("X")  # 设置X轴标签
plt.ylabel("y")  # 设置Y轴标签
plt.title("生成的非线性数据")  # 设置图表标题
plt.legend()  # 显示图例
plt.show()  # 显示图表

在这里插入图片描述
结果图:
生成的数据呈现一个明显的非线性分布。

4.2 模型训练与可视化

我们训练三种模型:
线性回归(1阶):欠拟合。
4阶多项式回归:最佳拟合。
10阶多项式回归:过拟合。

from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 多项式拟合
degrees = [1, 4, 20]
for degree in degrees:poly_features = PolynomialFeatures(degree=degree)  # 生成多项式特征X_poly_train = poly_features.fit_transform(X_train)X_poly_test = poly_features.transform(X_test)# 训练模型model = LinearRegression()model.fit(X_poly_train, y_train)# 预测y_train_pred = model.predict(X_poly_train)y_test_pred = model.predict(X_poly_test)# 计算误差train_error = mean_squared_error(y_train, y_train_pred)test_error = mean_squared_error(y_test, y_test_pred)# 绘制拟合曲线X_plot = np.linspace(-3, 3, 100).reshape(100, 1)X_poly_plot = poly_features.transform(X_plot)y_plot = model.predict(X_poly_plot)plt.scatter(X, y, color='blue', alpha=0.7, label='Data')plt.plot(X_plot, y_plot, color='red', label=f'Degree {degree}')plt.xlabel("X")plt.ylabel("y")plt.title(f"Degree {degree}\nTrain Error: {train_error:.2f} | Test Error: {test_error:.2f}")plt.legend()plt.show()

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

结果图:

Degree 1(欠拟合):模型太简单,无法捕获数据的非线性规律。
Degree 4(最佳拟合):模型复杂度适中,能很好地拟合数据。
Degree 20(过拟合):模型过于复杂,训练误差低,但测试误差大。

4.3 误差趋势分析

绘制训练误差和测试误差随模型复杂度变化的曲线:

train_errors = []
test_errors = []for degree in degrees:poly_features = PolynomialFeatures(degree=degree)X_poly_train = poly_features.fit_transform(X_train)X_poly_test = poly_features.transform(X_test)model = LinearRegression()model.fit(X_poly_train, y_train)y_train_pred = model.predict(X_poly_train)y_test_pred = model.predict(X_poly_test)train_errors.append(mean_squared_error(y_train, y_train_pred))test_errors.append(mean_squared_error(y_test, y_test_pred))# 绘制误差曲线
plt.plot(degrees, train_errors, marker='o', label='Train Error')
plt.plot(degrees, test_errors, marker='o', label='Test Error')
plt.xlabel("Polynomial Degree")
plt.ylabel("Mean Squared Error")
plt.title("训练误差和测试误差随多项式阶数变化")
plt.legend()
plt.show()

在这里插入图片描述

结果分析:

训练误差随着复杂度增加而降低。
测试误差先下降后上升,呈现“U型趋势”。

五、总结

5.1 过拟合与欠拟合的核心区别

过拟合:模型对训练数据“学得太死”,测试数据表现很差。
欠拟合:模型对数据“学得太少”,训练和测试表现都不好。

5.2 防止方法

防止过拟合:使用正则化、数据增强、早停等方法。
防止欠拟合:增加模型复杂度、延长训练时间、优化特征。


希望这篇文章让你对过拟合与欠拟合有了更深入的理解!如果还有疑问,欢迎交流!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/890765.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

学籍管理系统:实现教育管理现代化

2.1 Tomcat 简介 只要学习Java Web项目就不得不学习Tomcat。Tomcat是一种免费的开源的一种Java Web项目的容器,完美继承了 Apache服务器的特性,并且里面添加可以自动化运行的Java Web组件,让Java Web项目可以完全的运行到Tomcat里面。对于特大…

【婚庆摄影小程序设计与实现】

摘 要 社会发展日新月异,用计算机应用实现数据管理功能已经算是很完善的了,但是随着移动互联网的到来,处理信息不再受制于地理位置的限制,处理信息及时高效,备受人们的喜爱。所以各大互联网厂商都瞄准移动互联网这个潮…

服务器如何划分空间?

服务器如何划分空间?服务器是存储和处理数据的核心,如何有效地划分服务器空间则直接关系到资源的利用效率和系统的性能。无论是大型企业的数据中心,还是小型网站的共享主机,合理的空间划分都至关重要。下面是聚名网关于服务器如何…

12.26 学习卷积神经网路(CNN)

完全是基于下面这个博客来进行学习的,感谢! ​​【深度学习基础】详解Pytorch搭建CNN卷积神经网络LeNet-5实现手写数字识别_pytorch cnn-CSDN博客 基于深度神经网络DNN实现的手写数字识别,将灰度图像转换后的二维数组展平到一维,…

Unity URP多光源支持,多光源阴影投射,多光源阴影接收(优化版)

目录 前言: 一、属性 二、SubShader 三、ForwardLitPass 定义Tags 声明变体 声明变量 定义结构体 顶点Shader 片元Shader 四、全代码 四、添加官方的LitShader代码 五、全代码 六、效果图 七、结语 前言: 哈喽啊,我又来啦。这…

如何使用React,透传各类组件能力/属性?

在23年的时候,我主要使用的框架还是Vue,当时写了一篇“如何二次封装一个Vue3组件库?”的文章,里面涉及了一些如何使用Vue透传组件能力的方法。在我24年接触React之后,我发现这种扩展组件能力的方式有一个专门的术语&am…

109.【C语言】数据结构之求二叉树的高度

目录 1.知识回顾:高度(也称深度) 2.分析 设计代码框架 返回左右子树高度较大的那个的写法一:if语句 返回左右子树高度较大的那个的写法二:三目操作符 3.代码 4.反思 问题 出问题的代码 改进后的代码 执行结果 1.知识回顾&#xf…

通过百度api处理交通数据

通过百度api处理交通数据 1、读取excel获取道路数据 //道路名称Data EqualsAndHashCode public class RoadName {ExcelProperty("Name")private String name; }/*** 获取excel中的道路名称*/private static List<String> getRoadName() {// 定义文件路径&…

分析排名靠前的一些自媒体平台,如何运用这些平台?

众所周知&#xff0c;现在做网站越来越难了&#xff0c;主要的原因还是因为流量红利时代过去了。并且搜索引擎都在给自己的平台做闭环改造。搜索引擎的流量扶持太低了。如百度投资知乎&#xff0c;给知乎带来很多流量扶持&#xff0c;也为自身内容不足做一个填补。 而我们站长…

2024大模型在软件开发中的具体应用有哪些?(附实践资料合集)

大模型在软件开发中的具体应用非常广泛&#xff0c;以下是一些主要的应用领域&#xff1a; 自动化代码生成与智能编程助手&#xff1a; AI大模型能够根据开发者的自然语言描述自动生成代码&#xff0c;减少手动编写代码的工作量。例如&#xff0c;GitHub Copilot工具就是利用AI…

webpack的说明

介绍 因为不确定打出的前端包所访问的后端IP&#xff0c;需要对项目中IP配置文件单独拿出来&#xff0c;方便运维部署的时候对IP做修改。 因此&#xff0c;需要用webpack单独打包指定文件。 CommonsChunkPlugin module.exports {entry: {app: APP_FILE // 入口文件},outpu…

HTML 画布:创意与技术的融合

HTML 画布:创意与技术的融合 HTML 画布(<canvas>)元素是现代网页设计中的一个强大工具,它为开发者提供了一个空白画布,可以在上面通过JavaScript绘制图形、图像和动画。这种技术不仅为网页增添了视觉吸引力,还极大地丰富了用户的交互体验。本文将深入探讨HTML画布…

Ubuntu网络配置(桥接模式, nat模式, host主机模式)

windows上安装了vmware虚拟机&#xff0c; vmware虚拟机上运行着ubuntu系统。windows与虚拟机可以通过三种方式进行通信。分别是桥接模式&#xff1b;nat模式&#xff1b;host模式 一、桥接模式 所谓桥接模式&#xff0c;也就是虚拟机与宿主机处于同一个网段&#xff0c; 宿主机…

【SQL】王二的100道SQL刷题进阶之路

持续更新&#xff0c;建议关注收藏&#xff01; SQL进阶看这一篇就够了&#xff01; 目录 1-datediff2-生成排序序号3-having注意4-procedure declare5-弯弯绕绕 1-datediff select id,datediff(end_date, start_date) as diff from Tasks order by diff desc limit 3;dated…

3.系统学习-熵与决策树

熵与决策树 前言1.从数学开始信息量(Information Content / Shannon information)信息熵(Information Entropy)条件熵信息增益 决策树认识2.基于信息增益的ID3决策树3.C4.5决策树算法C4.5决策树算法的介绍决策树C4.5算法的不足与思考 4. CART 树基尼指数&#xff08;基尼不纯度…

FLV视频封装格式详解

目录(?)[-] OverviewFile Structure The FLV headerThe FLV File BodyFLV Tag Definition FLVTAGAudio TagsVideo TagsSCRIPTDATA onMetaDatakeyframes Overview Flash Video(简称FLV),是一种流行的网络格式。目前国内外大部分视频分享网站都是采用的这种格式. File Structure…

Text2Reward学习笔记

1. 提示词 请问&#xff0c;“glew”是一个RL工程师常用的工具库吗&#xff1f;2. 环境配置 2.1 安装 PyTorch-1.13.1 pip install torch1.13.1cu116 torchvision0.14.1cu116 \ torchaudio0.13.1 --extra-index-url https://download.pytorch.org/whl/cu1161.2 安装工具库 …

SpringBoot + HttpSession 自定义生成sessionId

SpringBoot HttpSession 自定义生成sessionId 业务场景实现方案 业务场景 最近在做用户登录过程中&#xff0c;由于默认ID是通过UUID创建的&#xff0c;缺乏足够的安全性&#xff0c;决定要自定义生成 sessionId。 实现方案 正常的获取session方法如下&#xff1a; HttpSe…

破解海外业务困局:新加坡服务器托管与跨境组网策略

在当今全球化商业蓬勃发展的浪潮之下&#xff0c;众多企业将目光投向海外市场&#xff0c;力求拓展业务版图、抢占发展先机。而新加坡&#xff0c;凭借其卓越的地理位置、强劲的经济发展态势以及高度国际化的营商环境&#xff0c;已然成为企业海外布局的热门之选。此时&#xf…

CMS(Concurrent Mark Sweep)垃圾回收器的具体流程

引言 CMS&#xff08;Concurrent Mark Sweep&#xff09;收集器是Java虚拟机中的一款并发收集器&#xff0c;其设计目标是最小化停顿时间&#xff0c;非常适合于对响应时间敏感的应用。与传统的串行或并行收集器不同&#xff0c;CMS能够尽可能地让垃圾收集线程与用户线程同时运…