21.过拟合和欠拟合示例

1. 背景介绍

在机器学习和深度学习中,过拟合和欠拟合是两个非常重要的概念。过拟合指的是模型在训练数据上表现很好,但在新的测试数据上效果变差的情况。欠拟合则是指模型无法很好地拟合训练数据的情况。这两种情况都会导致模型无法很好地泛化,影响最终的预测和应用效果。

为了帮助大家更好地理解过拟合和欠拟合的概念及其应对方法,我将通过一个基于PyTorch的代码示例来演示这两种情况的具体表现。我们将生成一个抛物线数据集,并定义三种不同复杂度的模型,分别对应欠拟合、正常拟合和过拟合的情况。通过可视化训练和测试误差的曲线图,以及预测结果的散点图,我们可以直观地观察到这三种情况下模型的拟合效果。

2. 核心概念与联系

过拟合和欠拟合是机器学习和深度学习中两个相互对应的概念:

1. 过拟合(Overfitting): 模型在训练数据上表现很好,但在新的测试数据上效果变差的情况。这通常是由于模型过于复杂,过度拟合了训练数据中的噪声和细节,导致无法很好地推广到未知数据。

2. 欠拟合(Underfitting): 模型无法很好地拟合训练数据的情况。这通常是由于模型过于简单,无法捕捉训练数据中的复杂模式和关系。

这两种情况都会导致模型在实际应用中无法很好地泛化,因此需要采取相应的措施来防止和缓解过拟合和欠拟合。常见的应对方法包括:

- 增加训练样本数量
- 减少模型复杂度(比如调整网络层数、神经元个数等)
- 使用正则化技术(如L1/L2正则化、Dropout等)
- 调整超参数(如学习率、批量大小等)
- 特征工程(如特征选择、降维等)

通过合理的模型设计和超参数调优,我们可以寻找到一个恰当的模型复杂度,使其既能很好地拟合训练数据,又能在新数据上保持良好的泛化性能。这就是机器学习中的**bias-variance tradeoff**,也是我们在实际应用中需要权衡的一个关键点。

 3. 核心算法原理和具体操作步骤

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split# 生成数据
np.random.seed(42)
X = np.random.uniform(-5, 5, 500)
y = X**2 + 1 + np.random.normal(0, 1, 500)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 定义三种不同复杂度的模型
class UnderFitModel(nn.Module):def __init__(self):super(UnderFitModel, self).__init__()self.fc = nn.Linear(1, 1)def forward(self, x):return self.fc(x)class NormalFitModel(nn.Module):def __init__(self):super(NormalFitModel, self).__init__()self.fc1 = nn.Linear(1, 8)self.fc2 = nn.Linear(8, 1)self.activation = nn.ReLU()def forward(self, x):x = self.fc1(x)x = self.activation(x)x = self.fc2(x)return xclass OverFitModel(nn.Module):def __init__(self):super(OverFitModel, self).__init__()self.fc1 = nn.Linear(1, 32)self.fc2 = nn.Linear(32, 32)self.fc3 = nn.Linear(32, 1)self.activation = nn.ReLU()def forward(self, x):x = self.fc1(x)x = self.activation(x)x = self.fc2(x)x = self.activation(x)x = self.fc3(x)return x# 训练模型并记录误差
def train_and_evaluate(model, train_loader, test_loader):optimizer = torch.optim.SGD(model.parameters(), lr=0.005)criterion = nn.MSELoss()train_losses = []test_losses = []for epoch in range(100):model.train()train_loss = 0.0for inputs, targets in train_loader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()train_loss += loss.item()train_loss /= len(train_loader)train_losses.append(train_loss)model.eval()test_loss = 0.0with torch.no_grad():for inputs, targets in test_loader:outputs = model(inputs)loss = criterion(outputs, targets)test_loss += loss.item()test_loss /= len(test_loader)test_losses.append(test_loss)return train_losses, test_losses# 训练三种模型并可视化
under_fit_model = UnderFitModel()
normal_fit_model = NormalFitModel()
over_fit_model = OverFitModel()under_fit_train_losses, under_fit_test_losses = train_and_evaluate(under_fit_model, train_loader, test_loader)
normal_fit_train_losses, normal_fit_test_losses = train_and_evaluate(normal_fit_model, train_loader, test_loader)
over_fit_train_losses, over_fit_test_losses = train_and_evaluate(over_fit_model, train_loader, test_loader)plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(under_fit_train_losses, label='Under-fit Train Loss')
plt.plot(under_fit_test_losses, label='Under-fit Test Loss')
plt.plot(normal_fit_train_losses, label='Normal-fit Train Loss')
plt.plot(normal_fit_test_losses, label='Normal-fit Test Loss')
plt.plot(over_fit_train_losses, label='Over-fit Train Loss')
plt.plot(over_fit_test_losses, label='Over-fit Test Loss')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('Training and Test Loss Curves')
plt.legend()plt.subplot(1, 2, 2)
plt.scatter(X_test, y_test, label='True')
plt.scatter(X_test, under_fit_model(X_test).detach().numpy(), label='Under-fit Prediction')
plt.scatter(X_test, normal_fit_model(X_test).detach().numpy(), label='Normal-fit Prediction')
plt.scatter(X_test, over_fit_model(X_test).detach().numpy(), label='Over-fit Prediction')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Test Set Predictions')
plt.legend()plt.show()

这个代码示例涵盖了我们之前讨论的各个步骤:

数据生成: 我们生成了一个抛物线形状的数据集,并使用train_test_split函数将其划分为训练集和测试集。
模型定义: 我们定义了三种不同复杂度的PyTorch模型,分别对应欠拟合、正常拟合和过拟合的情况。
训练与评估: 我们实现了一个train_and_evaluate函数,该函数负责训练模型并记录训练集和测试集上的损失。
可视化: 最后,我们使用matplotlib绘制了训练损失和测试损失的曲线图,以及在测试集上的预测结果。

欠拟合模型:训练误差和测试误差都较大,说明模型无法很好地拟合数据。在测试集上的预测结果也存在较大偏差。
正常拟合模型:训练误差和测试误差较为接近,说明模型的拟合效果较好。在测试集上的预测也比较准确。
过拟合模型:训练误差很小,但测试误差较大,说明模型在训练集上表现很好,但在新数据上泛化能力较差。在测试集上的预测结果存在一定偏差。
通过这个实例,我们可以直观地观察到不同复杂度模型在训练和泛化性能上的差异。欠拟合模型在训练集和测试集上的损失都较大,说明模型无法很好地拟合数据。正常拟合模型在训练集和测试集上的损失较为接近,说明模型具有较好的泛化能力。而过拟合模型在训练集上的损失很小,但在测试集上的损失较大,说明模型过于复杂,在新数据上泛化性能较差。

通过这种观察训练误差和测试误差的方法,我们可以及时发现模型存在的问题,并针对性地调整模型结构、添加正则化等手段来优化模型性能。这是机器学习和深度学习中非常基础和重要的实践技能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/22297.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

视频号小店,常见的违规条例!98%的商家必犯的违规细节!

哈喽~我是电商月月 做电商,不管哪个平台都有属于自己的规则条例,这些违规细节,一定要提前了解 所以今天,月月就给大家分享一下,做视频号小店的话,有哪些常见的违规细节 这里我们分三点讲解 一&#xff…

【分享】两种方法禁止修改Word文档

对于比较重要的Word文件,不想被随意编辑修改,可以试试以下两个方法,不清楚的小伙伴,一起来看看吧! 方法1:设置“只读方式” 我们可以给Word文档设置以“只读方式”打开,这样就算编辑修改了文档…

如何通过SD-WAN提升企业沟通效率

在数字化飞速发展的今天,企业对大数据和实时商业数据传输的需求日益增长。传统的专线连接技术已无法满足企业对快速部署商业应用和高效网络连接的需求。在这种背景下,SD-WAN成为提升企业网络沟通效率的关键技术。 SD-WAN的灵活部署模式 SD-WAN提供了高度…

6月软考新通知:24下集成大概率是中级蕞简单的一门

2024下半年软考6月新通知: 一、24下软考考试时间安排: 24下半年软考报名时间:8月19日-9月15日 24下半年软考考试时间:11月9-12日 24下半年软考成绩查询:12月中(预计) 二、考情分析 24上软考…

09_JavaWeb会话

1.会话 HTTP是一种无状态协议; HTTP协议对于发送过请求或者响应都不做持久化处理具体来说就是客户端发送请求,服务器接收请求,但是服务器自身不会记录每一条请求都是由哪一个客户端发出的; 会话管理是通过Cookie和Session配合解…

【排序】插入排序,希尔排序

前面我们讲述了冒泡排序和选择排序,我们本章讲的排序方法是插入排序,插入排序是希尔排序实现的基础函数,大家一定要好好理解插入排序的逻辑,这样才能在后面学习希尔排序的时候,更容易的去理解,我们直接开始…

关于无法通过脚本启动Kafka集群的解决办法

启动Kafka集群时,需要在每台个节点上启动启动服务,比较麻烦,通过写了以下脚本来进行启停;发现能正常使用停止功能,不能正常启动Kafka; Kafka启停脚本: ## 以防不能通过shell脚本启动Kafka服务…

富格林:揭露黑幕平台保障安全

富格林指出,很多黑幕平台都会将自己包装得光鲜亮丽后,再出来诱惑投资者,使得投资者资金安全得不到保障,有苦说不出。富格林表示,黑幕平台的套路其实是非常常见的,只要投资者熟知并能够分辨出,就…

C盘扩容——只能删除C盘右边的磁盘对C盘进行扩展

winR弹出命令框 输入:compmgmt.msc 进入磁盘管理页面 注意:被删除盘如果有重要数据信息,请备份。 或者删除之前转移至其他盘,否则删除之后,则无法找回。 尤其是安装的软件。 规范安装目录十分重要。 将C盘右边的磁盘&a…

最全 Inno Setup 教程-[FILE] Flag参数

【1】此参数是一个附加选项的集合。可以使用空格将多个选项分隔开。 【2】支持以下选项: 32位 当在“Source”和“DestDir”参数中使用{sys}常量时,将该常量映射到32位系统目录。将“regserver”和“regtypelib”标志设置为将文件视为32位,…

安防综合管理系统EasyCVR视频汇聚平台GA/T 1400协议中的关键消息交互示例

在当今的信息化时代,公共安全防范日益成为保障社会和谐稳定的关键。视频监控系统作为现代安全防范的重要手段,正不断在公安、交通、城市管理等领域发挥着越来越重要的作用。而GA/T 1400协议视图库,作为公安视频图像信息应用系统的标准&#x…

Vue3 子组件访问父组件的方法 - 父组件访问子组件的属性或方法 - 子组件修改父组件的值

一。子组件访问父组件的方法 //父组件 <DialogEditing close-dialog"handleClose" /> const handleClose () > {};//子组件 const emit defineEmits(["closeDialog"]); const close () > {emit("closeDialog"); // 使用 };二。父…

健身日记之倒立俯卧撑学习——起始日2024.6.4

文章目录 前言 自我介绍 昔日计划 新目标计划 瓶颈突破尝试 参考视频及文章 前言 有轻微健身基础&#xff0c;正式接触街健五大神技&#xff0c;立志在两年内解锁全部&#xff0c;将有机会的进行日常训练和目标肌群锻炼&#xff0c;这里向大家展示我的计划和安排&#xf…

opencv-python(五)

opencv的颜色通道中顺序是B&#xff0c;G&#xff0c;R。 图像属性 import cv2img cv2.imread(jk.jpg) print(fshape{img.shape}) print(fsize{img.size}) print(fdtype{img.dtype}) shape&#xff1a;图像像素的行&#xff0c;列&#xff0c;通道 size&#xff1a;行数 X …

YonSuite收款通,助力企业618更快收款

随着电商节日“618”的临近&#xff0c;各大企业纷纷摩拳擦掌&#xff0c;准备在这场年中大促中大展身手。然而&#xff0c;随着销售额的激增&#xff0c;收款管理问题也愈发凸显&#xff0c;成为制约企业快速发展的重要瓶颈。在这个关键时刻&#xff0c;YonSuite收款通凭借其卓…

Python实现登录到远程主机,然后在远程主机上继续连接远程主机

实现功能 登录到远程主机&#xff0c;然后在远程主机上继续连接远程主机&#xff0c;执行命令。 import paramiko import time# 第二个远程主机的连接信息&#xff08;在第一个远程主机上执行SSH连接时使用&#xff09; second_remote_host 192.168.xx.xxx # 创建SSH客…

通过命令行将tar压缩文件解压缩到指定目录|Linux

要将all.tar文件解压缩到指定目录下&#xff0c;你可以使用Linux命令行中的tar命令。以下是具体步骤&#xff1a; 打开终端&#xff08;Terminal&#xff09;。 使用cd命令切换到你想要解压缩文件的目标目录。例如&#xff1a; cd /path/to/your/directory将/path/to/your/dir…

echarts图例formatter配置添加百分比

echarts图例如何添加百分比 const pieChart async () > {const myChart echarts.init(piepic.value)const piedata await getPieData(); // 等待数据返回myChart.setOption({title: {},grid: {},tooltip: {trigger: item,},legend: {top: middle,align:left,icon: circl…

都可以写好后端接口

在后端工程师的日常开发中&#xff0c;我们都曾想过 怎么设计一个良好的接口呢&#xff1f;需要考虑的点有哪些。来 给您。 1、请求参数校验 这个是大家都能想到的&#xff0c;也是一个良好的接口必备的前提条件&#xff0c;通过入参的校验我们可以过滤掉许多无效的请求&…

零基础学Java第二十七天之前端-HTML5详解

前端-HTML5详解 一、概述 HTML5是HTML的第五个版本&#xff0c;它对HTML进行了许多改进和扩展&#xff0c;使得网页开发更加丰富和便利。HTML5是Web标准的重要组成部分&#xff0c;旨在提高浏览器兼容性&#xff0c;统一网页开发标准。HTML5不仅包括了HTML的基本元素和标签&am…