【理解机器学习中的过拟合与欠拟合】

在机器学习中,模型的表现很大程度上取决于我们如何平衡“过拟合”和“欠拟合”。本文通过理论介绍和代码演示,详细解析过拟合与欠拟合现象,并提出应对策略。主要内容如下:

什么是过拟合和欠拟合?
如何防止过拟合和欠拟合?
出现过拟合或欠拟合时怎么办?
使用代码和图像辅助理解。


一、什么是过拟合和欠拟合?

1.1过拟合(Overfitting)

定义:过拟合就是模型“学得太多了”,它不仅学会了数据中的规律,还把噪声和细节当成规律记住了。这就好比一个学生在考试前死记硬背了答案,但稍微换一道题就不会了。

过拟合的表现:

训练集表现非常好:训练数据上的准确率高,误差低。
测试集表现很差:新数据上的准确率低,误差大。
模型太复杂:比如使用了不必要的高阶多项式或过深的神经网络。

1.2 欠拟合(Underfitting)

欠拟合是什么?

欠拟合就是模型“学得太少了”。它只掌握了最基本的规律,无法捕获数据中的复杂模式。这就像一个学生只学到了皮毛,考试的时候连最简单的题都答不对。

欠拟合的表现:

训练集和测试集表现都很差:无论新数据还是老数据,模型都表现不好。
模型太简单:比如使用了线性模型拟合非线性数据,或者训练时间不足。

二、如何防止过拟合和欠拟合?

2.1 防止过拟合的方法

  1. 获取更多数据

更多的数据可以帮助模型更好地学习数据的真实分布,减少对训练数据细节的依赖。

  1. 正则化

正则化通过惩罚模型的复杂度,让模型不容易“过拟合”。

from sklearn.linear_model import Ridge  # L2正则化
model = Ridge(alpha=0.1)  # alpha控制正则化强度
  1. 降低模型复杂度

简化模型,比如减少神经网络层数或多项式的阶数。

  1. 早停法(Early Stopping)

在模型训练时,监控验证集的误差,如果误差开始上升,提前停止训练。

from keras.callbacks import EarlyStopping
early_stopping = EarlyStopping(monitor='val_loss', patience=5)
  1. 数据增强(Data Augmentation)

在图像分类任务中,通过旋转、裁剪、翻转等方法增加数据的多样性,提升模型的泛化能力。

2.2 防止欠拟合的方法

  1. 增加模型复杂度

增加模型的参数,比如更多的神经元或更深的网络层。

  1. 延长训练时间

欠拟合可能是因为训练时间不够长,模型没有学到足够的规律。

3。 优化特征工程

如果模型无法拟合数据,可能是因为输入的特征不够好。尝试创建更多、更有意义的特征。

  1. 降低正则化强度

正则化强度过大可能限制了模型的学习能力,适当减小正则化系数。

三、过拟合与欠拟合时怎么办?
当你发现模型出现问题时,可以通过以下策略调整:

现象解决方法
过拟合- 获取更多数据
- 使用正则化
- 降低模型复杂度
- 使用早停法
欠拟合- 增加模型复杂度
- 延长训练时间
- 改善特征质量
- 减小正则化强度

四、代码与图像演示:多项式拟合的例子

下面通过一个简单的例子,用多项式拟合来直观感受过拟合与欠拟合。

4.1 数据生成
我们生成一个非线性数据集,并可视化:

import numpy as np
import matplotlib.pyplot as plt
import matplotlibmatplotlib.rcParams['font.sans-serif'] = ['SimHei']  # 设置字体为 SimHei,显示中文
matplotlib.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题# 生成非线性数据
np.random.seed(42)  # 设置随机种子,保证结果可复现
X = np.random.rand(100, 1) * 6 - 3  # X范围[-3, 3]
y = 0.5 * X**3 - X**2 + 2 + np.random.randn(100, 1) * 2  # 非线性关系并添加噪声# 可视化数据
plt.scatter(X, y, color='blue', alpha=0.7, label='数据')  # 绘制散点图
plt.xlabel("X")  # 设置X轴标签
plt.ylabel("y")  # 设置Y轴标签
plt.title("生成的非线性数据")  # 设置图表标题
plt.legend()  # 显示图例
plt.show()  # 显示图表

在这里插入图片描述
结果图:
生成的数据呈现一个明显的非线性分布。

4.2 模型训练与可视化

我们训练三种模型:
线性回归(1阶):欠拟合。
4阶多项式回归:最佳拟合。
10阶多项式回归:过拟合。

from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 多项式拟合
degrees = [1, 4, 20]
for degree in degrees:poly_features = PolynomialFeatures(degree=degree)  # 生成多项式特征X_poly_train = poly_features.fit_transform(X_train)X_poly_test = poly_features.transform(X_test)# 训练模型model = LinearRegression()model.fit(X_poly_train, y_train)# 预测y_train_pred = model.predict(X_poly_train)y_test_pred = model.predict(X_poly_test)# 计算误差train_error = mean_squared_error(y_train, y_train_pred)test_error = mean_squared_error(y_test, y_test_pred)# 绘制拟合曲线X_plot = np.linspace(-3, 3, 100).reshape(100, 1)X_poly_plot = poly_features.transform(X_plot)y_plot = model.predict(X_poly_plot)plt.scatter(X, y, color='blue', alpha=0.7, label='Data')plt.plot(X_plot, y_plot, color='red', label=f'Degree {degree}')plt.xlabel("X")plt.ylabel("y")plt.title(f"Degree {degree}\nTrain Error: {train_error:.2f} | Test Error: {test_error:.2f}")plt.legend()plt.show()

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

结果图:

Degree 1(欠拟合):模型太简单,无法捕获数据的非线性规律。
Degree 4(最佳拟合):模型复杂度适中,能很好地拟合数据。
Degree 20(过拟合):模型过于复杂,训练误差低,但测试误差大。

4.3 误差趋势分析

绘制训练误差和测试误差随模型复杂度变化的曲线:

train_errors = []
test_errors = []for degree in degrees:poly_features = PolynomialFeatures(degree=degree)X_poly_train = poly_features.fit_transform(X_train)X_poly_test = poly_features.transform(X_test)model = LinearRegression()model.fit(X_poly_train, y_train)y_train_pred = model.predict(X_poly_train)y_test_pred = model.predict(X_poly_test)train_errors.append(mean_squared_error(y_train, y_train_pred))test_errors.append(mean_squared_error(y_test, y_test_pred))# 绘制误差曲线
plt.plot(degrees, train_errors, marker='o', label='Train Error')
plt.plot(degrees, test_errors, marker='o', label='Test Error')
plt.xlabel("Polynomial Degree")
plt.ylabel("Mean Squared Error")
plt.title("训练误差和测试误差随多项式阶数变化")
plt.legend()
plt.show()

在这里插入图片描述

结果分析:

训练误差随着复杂度增加而降低。
测试误差先下降后上升,呈现“U型趋势”。

五、总结

5.1 过拟合与欠拟合的核心区别

过拟合:模型对训练数据“学得太死”,测试数据表现很差。
欠拟合:模型对数据“学得太少”,训练和测试表现都不好。

5.2 防止方法

防止过拟合:使用正则化、数据增强、早停等方法。
防止欠拟合:增加模型复杂度、延长训练时间、优化特征。


希望这篇文章让你对过拟合与欠拟合有了更深入的理解!如果还有疑问,欢迎交流!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/890765.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【婚庆摄影小程序设计与实现】

摘 要 社会发展日新月异,用计算机应用实现数据管理功能已经算是很完善的了,但是随着移动互联网的到来,处理信息不再受制于地理位置的限制,处理信息及时高效,备受人们的喜爱。所以各大互联网厂商都瞄准移动互联网这个潮…

12.26 学习卷积神经网路(CNN)

完全是基于下面这个博客来进行学习的,感谢! ​​【深度学习基础】详解Pytorch搭建CNN卷积神经网络LeNet-5实现手写数字识别_pytorch cnn-CSDN博客 基于深度神经网络DNN实现的手写数字识别,将灰度图像转换后的二维数组展平到一维,…

Unity URP多光源支持,多光源阴影投射,多光源阴影接收(优化版)

目录 前言: 一、属性 二、SubShader 三、ForwardLitPass 定义Tags 声明变体 声明变量 定义结构体 顶点Shader 片元Shader 四、全代码 四、添加官方的LitShader代码 五、全代码 六、效果图 七、结语 前言: 哈喽啊,我又来啦。这…

如何使用React,透传各类组件能力/属性?

在23年的时候,我主要使用的框架还是Vue,当时写了一篇“如何二次封装一个Vue3组件库?”的文章,里面涉及了一些如何使用Vue透传组件能力的方法。在我24年接触React之后,我发现这种扩展组件能力的方式有一个专门的术语&am…

109.【C语言】数据结构之求二叉树的高度

目录 1.知识回顾:高度(也称深度) 2.分析 设计代码框架 返回左右子树高度较大的那个的写法一:if语句 返回左右子树高度较大的那个的写法二:三目操作符 3.代码 4.反思 问题 出问题的代码 改进后的代码 执行结果 1.知识回顾&#xf…

分析排名靠前的一些自媒体平台,如何运用这些平台?

众所周知,现在做网站越来越难了,主要的原因还是因为流量红利时代过去了。并且搜索引擎都在给自己的平台做闭环改造。搜索引擎的流量扶持太低了。如百度投资知乎,给知乎带来很多流量扶持,也为自身内容不足做一个填补。 而我们站长…

2024大模型在软件开发中的具体应用有哪些?(附实践资料合集)

大模型在软件开发中的具体应用非常广泛,以下是一些主要的应用领域: 自动化代码生成与智能编程助手: AI大模型能够根据开发者的自然语言描述自动生成代码,减少手动编写代码的工作量。例如,GitHub Copilot工具就是利用AI…

Ubuntu网络配置(桥接模式, nat模式, host主机模式)

windows上安装了vmware虚拟机, vmware虚拟机上运行着ubuntu系统。windows与虚拟机可以通过三种方式进行通信。分别是桥接模式;nat模式;host模式 一、桥接模式 所谓桥接模式,也就是虚拟机与宿主机处于同一个网段, 宿主机…

3.系统学习-熵与决策树

熵与决策树 前言1.从数学开始信息量(Information Content / Shannon information)信息熵(Information Entropy)条件熵信息增益 决策树认识2.基于信息增益的ID3决策树3.C4.5决策树算法C4.5决策树算法的介绍决策树C4.5算法的不足与思考 4. CART 树基尼指数(基尼不纯度…

SpringBoot + HttpSession 自定义生成sessionId

SpringBoot HttpSession 自定义生成sessionId 业务场景实现方案 业务场景 最近在做用户登录过程中,由于默认ID是通过UUID创建的,缺乏足够的安全性,决定要自定义生成 sessionId。 实现方案 正常的获取session方法如下: HttpSe…

破解海外业务困局:新加坡服务器托管与跨境组网策略

在当今全球化商业蓬勃发展的浪潮之下,众多企业将目光投向海外市场,力求拓展业务版图、抢占发展先机。而新加坡,凭借其卓越的地理位置、强劲的经济发展态势以及高度国际化的营商环境,已然成为企业海外布局的热门之选。此时&#xf…

数学课程评价系统:客户服务与教学支持

2.1 SSM框架介绍 本课题程序开发使用到的框架技术,英文名称缩写是SSM,在JavaWeb开发中使用的流行框架有SSH、SSM、SpringMVC等,作为一个课题程序采用SSH框架也可以,SSM框架也可以,SpringMVC也可以。SSH框架是属于重量级…

攻防世界web第三题file_include

<?php highlight_file(__FILE__);include("./check.php");if(isset($_GET[filename])){$filename $_GET[filename];include($filename);} ?>惯例&#xff1a; 代码审查&#xff1a; 1.可以看到include(“./check.php”);猜测是同级目录下有一个check.php文…

【深度学习环境】NVIDIA Driver、Cuda和Pytorch(centos9机器,要用到显示器)

文章目录 一 、Anaconda install二、 NIVIDIA driver install三、 Cuda install四、Pytorch install 一 、Anaconda install Step 1 Go to the official website: https://www.anaconda.com/download Input your email and submit. Step 2 Select your version, and click i…

寻找适合小户型的开源知识库open source knowledge base之路

寻找一个开源的知识库&#xff0c;为了把以前花很多时间收集的信息或是项目/课程资料放到一个容易归类和管理的私有自主系统中&#xff0c;以便更容易查阅&#xff0c;花更少时间收集、对比版本及分享等一系列管理工作&#xff0c;同时确保在需要时可以相对快速找到有用的资料&…

C语言勘破之路-最终篇 —— 预处理(上)

人无完人&#xff0c;持之以恒&#xff0c;方能见真我&#xff01;&#xff01;&#xff01; 共同进步&#xff01;&#xff01; 文章目录 一、预定义符号二、#define定义常量三.、#define定义宏四、带有副作用的宏参数五、宏替换的规则六、宏和函数的对比1.宏的优势2.函数的优…

学习threejs,THREE.RingGeometry 二维平面圆环几何体

&#x1f468;‍⚕️ 主页&#xff1a; gis分享者 &#x1f468;‍⚕️ 感谢各位大佬 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍⚕️ 收录于专栏&#xff1a;threejs gis工程师 文章目录 一、&#x1f340;前言1.1 ☘️THREE.RingGeometry 圆环几…

Win11系统下Oracle11g数据库下载与安装使用教程

文章目录 一、Oracle下载与安装1.1 解压安装包1.2 开始安装Oracle11g1.2.1 用户 1.3 测试数据库是否配置成功1.4 了解一下 Oracle相关服务1.5 了解Oracle体系结构 二、使用工具连接数据库2.1 PL/ SQL 连接本地oracle 三、PL/ SQL远程访问数据库3.1 可能踩坑问题&#xff08;TNS…

数据结构(Java版)第六期:LinkedList与链表(一)

目录 一、链表 1.1. 链表的概念及结构 1.2. 链表的实现 专栏&#xff1a;数据结构(Java版) 个人主页&#xff1a;手握风云 一、链表 1.1. 链表的概念及结构 链表是⼀种物理存储结构上⾮连续存储结构&#xff0c;数据元素的逻辑顺序是通过链表中的引⽤链接次序实现的。与火车…

从零开始C++棋牌游戏开发之第三篇:游戏的界面布局设计

在游戏开发的旅途中&#xff0c;界面布局设计是一个充满创意和挑战的环节。对于棋牌类游戏而言&#xff0c;界面不仅仅是功能的载体&#xff0c;更是玩家与游戏互动的桥梁。一个清晰、直观且美观的界面可以显著提升游戏的用户体验。 在这篇文章中&#xff0c;我们将从功能需求…