自定义数据集,使用 PyTorch 框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测

在本文中,我们将展示如何使用 NumPy 创建自定义数据集,利用 PyTorch 实现一个简单的逻辑回归模型,并在训练完成后保存该模型,最后加载模型并用它进行预测。

1. 创建自定义数据集

首先,我们使用 NumPy 创建一个简单的二分类数据集。假设我们的数据集包含两个特征。

import numpy as np# 生成随机数据
np.random.seed(42)
X = np.random.randn(100, 2)  # 100个样本,2个特征
y = (X[:, 0] + X[:, 1] > 0).astype(int)  # 标签为1或0# 打印数据集的前5个样本
print(X[:5], y[:5])

2. 构建逻辑回归模型

接下来,我们使用 PyTorch 来构建一个简单的逻辑回归模型。PyTorch 提供了 torch.nn.Module 类,能够轻松实现神经网络模型。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset# 转换 NumPy 数据为 PyTorch 张量
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32).view(-1, 1)# 创建数据集并加载
dataset = TensorDataset(X_tensor, y_tensor)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)# 定义逻辑回归模型
class LogisticRegressionModel(nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__()self.linear = nn.Linear(2, 1)  # 2个输入特征,1个输出def forward(self, x):return torch.sigmoid(self.linear(x))# 初始化模型
model = LogisticRegressionModel()

3. 训练模型

接下来,我们定义损失函数和优化器,并训练模型。

# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练模型
epochs = 1000
for epoch in range(epochs):for inputs, labels in dataloader:optimizer.zero_grad()  # 清空梯度outputs = model(inputs)  # 前向传播loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新权重if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

4. 保存模型

模型训练完成后,我们可以将模型保存到文件中。PyTorch 提供了 torch.save() 方法来保存模型的状态字典。

# 保存模型
torch.save(model.state_dict(), 'logistic_regression_model.pth')
print("模型已保存!")

5. 加载模型并进行预测

我们可以加载保存的模型,并对新数据进行预测。

# 加载模型
loaded_model = LogisticRegressionModel()
loaded_model.load_state_dict(torch.load('logistic_regression_model.pth'))
loaded_model.eval()  # 切换到评估模式# 进行预测
with torch.no_grad():test_data = torch.tensor([[1.5, -0.5]], dtype=torch.float32)prediction = loaded_model(test_data)print(f'预测值: {prediction.item():.4f}')

6. 总结

在这篇博客中,我们展示了如何使用 NumPy 创建一个简单的自定义数据集,并使用 PyTorch 实现一个逻辑回归模型。我们还展示了如何保存训练好的模型,并加载模型进行预测。通过保存和加载模型,我们可以在不同的时间或环境中重复使用已经训练好的模型,而不需要重新训练它。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/894005.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

​ONES 春节假期服务通知

ONES 春节假期服务通知 灵蛇贺岁,瑞气盈门。感谢大家一直以来对 ONES 的认可与支持,祝您春节快乐! 「2025年1月28日 ~ 2025年2月4日」春节假期期间,我们的值班人员将为您提供如下服务 : 紧急问题 若有紧急问…

python:洛伦兹变换

洛伦兹变换(Lorentz transformations)是相对论中的一个重要概念,特别是在讨论时空的变换时非常重要。在四维时空的背景下,洛伦兹变换描述了在不同惯性参考系之间如何变换时间和空间坐标。在狭义相对论中,洛伦兹变换通常…

LangChain:使用表达式语言优化提示词链

在 LangChain 里,LCEL 即 LangChain Expression Language(LangChain 表达式语言),本文为你详细介绍它的定义、作用、优势并举例说明,从简单示例到复杂组合示例,让你快速掌握LCEL表达式语言使用技巧。 定义 …

unity学习20:time相关基础 Time.time 和 Time.deltaTime

目录 1 unity里的几种基本时间 1.1 time 相关测试脚本 1.2 游戏开始到现在所用的时间 Time.time 1.3 时间缩放值 Time.timeScale 1.4 固定时间间隔 Time.fixedDeltaTime 1.5 两次响应时间之间的间隔:Time.deltaTime 1.6 对应测试代码 1.7 需要关注的2个基本…

世上本没有路,只有“场”et“Bravo”

楔子:电气本科“工程电磁场”电气研究生课程“高等电磁场分析”和“电磁兼容”自学”天线“、“通信原理”、“射频电路”、“微波理论”等课程 文章目录 前言零、学习历程一、Maxwells equations1.James Clerk Maxwell2.自由空间中传播的电磁波3.边界条件和有限时域…

electron typescript运行并设置eslint检测

目录 一、初始化package.json 二、安装依赖 三、项目结构 四、配置启动项 五、补充:ts转js别名问题 一、初始化package.json 我的:这里的"main"没太大影响,看后面的步骤。 {"name": "xloda-cloud-ui-pc"…

学习数据结构(3)顺序表

1.动态顺序表的实现 (1)初始化 (2)扩容 (3)头部插入 (4)尾部插入 (5)头部删除 (这里注意要保证有效数据个数不为0) (6&a…

PydanticAI应用实战

PydanticAI 是一个 Python Agent 框架,旨在简化使用生成式 AI 构建生产级应用程序的过程。 它由 Pydantic 团队构建,该团队也开发了 Pydantic —— 一个在许多 Python LLM 生态系统中广泛使用的验证库。PydanticAI 的目标是为生成式 AI 应用开发带来类似 FastAPI 的体验,它基…

deepseek R1的确不错,特别是深度思考模式

deepseek R1的确不错,特别是深度思考模式,每次都能自我反省改进。比如我让 它写文案: 【赛博朋克版程序员新春密码——2025我们来破局】 亲爱的代码骑士们: 当CtrlS的肌肉记忆遇上抢票插件,当Spring Boot的…

macbook安装go语言

通过brew来安装go语言 使用brew命令时,一般都会通过brew search看看有哪些版本 brew search go执行后,返回了一堆内容,最下方展示 If you meant "go" specifically: It was migrated from homebrew/cask to homebrew/core. Cas…

若依基本使用及改造记录

若依框架想必大家都了解得不少,不可否认这是一款及其简便易用的框架。 在某种情况下(比如私活)使用起来可谓是快得一匹。 在这里小兵结合自身实际使用情况,记录一下我对若依框架的使用和改造情况。 一、源码下载 前往码云进行…

Kafka 深入服务端 — 时间轮

Kafka中存在大量的延迟操作,比如延时生产、延时拉取和延时删除等。Kafka基于时间轮概念自定义实现了一个用于延时功能的定时器,来完成这些延迟操作。 1 时间轮 Kafka没有使用基于JDK自带的Timer或DelayQueue来实现延迟功能,因为它们的插入和…

数据分析系列--②RapidMiner导入数据和存储过程

一、下载数据 点击下载AssociationAnalysisData.xlsx数据集 二、导入数据 1. 在本地计算机中创建3个文件夹 2. 从本地选择.csv或.xlsx 三、界面说明 四、存储过程 将刚刚新建的过程存储到本地 Congratulations, you are done.

HarmonyOS简介:HarmonyOS核心技术理念

核心理念 一次开发、多端部署可分可合、自由流转统一生态、原生智能 一次开发、多端部署 可分可合 自由流转 自由流转可分为跨端迁移和多端协同两种情况 统一生态 支持业界主流跨平台开发框架,通过多层次的开放能力提供统一接入标准,实现三方框架快速…

ES6语法

一、Let、const、var变量定义 1.let 声明的变量有严格局部作用域 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"&g…

书生大模型实战营3

文章目录 L0——入门岛git基础Git 是什么&#xff1f;Git 中的一些基本概念工作区、暂存区和 Git 仓库区文件状态分支主要功能 Git 平台介绍GitHubGitLabGitee Git 下载配置验证下载 Git配置 Git验证 Git配置 Git常用操作Git简易入门四部曲Git其他指令 闯关任务任务1: 破冰活动…

前端——js高级25.1.27

复习&#xff1a;对象 问题一&#xff1a; 多个数据的封装提 一个对象对应现实中的一个事物 问题二&#xff1a; 统一管理多个数据 问题三&#xff1a; 属性&#xff1a;组成&#xff1a;属性名属性值 &#xff08;属性名为字符串&#xff0c;属性值任意&#xff09; 方…

[创业之路-270]:《向流程设计要效率》-2-企业流程架构模式 POS架构(规划、业务运营、支撑)、OES架构(业务运营、使能、支撑)

目录 一、POS架构 二、OES架构 三、POS架构与OES架构的差异 四、各自的典型示例 POS架构典型示例 OES架构典型示例 示例分析 五、各自的典型企业 POS架构典型企业 OES架构典型企业 分析 六、各自典型的流程 POS架构的典型流程 OES架构的典型流程 企业流程架构模式…

计算机的错误计算(二百二十二)

摘要 利用大模型化简计算 实验表明&#xff0c;虽然结果正确&#xff0c;但是&#xff0c;大模型既绕了弯路&#xff0c;又有数值计算错误。 与前面相同&#xff0c;再利用同一个算式看看另外一个大模型的化简与计算能力。 例1. 化简计算摘要中算式。 下面是与一个大模型的…

【现代深度学习技术】深度学习计算 | 参数管理

【作者主页】Francek Chen 【专栏介绍】 ⌈ ⌈ ⌈PyTorch深度学习 ⌋ ⌋ ⌋ 深度学习 (DL, Deep Learning) 特指基于深层神经网络模型和方法的机器学习。它是在统计机器学习、人工神经网络等算法模型基础上&#xff0c;结合当代大数据和大算力的发展而发展出来的。深度学习最重…