【chatgpt】train_split_test的random_state

在使用train_test_split函数划分数据集时,random_state参数用于控制随机数生成器的种子,以确保划分结果的可重复性。这样,无论你运行多少次代码,只要使用相同的random_state值,得到的训练集和测试集划分就会是一样的。

使用 train_test_split 示例

以下是一个示例,展示如何使用train_test_split函数进行数据集划分,并设置random_state参数:
程序输出结果
Training set shape: (80, 10), (80,)
Test set shape: (20, 10), (20,)

import numpy as np
from sklearn.model_selection import train_test_split# 假设我们有一些数据
X = np.random.rand(100, 10)  # 100个样本,每个样本10个特征
y = np.random.randint(0, 2, 100)  # 100个样本的标签(0或1)# 使用train_test_split进行数据集划分
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 打印划分后的数据集形状
print(f'Training set shape: {X_train.shape}, {y_train.shape}')
print(f'Test set shape: {X_test.shape}, {y_test.shape}')

在这个示例中:

  • X 是特征矩阵,包含100个样本,每个样本有10个特征。
  • y 是标签数组,包含100个样本的标签。
  • test_size=0.2 表示将数据集的20%用作测试集,剩下的80%用作训练集。
  • random_state=42 用于确保划分的可重复性。

为什么使用 random_state

使用 random_state 可以确保在多次运行代码时,得到的训练集和测试集划分是一致的,这在以下情况下特别有用:

  1. 调试和开发: 在开发和调试模型时,使用相同的 random_state 可以确保数据划分的一致性,从而使得调试更加容易。
  2. 实验的可重复性: 在进行实验时,使用相同的 random_state 可以确保实验结果的可重复性,使得其他人可以验证你的结果。
  3. 比较模型性能: 在比较不同模型的性能时,使用相同的 random_state 可以确保每个模型都使用相同的训练集和测试集,从而使比较更加公平。
    在这里插入图片描述

示例:比较大数据集和小数据集的模型性能

假设我们有一个大数据集和一个小数据集,我们想要比较它们在同一模型上的性能。我们可以使用 train_test_split 进行数据集划分,并设置 random_state 以确保划分的可重复性。

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import numpy as np
# 创建大数据集和小数据集
X_large = np.random.rand(1000, 10)
y_large = np.random.rand(1000, 1)X_small = np.random.rand(100, 10)
y_small = np.random.rand(100, 1)# 使用train_test_split进行数据集划分
X_train_large, X_test_large, y_train_large, y_test_large = train_test_split(X_large, y_large, test_size=0.2, random_state=42)
X_train_small, X_test_small, y_train_small, y_test_small = train_test_split(X_small, y_small, test_size=0.2, random_state=42)# 转换为张量
X_train_large = torch.tensor(X_train_large, dtype=torch.float32)
y_train_large = torch.tensor(y_train_large, dtype=torch.float32)
X_test_large = torch.tensor(X_test_large, dtype=torch.float32)
y_test_large = torch.tensor(y_test_large, dtype=torch.float32)X_train_small = torch.tensor(X_train_small, dtype=torch.float32)
y_train_small = torch.tensor(y_train_small, dtype=torch.float32)
X_test_small = torch.tensor(X_test_small, dtype=torch.float32)
y_test_small = torch.tensor(y_test_small, dtype=torch.float32)# 创建数据加载器
train_loader_large = DataLoader(TensorDataset(X_train_large, y_train_large), batch_size=32, shuffle=True)
test_loader_large = DataLoader(TensorDataset(X_test_large, y_test_large), batch_size=32, shuffle=False)train_loader_small = DataLoader(TensorDataset(X_train_small, y_train_small), batch_size=32, shuffle=True)
test_loader_small = DataLoader(TensorDataset(X_test_small, y_test_small), batch_size=32, shuffle=False)# 定义简单的线性模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.linear = nn.Linear(10, 1)def forward(self, x):return self.linear(x)# 训练模型的通用函数
def train_model(train_loader, num_epochs=50, learning_rate=0.01):model = SimpleModel()criterion = nn.MSELoss()optimizer = optim.SGD(model.parameters(), lr=learning_rate)train_losses = []for epoch in range(num_epochs):model.train()epoch_train_loss = 0.0for batch_x, batch_y in train_loader:outputs = model(batch_x)loss = criterion(outputs, batch_y)optimizer.zero_grad()loss.backward()optimizer.step()epoch_train_loss += loss.item()epoch_train_loss /= len(train_loader)train_losses.append(epoch_train_loss)print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_train_loss:.4f}')return model, train_losses# 训练大数据集的模型
print("Training on large dataset")
model_large, train_losses_large = train_model(train_loader_large)# 训练小数据集的模型
print("\nTraining on small dataset")
model_small, train_losses_small = train_model(train_loader_small)# 绘制训练损失曲线
plt.figure(figsize=(12, 6))
plt.plot(range(1, len(train_losses_large) + 1), train_losses_large, label='Large Dataset Train Loss')
plt.plot(range(1, len(train_losses_small) + 1), train_losses_small, label='Small Dataset Train Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training Loss Comparison')
plt.savefig("test")# 在测试集上计算最终的评估指标(例如均方误差)
def evaluate_model(model, test_loader):model.eval()test_loss = 0.0criterion = nn.MSELoss()with torch.no_grad():for batch_x, batch_y in test_loader:outputs = model(batch_x)loss = criterion(outputs, batch_y)test_loss += loss.item()test_loss /= len(test_loader)return test_loss# 评估大数据集的模型
final_test_loss_large = evaluate_model(model_large, test_loader_large)# 评估小数据集的模型
final_test_loss_small = evaluate_model(model_small, test_loader_small)print(f'Final Test Loss on Large Dataset: {final_test_loss_large:.4f}')
print(f'Final Test Loss on Small Dataset: {final_test_loss_small:.4f}')

结果分析

通过上述代码,可以得到大数据集和小数据集在训练过程中的损失曲线以及最终的测试损失。根据这些信息,可以比较它们的收敛情况和性能。

  • 损失曲线: 通过观察损失曲线,判断模型在两个数据集上的收敛速度和稳定性。如果两者曲线形状相似,并且在同一水平上趋于平稳,可以认为它们收敛到了相似的程度。

  • 最终测试损失: 最终测试损失值可以用于直接比较两个模型的性能。如果两者最终测试损失值接近,则可以认为它们的模型性能相当。

通过使用相同的 random_state 值,确保数据集划分的一致性,从而使得比较结果更加公平和具有可重复性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/858344.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

导入别人的net文件报红问题sdk

1. 使用cmd命令 dotnet --info 查看自己使用的SDK版本 2.直接找到项目中的 global.json 文件,右键打开,直接修改版本为本机的SDK版本,就可以用了

使用Flink CDC实时监控MySQL数据库变更

在现代数据架构中,实时数据处理变得越来越重要。Flink CDC(Change Data Capture)是一种强大的工具,可以帮助我们实时捕获数据库的变更,并进行处理。本文将介绍如何使用Flink CDC从MySQL数据库中读取变更数据&#xff0…

端到端自动驾驶的基础概念

欢迎大家关注我的B站: 偷吃薯片的Zheng同学的个人空间-偷吃薯片的Zheng同学个人主页-哔哩哔哩视频 (bilibili.com) 目录 1.端到端自动驾驶的定义 1.1特斯拉FSD 1.2端到端架构演进 1.3大模型 1.4世界模型 1.5纯视觉传感器 2.落地的挑战 1.端到端自动驾驶的定…

MySQL----利用Mycat配置读写分离

首先确保主从复制是正常的,具体步骤在MySQL----配置主从复制。MySQL----配置主从复制 环境 master(CtenOS7):192.168.200.131 ----ifconfig查看->ens33->inetslave(win10):192.168.207.52 ----ipconfig查看->无线局域网适配器 WLA…

whiteboard - 笔记

1 drawio draw.io GitHub - jgraph/drawio: draw.io is a JavaScript, client-side editor for general diagramming. 2 demo 可以将XML数据保存到服务器上的data目录。需要在服务器端创建一个接收和处理POST请求的脚本,该脚本将接收到的SVG数据保存到指定的文件中。下面是…

libssh-cve_2018_10933-vulfocus

1.原理 ibssh是一个用于访问SSH服务的C语言开发包,它能够执行远程命令、文件传输,同时为远程的程序提供安全的传输通道。server-side state machine是其中的一个服务器端状态机。 在libssh的服务器端状态机中发现了一个逻辑漏洞。攻击者可以MSG_USERA…

基于CDMA的多用户水下无线光通信(3)——解相关多用户检测

继续上一篇博文,本文将介绍基于解相关的多用户检测算法。解相关检测器的优点是因不需要估计各个用户的接收信号幅值而具有抗远近效应的能力。常规的解相关检测器有运算量大和实时性差的缺点,本文针对异步CDMA的MAI主要来自干扰用户的相邻三个比特周期的特…

【c2】编译预处理,gdb,makefile,文件,多线程,动静态库

文章目录 1.编译预处理:C源程序 - 编译预处理【#开头指令和特殊符号进行处理,删除程序中注释和多余空白行】- 编译2.gdb调试:多进/线程中无法用3.makefile文件:make是一个解释makefile中指令的命令工具4.文件:fprint/f…

Git分支的状态存储——stash命令的详细用法

文章目录 6.6 Git的分支状态存储6.6.1 git stash命令6.6.2 Git存储的基本使用6.6.3 Git存储的其他用法6.6.4 Git存储与暂存区6.6.5 Git存储的原理 6.6 Git的分支状态存储 有时,当我们在项目的一部分上已经工作一段时间后,所有东西都进入了混乱的状态&am…

AG32 MCU是否支持DFU下载实现USB升级

1、AG32 MCU是否支持DFU下载实现USB升级呢? 先说答案是NO. STM32 可以通过内置DFU实现USB升级,AG32 MCU目前不支持。但用户可以自己写一个DFU, 作为二次boot. 2、AG32 MCU可支持的下载方式有哪些呢? 我们AG32裸机下载只支持uart和…

四川汇聚荣科技有限公司怎么样?

在探讨一家科技公司的综合实力时,我们往往从多个维度进行考量,包括但不限于公司的发展历程、产品与服务的质量、市场表现、技术创新能力以及企业文化。四川汇聚荣科技有限公司作为一家位于中国西部的科技企业,其表现和影响力自然也受到业界和…

测试服务器端口是否打开,服务器端口开放异常的解决方法

在进行服务器端口开放性的测试时,我们通常使用网络工具来验证目标端口是否响应特定的协议请求。常用的工具包括Telnet、Nmap、nc(netcat)等。这些工具可以通过发送TCP或UDP数据包到指定的IP地址和端口,然后分析返回的数据包&#…

【FreeRTOS】任务管理与调度

文章目录 调度:总结 调度: 相同优先级的任务轮流运行最高优先级的任务先运行 可以得出结论如下: a 高优先级的任务在运行,未执行完,更低优先级的任务无法运行b 一旦高优先级任务就绪,它会马上运行&#xf…

Postman Postman接口测试工具使用简介

Postman这个接口测试工具的使用做个简单的介绍,仅供参考。 插件安装 1)下载并安装chrome浏览器 2)如下 软件使用说明

从零入手人工智能(5)—— 决策树

1.前言 在上一篇文章《从零入手人工智能(4)—— 逻辑回归》中讲述了逻辑回归这个分类算法,今天我们的主角是决策树。决策树和逻辑回归这两种算法都属于分类算法,以下是决策树和逻辑回归的相同点: 分类任务&#xff1…

椭圆的矩阵表示法

椭圆的矩阵表示法 flyfish 1. 标准几何表示法 标准几何表示法是通过椭圆的几何定义来表示的: x 2 a 2 y 2 b 2 1 \frac{x^2}{a^2} \frac{y^2}{b^2} 1 a2x2​b2y2​1其中, a a a 是椭圆的长半轴长度, b b b 是椭圆的短半轴长度。 2.…

三十八篇:架构大师之路:探索软件设计的无限可能

架构大师之路:探索软件设计的无限可能 1. 引言:架构的艺术与科学 在软件工程的广阔天地中,系统架构不仅是设计的骨架,更是灵魂所在。它如同建筑师手中的蓝图,决定了系统的结构、性能、可维护性以及未来的扩展性。本节…

AWS-PatchAsgInstance自动化定时ASG组打补丁

问题 需要给AWS的EC2水平自动扩展组AutoScaling Group(ASG)中的EC2自动定期打补丁。 创建自动化运行IAM角色 找到创建角色入口页面,如下图: 开始创建Systems Manager自动化运行的IAM角色,如下图: 设置…

2023-2024 学年第二学期小学数学六年级期末质量检测模拟(制作:王胤皓)(90分钟)

word效果预览: 一、我会填 1. 1.\hspace{0.5em} 1. 一个多位数,亿位上是次小的素数,千位上是最小的质数的立方,十万位是 10 10 10 和 15 15 15 的最大公约数,万位是最小的合数,十位上的数既不是质数也…

体验了一下AI生产3D模型有感

我的实验路子是想试试能不能帮我建一下实物模型 SO 我选择了一个成都环球中心的网图 但是生成的结果掺不忍睹,但是看demo来看,似乎如果你能给出一张干净的提示图片,他还是能做出一些东西的 这里我延申的思考是这个物体他如果没看过背面&…