深度学习Note.5(机器学习.6)

1.Runner类

一个任务应用机器学习方法流程:

数据集构建

模型构建

损失函数定义

优化器

模型训练

模型评价

模型预测

所以根据以上,我们把机器学习模型基本要素封装成一个Runner类(加上模型保存、模型加载等功能。)

Runner类的成员函数定义如下:

  • __init__函数:实例化Runner类时默认调用,需要传入模型、损失函数、优化器和评价指标等;
  • train函数:完成模型训练,指定模型训练需要的训练集和验证集;
  • evaluate函数:通过对训练好的模型进行评价,在验证集或测试集上查看模型训练效果;
  • predict函数:选取一条数据对训练好的模型进行预测;
  • save_model函数:模型在训练过程和训练结束后需要进行保存;
  • load_model函数:调用加载之前保存的模型。
class Runner(object):def __init__(self, model, optimizer, loss_fn, metric):self.model = model         # 模型self.optimizer = optimizer # 优化器self.loss_fn = loss_fn     # 损失函数   self.metric = metric       # 评估指标# 模型训练def train(self, train_dataset, dev_dataset=None, **kwargs):pass# 模型评价def evaluate(self, data_set, **kwargs):pass# 模型预测def predict(self, x, **kwargs):pass# 模型保存def save_model(self, save_path):pass# 模型加载def load_model(self, model_path):pass

1.2Runner类流程

①初始化:传入模型、损失函数、优化器和评价指标

②训练:基于训练集调用train()函数训练模型,基于验证集通过evaluate()函数验证模型。通过save_model()函数保存模型

③评价:基于测试集通过evaluate()函数得到指标性能。

④预测:给定样本,通过predict()函数得到该样本标签

2.案例:波士顿房价预测

波士顿房价预测基于线性回归模型和Runner类实现

2.1数据处理

2.1.1构建

开源库pandas导入。

import pandas as pd # 开源数据分析和操作工具# 利用pandas加载波士顿房价的数据集
data=pd.read_csv("/home/aistudio/work/boston_house_prices.csv")
# 预览前5行数据
data.head()

2.1.2数据集划分

训练集 和 测试集。

import paddlepaddle.seed(10)# 划分训练集和测试集
def train_test_split(X, y, train_percent=0.8):n = len(X)shuffled_indices = paddle.randperm(n) # 返回一个数值在0到n-1、随机排列的1-D Tensortrain_set_size = int(n*train_percent)train_indices = shuffled_indices[:train_set_size]test_indices = shuffled_indices[train_set_size:]X = X.valuesy = y.valuesX_train=X[train_indices]y_train = y[train_indices]X_test = X[test_indices]y_test = y[test_indices]return X_train, X_test, y_train, y_test X = data.drop(['MEDV'], axis=1)
y = data['MEDV']X_train, X_test, y_train, y_test = train_test_split(X,y)# X_train每一行是个样本,shape[N,D]

2.1.3特征化工程

避免数据之间的可比性:对特征数据进行归一化处理,将数据缩放到[0, 1]区间

import paddleX_train = paddle.to_tensor(X_train,dtype='float32')
X_test = paddle.to_tensor(X_test,dtype='float32')
y_train = paddle.to_tensor(y_train,dtype='float32')
y_test = paddle.to_tensor(y_test,dtype='float32')X_min = paddle.min(X_train,axis=0)
X_max = paddle.max(X_train,axis=0)X_train = (X_train-X_min)/(X_max-X_min)X_test  = (X_test-X_min)/(X_max-X_min)# 训练集构造
train_dataset=(X_train,y_train)
# 测试集构造
test_dataset=(X_test,y_test)

2.2模型构建

rom nndl.op import Linear# 模型实例化
input_size = 12
model=Linear(input_size)

2.3完善Runner类

测试集上使用MSE对模型性能进行评估。本案例利用飞桨框架提供的MSELoss API实现

import paddle
import os
from nndl.opitimizer import optimizer_lsmclass Runner(object):def __init__(self, model, optimizer, loss_fn, metric):# 优化器和损失函数为None,不再关注# 模型self.model=model# 评估指标self.metric = metric# 优化器self.optimizer = optimizerdef train(self,dataset,reg_lambda,model_dir):X,y = datasetself.optimizer(self.model,X,y,reg_lambda)# 保存模型self.save_model(model_dir)def evaluate(self, dataset, **kwargs):X,y = datasety_pred = self.model(X)result = self.metric(y_pred, y)return resultdef predict(self, X, **kwargs):return self.model(X)def save_model(self, model_dir):if not os.path.exists(model_dir):os.makedirs(model_dir)params_saved_path = os.path.join(model_dir,'params.pdtensor')paddle.save(model.params,params_saved_path)def load_model(self, model_dir):params_saved_path = os.path.join(model_dir,'params.pdtensor')self.model.params=paddle.load(params_saved_path)optimizer = optimizer_lsm# 实例化Runner
runner = Runner(model, optimizer=optimizer,loss_fn=None, metric=mse_loss)

2.4模型训练

组装完成Runner之后,我们将开始进行模型训练、评估和测试

# 模型保存文件夹
saved_dir = '/home/aistudio/work/models'# 启动训练
runner.train(train_dataset,reg_lambda=0,model_dir=saved_dir)columns_list = data.columns.to_list()
weights = runner.model.params['w'].tolist()
b = runner.model.params['b'].item()for i in range(len(weights)):print(columns_list[i],"weight:",weights[i])print("b:",b)

2.5模型测试

加载训练好的模型参数,在测试集上得到模型的MSE指标

# 加载模型权重
runner.load_model(saved_dir)mse = runner.evaluate(test_dataset)
print('MSE:', mse.item())

2.6模型预测

load_model函数加载保存好的模型,使用predict进行模型预测

runner.load_model(saved_dir)
pred = runner.predict(X_test[:1])
print("真实房价:",y_test[:1].item())
print("预测的房价:",pred.item())
真实房价: 33.099998474121094
预测的房价: 33.04654312133789

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/899839.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

linux服务器专题1------redis的安装及简单配置

在 linux上安装 Redis 可以按照以下步骤进行(此处用Ubuntu 服务器进行讲解): 步骤 1: 更新系统包 打开终端并运行以下命令以确保你的系统是最新的: sudo apt update sudo apt upgrade步骤 2: 安装 Redis 使用 apt 包管理器安装 Redis: s…

面试问题总结:qt工程师/c++工程师

C 语言相关问题答案 面试问题总结:qt工程师/c工程师 C 语言相关问题答案 目录基础语法与特性内存管理预处理与编译 C 相关问题答案面向对象编程模板与泛型编程STL 标准模板库 Qt 相关问题答案Qt 基础与信号槽机制Qt 界面设计与布局管理Qt 多线程与并发编程 目录 基础…

实现实时数据推送:SpringBoot中SSE接口的两种方法

🌟 前言 欢迎来到我的技术小宇宙!🌌 这里不仅是我记录技术点滴的后花园,也是我分享学习心得和项目经验的乐园。📚 无论你是技术小白还是资深大牛,这里总有一些内容能触动你的好奇心。🔍 &#x…

LXC 导入多Linux系统

前提要求 ubuntu下安装lxd 参考Rockylinux下安装lxd 参考LXC 源替换参考LXC 容器端口发布参考LXC webui 管理<

ES的文档更新机制

想获取更多高质量的Java技术文章&#xff1f;欢迎访问Java技术小馆官网&#xff0c;持续更新优质内容&#xff0c;助力技术成长 Java技术小馆官网https://www.yuque.com/jtostring ES的文档更新机制 在现代应用中&#xff0c;数据的动态性越来越强&#xff0c;我们不仅需要快…

trae.ai 编辑器:前端开发者的智能效率革命

一、为什么我们需要更智能的编辑器&#xff1f; 作为从业5年的前端开发者&#xff0c;我使用过从Sublime到VSCode的各种编辑器。但随着现代前端技术的复杂度爆炸式增长&#xff08;想想一个React组件可能涉及JSX、CSS-in-JS、TypeScript和GraphQL&#xff09;&#xff0c;传统…

MySQL篇(一):慢查询定位及索引、B树相关知识详解

MySQL篇&#xff08;一&#xff09;&#xff1a;慢查询定位及索引、B树相关知识详解 MySQL篇&#xff08;一&#xff09;&#xff1a;慢查询定位及索引、B树相关知识详解一、MySQL中慢查询的定位&#xff08;一&#xff09;慢查询日志的开启&#xff08;二&#xff09;慢查询日…

uniapp APP端在线升级(简版)

设计思路&#xff1a; 1.版本比较&#xff1a;应用程序检查其当前版本与远程服务器上可用的最新版本 2. 更新状态指示&#xff1a;如果应用程序是不是最新的版本&#xff0c;则页面提示下载最新版本。 3.下载启动&#xff1a;通过plus.downloader.createDownload()启动新应用…

基于javaweb的SpringBoot教务课程管理设计与实现(源码+文档+部署讲解)

技术范围&#xff1a;SpringBoot、Vue、SSM、HLMT、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、小程序、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容&#xff1a;免费功能设计、开题报告、任务书、中期检查PPT、系统功能实现、代码编写、论文编写和辅导、论文…

使用大语言模型进行Python图表可视化

Python使用matplotlib进行可视化一直有2个问题&#xff0c;一是代码繁琐&#xff0c;二是默认模板比较丑。因此发展出seaborn等在matplotlib上二次开发&#xff0c;以更少的代码进行画图的和美化的库&#xff0c;但是这也带来了定制化不足的问题。在大模型时代&#xff0c;这个…

【JavaEE】MyBatis - Plus

目录 一、快速使用二、CRUD简单使用三、常见注解3.1 TableName3.2 TableFiled3.3 TableId 四、条件构造器4.1 QueryWrapper4.2 UpdateWrapper4.3 LambdaQueryWrapper4.4 LambdaUpdateWrapper 五、自定义SQL 一、快速使用 MyBatis Plus官方文档&#xff1a;MyBatis Plus官方文档…

采用前端技术开源了一个数据结构算法的可视化工具

今天要推荐的开源项目叫VisuAlgoX,是一个面向计算机科学和游戏开发的 交互式算法可视化工具&#xff0c;帮助用户通过直观的动画理解各种数据结构和算法。 项目的前身 由于最近在做一些关于游戏和图形化方面的文章&#xff0c;因此做了一部分相关算法的动态可视化来做配图展示…

体验智谱清言的AutoGLM进行自动化的操作(Chrome插件)

最近体验了很多的大模型&#xff0c;大模型我是一直关注着ChatGLM&#xff0c;因为它确实在7b和8b这档模型里&#xff0c;非常聪明&#xff01; 最近还体验了很多大模型的应用软件&#xff0c;比如Agently、5ire、 mcphost、 Dive、 NextChat等。但是这些一般都是图形界面或者…

pytorch中dataloader自定义数据集

前言 在深度学习中我们需要使用自己的数据集做训练&#xff0c;因此需要将自定义的数据和标签加载到pytorch里面的dataloader里&#xff0c;也就是自实现一个dataloader。 数据集处理 以花卉识别项目为例&#xff0c;我们分别做出图片的训练集和测试集&#xff0c;训练集的标…

Blender模型导入虚幻引擎设置

单位系统不一致 Blender默认单位是米&#xff08;Meters&#xff09;&#xff0c;而虚幻引擎默认使用**厘米&#xff08;Centimeters&#xff09;**作为单位。 当模型从Blender导出为FBX或其他格式时&#xff0c;如果没有调整单位&#xff0c;虚幻引擎会将1米&#xff08;Blen…

Docker基础详解

Docker 技术详解 一、概述 Docker官网&#xff1a;https://docs.docker.com/ 菜鸟教程&#xff1a;https://www.runoob.com/docker/docker-tutorial.html 1.1 什么是Docker&#xff1f; Docker 是一个开源的容器化平台&#xff0c;它允许开发者将应用程序和其依赖项打包到…

FastPillars:一种易于部署的基于支柱的 3D 探测器

FastPillars&#xff1a;一种易于部署的基于支柱的 3D 探测器Report issue for preceding element Sifan Zhou 1 , Zhi Tian 2 , Xiangxiang Chu 2 , Xinyu Zhang 2 , Bo Zhang 2 , Xiaobo Lu11{}^{1}start_FLOATSUPERSCRIPT 1 end_FLOATSUPERSCRIPT11footnotemark: 1 Chengji…

NLP语言模型训练里的特殊向量

1. CLS 向量和 DEC 向量的区别及训练方式 (1) CLS 向量与 DEC 向量是否都是特殊 token&#xff1f; CLS 向量&#xff08;[CLS] token&#xff09;和 DEC 向量&#xff08;Decoder Input token&#xff09;都是特殊的 token&#xff0c;但它们出现在不同类型的 NLP 模型中&am…

字节跳动 UI-TARS 汇总整理报告

1. 摘要 UI-TARS 是字节跳动开发的一种原生图形用户界面&#xff08;GUI&#xff09;代理模型 。它将感知、行动、推理和记忆整合到一个统一的视觉语言模型&#xff08;VLM&#xff09;中 。UI-TARS 旨在跨桌面、移动和 Web 平台实现与 GUI 的无缝交互 。实验结果表明&#xf…

基于Python深度学习的鲨鱼识别分类系统

摘要&#xff1a;鲨鱼是海洋环境健康的指标&#xff0c;但受到过度捕捞和数据缺乏的挑战。传统的观察方法成本高昂且难以收集数据&#xff0c;特别是对于具有较大活动范围的物种。论文讨论了如何利用基于媒体的远程监测方法&#xff0c;结合机器学习和自动化技术&#xff0c;来…