如何理解和区分训练集、测试集和验证集

如何理解和区分训练集、测试集和验证集

在这里插入图片描述

🌈 个人主页:高斯小哥
🔥 高质量专栏:Matplotlib之旅:零基础精通数据可视化、Python基础【高质量合集】
💡 创作高质量博文,分享更多关于深度学习、PyTorch、Python领域的优质内容!(希望得到您的关注~)


🌵文章目录🌵

  • 一、概念回顾 🧠
  • 二、基于PyTorch划分训练集、验证集、测试集 🔥
  • 三、模型训练与评估 🔥
  • 四、总结 🎉
  • 五、最后 🤝

  👋大家好,欢迎来到我的博客!在机器学习和深度学习的世界里,数据集被划分为训练集、验证集和测试集是非常重要的。这些集合各自扮演着不同的角色,确保我们的模型能够准确地学习和泛化。今天,我将通过PyTorch的示例代码来详细解释如何理解和区分这三个集合。

关键词:#机器学习 #深度学习 #训练集 #验证集 #测试集 #PyTorch #数据划分 #模型训练与评估

一、概念回顾 🧠

  • 训练集(Training Set):用于训练模型的数据集。模型通过学习训练集中的数据来拟合数据分布并学习规律。
  • 验证集(Validation Set):用于验证模型性能的数据集。在模型训练过程中,我们使用验证集来调整模型参数和超参数,以优化模型性能。验证集帮助我们在调整模型时避免过拟合。
  • 测试集(Test Set):用于评估模型性能的数据集。在模型训练完成后,我们使用测试集来评估模型的泛化能力,即模型在未知数据上的表现。测试集应该是完全独立的,从未参与模型的训练或验证。

二、基于PyTorch划分训练集、验证集、测试集 🔥

  在PyTorch中,我们通常使用torch.utils.data.Datasettorch.utils.data.DataLoader来处理数据集。首先,我们需要创建一个继承自Dataset的自定义数据集类,然后使用DataLoader来加载数据并提供批量处理、打乱等功能。

  下面是一个简单的例子,展示了如何创建一个自定义数据集类,并划分为训练集、验证集和测试集。

import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split# 假设我们有一个简单的数据集,包含特征和标签
features = torch.randn(1000, 10)  # 1000个样本,每个样本10个特征
labels = torch.randint(0, 2, (1000,))  # 1000个样本的二分类标签# 划分数据集为训练集和临时集(验证集+测试集)
X_train, X_temp, y_train, y_temp = train_test_split(features, labels, test_size=0.4, random_state=42)# 进一步划分临时集为验证集和测试集
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)# 自定义数据集类
class MyDataset(Dataset):def __init__(self, features, labels):self.features = featuresself.labels = labelsdef __len__(self):return len(self.labels)def __getitem__(self, idx):return self.features[idx], self.labels[idx]# 创建数据集实例
train_dataset = MyDataset(X_train, y_train)
val_dataset = MyDataset(X_val, y_val)
test_dataset = MyDataset(X_test, y_test)# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

三、模型训练与评估 🔥

  现在,我们有了训练集、验证集和测试集的数据加载器,接下来是训练模型并使用验证集进行调整,最后使用测试集评估模型的性能。

import torch.nn as nn
import torch.optim as optim# 定义一个简单的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 2)def forward(self, x):return self.fc(x)# 实例化模型、损失函数和优化器
model = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练模型
num_epochs = 10
for epoch in range(num_epochs):model.train()  # 设置模型为训练模式train_loss = 0.0for inputs, labels in train_loader:optimizer.zero_grad()  # 清空梯度outputs = model(inputs)  # 前向传播loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数train_loss += loss.item() * inputs.size(0)  # 累加损失train_loss /= len(train_loader.dataset)  # 计算平均损失print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}")# 使用验证集评估模型性能model.eval()  # 设置模型为评估模式val_loss = 0.0val_accuracy = 0.0with torch.no_grad():  # 不需要计算梯度for inputs, labels in val_loader:outputs = model(inputs)loss = criterion(outputs, labels)val_loss += loss.item() * inputs.size(0)# 计算准确率_, predicted = torch.max(outputs, 1)correct = (predicted == labels).sum().item()val_accuracy += correct / inputs.size(0)val_loss /= len(val_loader.dataset)val_accuracy /= len(val_loader)print(f"Validation Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.4f}")# 使用测试集评估模型性能
model.eval()  # 保持模型为评估模式
test_loss = 0.0
test_accuracy = 0.0with torch.no_grad():  # 不需要计算梯度for inputs, labels in test_loader:outputs = model(inputs)loss = criterion(outputs, labels)test_loss += loss.item() * inputs.size(0)# 计算准确率_, predicted = torch.max(outputs, 1)correct = (predicted == labels).sum().item()test_accuracy += correct / inputs.size(0)test_loss /= len(test_loader.dataset)
test_accuracy /= len(test_loader)print(f"Test Loss: {test_loss:.4f}, Accuracy: {test_accuracy:.4f}")

四、总结 🎉

  通过上面的代码和解释,我们了解了如何在PyTorch中创建数据集、划分训练集、验证集和测试集,并使用这些集合来训练和评估模型。在实际应用中,通常还需要进行更多的数据预处理步骤,如数据清洗、特征工程等。此外,模型的性能也可以通过其他指标来评估,如精确度、召回率、F1分数等,具体取决于问题的性质和目标。

  希望这篇博客能帮助你更好地理解和区分训练集、验证集和测试集,并在实践中应用它们来构建和评估机器学习模型!🚀


五、最后 🤝

  亲爱的读者,感谢您每一次停留和阅读,这是对我们最大的支持和鼓励!🙏在茫茫网海中,您的关注让我们深感荣幸。您的独到见解和建议,如明灯照亮我们前行的道路。🌟若在阅读中有所收获,一个赞或收藏,对我们意义重大。

  我们承诺,会不断自我挑战,为您呈现更精彩的内容。📚有任何疑问或建议,欢迎在评论区畅所欲言,我们时刻倾听。💬让我们携手在知识的海洋中航行,共同成长,共创辉煌!🌱🌳感谢您的厚爱与支持,期待与您共同书写精彩篇章!

  您的点赞👍、收藏🌟、评论💬和关注💖,是我们前行的最大动力!

  🎉 感谢阅读,祝你编程愉快! 🎉

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/697183.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

靡语IT:Vue精讲(一)

Vue简介 发端于2013年的个人项目,已然成为全世界三大前端框架之一,在中国大陆更是前端首选。 它的设计思想、编码技巧也被众多的框架借鉴、模仿。 纪略 2013年,在Google工作的尤雨溪,受到Angular的启发,从中提取自…

【前端素材】推荐优质后台管理系统Protable平台模板(附源码)

一、需求分析 后台管理系统是一种用于管理和监控网站、应用程序或系统的在线工具。它通常是通过网页界面进行访问和操作,用于管理网站内容、用户权限、数据分析等。当我们从多个层次来详细分析后台管理系统时,可以将其功能和定义进一步细分,…

华为配置CAPWAP双栈覆盖业务示例

配置CAPWAP双栈覆盖业务示例 组网图形 图1 配置CAPWAP双栈覆盖业务示例组网图 业务需求组网需求数据规划配置思路配置注意事项操作步骤配置文件 业务需求 企业用户接入WLAN网络,以满足移动办公的最基本需求。且在覆盖区域内移动发生漫游时,不影响用户的业…

【selenium】三大切换 iframe 弹窗alert 句柄window 和 鼠标操作

目录 一、iframe 1、切换方式: 1、第一种情况: 2、第二种情况: 方式1: 先找到iframe,定位iframe元素(可以通过元素定位的各种方式:xpath,css等等),用对象接收&…

[HTML]Web前端开发技术27(HTML5、CSS3、JavaScript )JavaScript基础——喵喵画网页

希望你开心,希望你健康,希望你幸福,希望你点赞! 最后的最后,关注喵,关注喵,关注喵,佬佬会看到更多有趣的博客哦!!! 喵喵喵,你对我真的…

activeMq将mqtt发布订阅转成消息队列

1、activemq.xml置文件新增如下内容 2、mqttx测试发送: 主题(配置的模糊匹配,为了并发):VirtualTopic/device/sendData/12312 3、mqtt接收的结果 4、程序处理 package comimport cn.hutool.core.date.DateUtil; imp…

【AIGC】基于深度学习的图像生成与增强技术

摘要: 本论文探讨基于深度学习的图像生成与增强技术在图像处理和计算机视觉领域的应用。我们综合分析了主流的深度学习模型,特别是生成对抗网络(GAN)和变分自编码器(VAE)等,并就它们在实际应用中…

小程序性能优化

背景 在开发小程序的过程中我们发现,小程序的经常会遇到性能问题,尤其是在微信开发者工具的时候更是格外的卡,经过排查发现,卡顿的页面有这么多的js代码需要加载,而且都是在进入这个页面的时候加载,这就会…

Android 仿信号格子强度动画效果实现

效果图 在 Android 中,如果你想要绘制一个圆角矩形并使其居中显示,你可以使用 Canvas 类 drawRoundRect 方法。要使圆角矩形居中,你需要计算矩形的位置,这通常涉及到确定矩形左上角的位置(x, y)&#xff0…

第3部分 原理篇2去中心化数字身份标识符(DID)(2)

3.2.2. DID相关概念 3.2.2.1. 去中心化标识符 (Decentralized identifier,DID) 本聪老师:DID有两个含义,一是Decentralized identity,就是去中心化身份,是广泛意义的DID。另外一个是Decentralized identifier&#xf…

Web性能优化-浏览器工作原理-MDN文档学习笔记

浏览器工作原理 查看更多学习笔记:GitHub:LoveEmiliaForever MDN中文官网 导航 导航是加载 web 页面的第一步:输入 URL、点击一个链接、提交表单等等 DNS查询 导航的第一步是要去寻找页面资源的位置 例如访问https://example.com&#x…

qt-动画圆圈等待-LED数字

qt-动画圆圈等待-LED数字 一、演示效果二、关键程序三、下载链接 一、演示效果 二、关键程序 #include "LedNumber.h" #include <QLabel>LEDNumber::LEDNumber(QWidget *parent) : QWidget(parent) {//设置默认宽高比setScale((float)0.6);//设置默认背景色se…

websocket与Socket的区别

概念讲解 网络&#xff1a;通俗意义上&#xff0c;也就是连接两台计算器 五层网络模型&#xff1a;应用层、传输层、网络层、数据链路层、物理层 应用层 (application layer)&#xff1a;直接为应用进程提供服务。应用层协议定义的是应用进程间通讯和交互的规则&#xff0c;不…

排序第三篇 直接插入排序

插入排序的基本思想是&#xff1a; 每次将一个待排序的记录按其关键字的大小插入到前面已排好序的文件中的适当位置&#xff0c; 直到全部记录插入完为止。 一 简介 插入排序可分为2类 本文介绍 直接插入排序 它的基本操作是&#xff1a; 假设待排充序的记录存储在数组 R[1……

电路设计(27)——交通信号灯的multisim仿真

1.功能要求 使用数字芯片设计一款交通信号灯&#xff0c;使得&#xff1a; 主干道的绿灯时间为60S&#xff0c;红灯时间为45S 次干道的红灯时间为60S&#xff0c;绿灯时间为45S 主、次干道&#xff0c;绿灯的最后5S内&#xff0c;黄灯闪烁 使用数码管显示各自的倒计时时间。 按…

JavaScript 数组、遍历

数组 多维数组&#xff1a;数组里面嵌套 一层数组为二维数组。一维数组的使用频率是最高的。 如果数组访问越界会返回undefined。 数组遍历 数组方法Array.isArray() 这个方法可以去判定一个内容是否是数组。

AndroidStudio 2024-2-21 Win10/11最新安装配置(Kotlin快速构建配置,gradle镜像源)

AndroidStudio 2024 Win10/11最新安装配置 教程目的&#xff1a; (从安装到卸载) &#xff0c;针对Kotlin开发配置&#xff0c;gradle-8.2-src/bin下载慢&#xff0c;以及Kotlin构建慢的解决 好久没玩AS了,下载发现装个AS很麻烦,就觉得有必要出个教程了(就是记录一下:嘻嘻) 因…

java 时间格式 YYYY 于yyyy的区别

java formatDate 时间时&#xff0c;经常需要输入格式比如 YYYYMMDD,yyyyMMdd 这两个是有区别的 具体每个参数可以看下面

igolang学习1,dea的golang-1.22.0

参考&#xff1a;使用IDEA配置GO的开发环境备忘录-CSDN博客 1.下载All releases - The Go Programming Language (google.cn) 2.直接next 3.window环境变量配置 4.idea的go插件安装 5.新建go项目找不到jdk解决 https://blog.csdn.net/ouyang111222/article/details/1361657…

【js】无限虚拟列表的原理及实现

什么是虚拟列表 虚拟列表是长列表按需显示思路的一种实现&#xff0c;即虚拟列表是一种根据滚动容器元素的可视区域来渲染长列表数据中某一个部分数据的技术。 简而言之&#xff0c;虚拟列表指的就是「可视区域渲染」的列表。有三个概念需要了解一下&#xff1a; 视口容器元…