机器学习复习(1)——任务整理流程

目录

固定的随机数种子

定义predict功能

拆分数据集

定义trainer

超参数设置

数据集载入

固定的随机数种子

在大量的机器学习与深度学习实验中,如果不进行特殊设置,我们的结果将不可复现,固定的随机数种子将会解决这个问题

def same_seed(seed): '''设置随机种子(便于复现)'''torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = Falsenp.random.seed(seed)torch.manual_seed(seed)if torch.cuda.is_available():torch.cuda.manual_seed_all(seed)print(f'Set Seed = {seed}')

定义predict功能

在自己进行完整的训练框架搭建时,对于结果的预测功能搭建,需要分离被预测对象,否则预测的对象的梯度会回传,扰乱模型backbone

def predict(test_loader, model, device):model.eval() # 设置成eval模式.preds = []for x in tqdm(test_loader):x = x.to(device)                        with torch.no_grad():pred = model(x)         preds.append(pred.detach().cpu())   
#detach()从GPU分离tensor, cpu()将tensor从GPU转到CPUpreds = torch.cat(preds, dim=0).numpy()  
# 将预测结果拼接成一个numpy矩阵return preds

拆分数据集

对于原始数据集(分类)的拆分函数

def train_valid_split(data_set, valid_ratio, seed):'''数据集拆分成训练集(training set)和 验证集(validation set)'''valid_set_size = int(valid_ratio * len(data_set)) train_set_size = len(data_set) - valid_set_sizetrain_set, valid_set = random_split(data_set, [train_set_size, valid_set_size],generator=torch.Generator().manual_seed(seed))return np.array(train_set), np.array(valid_set)

定义trainer

def trainer(train_loader, valid_loader, model, config, device):criterion = nn.MSELoss(reduction='mean') # 损失函数的定义# 定义优化器optimizer = torch.optim.SGD(model.parameters(), lr=config['learning_rate'], momentum=0.9) # tensorboard 的记录器writer = SummaryWriter()if not os.path.isdir('./models'):# 创建文件夹-用于存储模型os.mkdir('./models')n_epochs, best_loss, step, early_stop_count = config['n_epochs'], math.inf, 0, 0for epoch in range(n_epochs):model.train() # 训练模式loss_record = []# tqdm可以帮助我们显示训练的进度  train_pbar = tqdm(train_loader, position=0, leave=True)# 设置进度条的左边 : 显示第几个Epoch了train_pbar.set_description(f'Epoch [{epoch+1}/{n_epochs}]')for x, y in train_pbar:optimizer.zero_grad()               # 将梯度置0.x, y = x.to(device), y.to(device)   # 将数据一到相应的存储位置(CPU/GPU)pred = model(x)             loss = criterion(pred, y)loss.backward()                     # 反向传播 计算梯度.optimizer.step()                    # 更新网络参数step += 1loss_record.append(loss.detach().item())# 训练完一个batch的数据,将loss 显示在进度条的右边train_pbar.set_postfix({'loss': loss.detach().item()})mean_train_loss = sum(loss_record)/len(loss_record)# 每个epoch,在tensorboard 中记录训练的损失(后面可以展示出来)writer.add_scalar('Loss/train', mean_train_loss, step)model.eval() # 将模型设置成 evaluation 模式.loss_record = []for x, y in valid_loader:x, y = x.to(device), y.to(device)with torch.no_grad():pred = model(x)loss = criterion(pred, y)loss_record.append(loss.item())mean_valid_loss = sum(loss_record)/len(loss_record)print(f'Epoch [{epoch+1}/{n_epochs}]: Train loss: {mean_train_loss:.4f}, Valid loss: {mean_valid_loss:.4f}')# 每个epoch,在tensorboard 中记录验证的损失(后面可以展示出来)writer.add_scalar('Loss/valid', mean_valid_loss, step)if mean_valid_loss < best_loss:best_loss = mean_valid_losstorch.save(model.state_dict(), config['save_path']) # 模型保存print('Saving model with loss {:.3f}...'.format(best_loss))early_stop_count = 0else: early_stop_count += 1if early_stop_count >= config['early_stop']:print('\nModel is not improving, so we halt the training session.')return

超参数设置

device = 'cuda' if torch.cuda.is_available() else 'cpu'
config = {'seed': 5201314,      # 随机种子,可以自己填写. :)'select_all': True,   # 是否选择全部的特征'valid_ratio': 0.2,   # 验证集大小(validation_size) = 训练集大小(train_size) * 验证数据占比(valid_ratio)'n_epochs': 3000,     # 数据遍历训练次数           'batch_size': 256, 'learning_rate': 1e-5,              'early_stop': 400,    # 如果early_stop轮损失没有下降就停止训练.     'save_path': './models/model.ckpt'  # 模型存储的位置
}

数据集载入

# 使用Pytorch中Dataloader类按照Batch将数据集加载
train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, pin_memory=True)

 模型训练

model = My_Model(input_dim=x_train.shape[1]).to(device) 
# 将模型和训练数据放在相同的存储位置(CPU/GPU)
trainer(train_loader, valid_loader, model, config, device)

模型测试

def save_pred(preds, file):''' 将模型保存到指定位置'''with open(file, 'w') as fp:writer = csv.writer(fp)writer.writerow(['id', 'tested_positive'])for i, p in enumerate(preds):writer.writerow([i, p])model = My_Model(input_dim=x_train.shape[1]).to(device)
model.load_state_dict(torch.load(config['save_path']))
preds = predict(test_loader, model, device) 
save_pred(preds, 'pred.csv')   

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/655720.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

字符串相关函数和文件操作

文章目录 1. C/C 字符串概述1.1 字符串常量1.2 字符数组 2. 字符串函数2.1 拷贝赋值功能相关函数&#xff08;覆盖&#xff09;2.1.1 strcpy2.1.2 strncpy2.1.3 memcpy2.1.4 memmove2.1.5 memset2.1.6 注意小点2.1.7 【函数区别】 2.2 追加功能相关函数2.2.1 strcat2.2.2 strnc…

使用plotly dash 画3d圆柱(Python)

plotly3D &#xff08;3d charts in Python&#xff09;可以画3维图形 在做圆柱的3D装箱项目&#xff0c;需要装箱的可视化&#xff0c;但是Mesh &#xff08;3d mesh plots in Python&#xff09;只能画三角形&#xff0c;所以需要用多个三角形拼成一个圆柱&#xff08;想做立…

Python qt.qpa.xcb: could not connect to display解决办法

遇到问题&#xff1a;qt.qpa.xcb: could not connect to display 解决办法&#xff0c;在命令行输入&#xff1a; export DISPLAY:0 然后重新跑python程序&#xff0c;解决&#xff01; 参考博客&#xff1a;qt.qpa.xcb: could not connect to displayqt.qpa.plugin: Could …

Ubuntu搭建国标平台wvp-GB28181-pro

目录 简介安装和编译1.查看操作系统信息2.安装最新版的nodejs3.安装java环境4.安装mysql5.安装redis6.安装编译器7.安装cmake8.安装依赖库9.编译ZLMediaKit9.1.编译结果说明 10.编译wvp-GB28181-pro10.1.编译结果说明 配置1.WVP-PRO配置文件1.1.Mysql数据库配置1.2.REDIS数据库…

监听项目中指定属性数据,点击或模块显示时

当项目中&#xff0c;需要获取某个页面上、某个标签上、有指定自定义属性时&#xff0c;需要在点击该元素时进行公共逻辑处理&#xff0c;或该元素在显示的时候进行逻辑处理&#xff0c;这时可以定义一个公共的方法&#xff0c;在每个页面引用&#xff0c;并写入数据即可 &…

OpenHarmony RK3568 启动流程优化

目前rk3568的开机时间有21s&#xff0c;统计的是关机后从按下 power 按键到显示锁屏的时间&#xff0c;当对openharmony的系统进行了裁剪子系统&#xff0c;系统app&#xff0c;禁用部分服务后发现开机时间仅仅提高到了20.94s 优化微乎其微。在对init进程的log进行分析并解决其…

【Spring Boot 3】异步线程任务

【Spring Boot 3】异步线程任务 背景介绍开发环境开发步骤及源码工程目录结构总结背景 软件开发是一门实践性科学,对大多数人来说,学习一种新技术不是一开始就去深究其原理,而是先从做出一个可工作的DEMO入手。但在我个人学习和工作经历中,每次学习新技术总是要花费或多或…

面向云服务的GaussDB全密态数据库

前言 全密态数据库&#xff0c;顾名思义与大家所理解的流数据库、图数据库一样&#xff0c;就是专门处理密文数据的数据库系统。数据以加密形态存储在数据库服务器中&#xff0c;数据库支持对密文数据的检索与计算&#xff0c;而与查询任务相关的词法解析、语法解析、执行计划生…

【工具】raw与jpg互转python-cpp

在工作中常常需要将图像转化为raw数据或者yuv数据&#xff0c;这里将提供 cpp 版本和 python 版本的互转代码 代码链接见文档尾部。 cpp 版本 jpg2raw.cpp #include <fstream> #include <iostream> #include <opencv2/core.hpp> #include <opencv2/hig…

oracle版本号中的i,G,C代表什么含义

大家都熟悉的 Oracle 版本号有 9i、10G、11G、12C、19C 等&#xff0c;但在早期&#xff0c;Oracle 的版本号并不包含这些字母。 最初&#xff0c;Oracle 的版本号简单地是 1、2、3、4 等&#xff0c;一直发展到 1999 年发布的 8i 版本。20 世纪末是互联网爆发式发展的时代。 …

将一个excel文件里面具有相同参数的行提取后存入新的excel

功能描述&#xff1a; 一个excel里面有很多行数据&#xff0c;其中“交易时间”这一列有很多交易日期&#xff0c;有些行的交易日期是一样的&#xff0c;那么就把所有交易日期相同的行挑出来&#xff0c;形成一个新的以交易日期命名的文件。import pandas as pd import os# 读取…

跨境ERP定制趋势预测:数字化转型助您赢得市场先机

随着全球贸易的不断融合和发展&#xff0c;跨境业务已成为许多企业拓展市场的重要途径。在这个背景下&#xff0c;ERP定制正逐渐成为企业数字化转型的关键利器。本文将为您预测跨境ERP定制的趋势&#xff0c;并探讨数字化转型如何助您赢得市场先机。 ERP定制趋势预测 1. 数据…

命令行启动Android Studio模拟器

1、sdk路径查看&#xff08;打开Android Studio&#xff09; 以上前提是安装的Android Studio并添加了模拟器&#xff01;&#xff01;&#xff01; 2、复制路径在终端进入到 cd /Users/duxi/Library/Android/sdk目录&#xff08;命令行启动不用打开Android Studio就能运行模拟…

【Java程序设计】【C00182】基于SSM的高校成绩报送管理系统(论文+PPT)

基于SSM的高校成绩报送管理系统&#xff08;论文PPT&#xff09; 项目简介项目获取开发环境项目技术运行截图 项目简介 这是一个基于ssm的高校成绩报送系统 本系统分为前台系统、管理员、教师以及学生4个功能模块。 前台系统&#xff1a;当游客打开系统的网址后&#xff0c;首…

25考研北大软微该怎么做?

25考研想准备北大软微&#xff0c;那肯定要认真准备了 考软微需要多少实力 现在的软微已经不是以前的软微了&#xff0c;基本上所有考计算机的同学都知道&#xff0c;已经没有什么信息优势了&#xff0c;只有实打实的有实力的选手才建议报考。 因为软微的专业课也是11408&am…

PyTorch自动微分机制的详细介绍

PyTorch深度学习框架的官方文档确实提供了丰富的信息来阐述其内部自动微分机制。在PyTorch中&#xff0c;张量&#xff08;Tensor&#xff09;和计算图&#xff08;Computation Graph&#xff09;的设计与实现使得整个系统能够支持动态的、高效的自动求导过程。 具体来说&#…

掌握Java多线程利器:ConcurrentHashMap详解

在并发编程的世界里&#xff0c;每一个微小的延迟都可能积累成为性能瓶颈。今天&#xff0c;让我们一起揭开Java中ConcurrentHashMap的神秘面纱&#xff0c;这是一个在多线程环境中不可或缺的高性能组件。从它的设计理念到底层实现&#xff0c;我们将详细探讨ConcurrentHashMap…

基于团簇阵列中的量子隧穿效应的氢气传感器

在科技日新月异的今天&#xff0c;传感器技术也在不断地发展和创新。其中&#xff0c;基于团簇阵列中的量子隧穿效应的氢气传感器&#xff0c;以其独特的优势和巨大的潜力&#xff0c;成为了气体检测技术领域的一颗新星。 一、什么是基于团簇阵列中的量子隧穿效应的氢气传感器&…

浅谈Java主流锁

浅谈Java主流锁 synchronized关键字 synchronized是Java中最基本的锁机制&#xff0c;可以用来修饰方法或代码块。 修饰方法&#xff1a; public synchronized void method() {// 代码 }修饰代码块&#xff1a; public void method() {synchronized (this) {// 代码} }synch…

年度重磅更新!“AI+可视化拖拽”实现个性化页面极速开发!组件设计器即将上线!

AI智能开发&#xff01;网站一键复刻&#xff01;设计稿秒变成品&#xff01; 相信对很多关注低代码和AI技术的小伙伴来说&#xff0c; 都觉得像这些还只是停留在概念上的技术&#xff0c;很难落地实践。 但是在「织信」已经全部都做到了&#xff01; 无图无真相&#xff0…