【小笔记】用tsai库实现Rocket家族算法

2024.1.16
Rocket家族算法是用于时间序列分类的强baseline(性能比较参考【小笔记】时序数据分类算法最新小结),Rocket/MiniRocket/MultiRocket官方都有开源实现,相比较而言,用tsai来实现有三个好处:1是快速跑通模型;2是更简洁优雅;3是掌握一个框架能举一反三。
在这里插入图片描述

1.tsai简介

项目:https://github.com/timeseriesAI/tsai
在这里插入图片描述

简介:
用于处理时间序列的工具库,包含TCN、Rockert等众多时间序列处理算法
请添加图片描述
安装:

pip install tsai

2.Rocket:最优雅的实现

这个例子是基于UCR的Beef数据集,运行时,会自动下载数据集到项目的data路径下

from tsai.all import *
from sklearn.linear_model import RidgeClassifierCV
from dsets_build import get_my_dsetsdevice = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)# 加载UCR数据集
X, y, splits = get_UCR_data('Beef', return_split=False, on_disk=True, verbose=True)
tfms  = [None, [Categorize()]]
batch_tfms = [TSStandardize(by_sample=True)]
dsets = TSDatasets(X, y, tfms=tfms, splits=splits)# 标准示例
dls = TSDataLoaders.from_dsets(dsets.train, dsets.valid, bs=768, drop_last=False, shuffle_train=False,device=device,batch_tfms=[TSStandardize(by_sample=True)])
model = create_model(ROCKET, dls=dls)
# model = model.to(device)print("构造特征...")
X_train, y_train = create_rocket_features(dls.train, model, verbose=False)
X_valid, y_valid = create_rocket_features(dls.valid, model, verbose=False)
print(X_train.shape, X_valid.shape)print("基于特征开始训练...")
ridge = RidgeClassifierCV(alphas=np.logspace(-8, 8, 17))
ridge.fit(X_train, y_train)
print(f'alpha: {ridge.alpha_:.2E}  train: {ridge.score(X_train, y_train):.5f}  valid: {ridge.score(X_valid, y_valid):.5f}')

3.MiniRocket:(比Rocket更快)

待补充

4.MultiRocket:(比MiniRocket更强)

待补充

5.Hydra-MultiRocket:(Rocket家最强王者)

待补充

6.用自己的数据集训练模型

上面的例子都是用的UCR数据集,若要用自己的数据集进行训练怎么解决?
官方教程:
tsai-main\tutorial_nbs路径下的00c_Time_Series_data_preparation.ipynb
在这里插入图片描述
我总结了一下,基于单变量时间序列构建数据集就是下面这样:
dsets_build.py

from tsai.all import *
import numpy as np
import pandas as pddef get_my_dsets():# 导入数据集train_data, valid_data, test_data = [[], []], [[], []], [[], []]radio_train, radio_valid, radio_test = 0.6, 0.2, 0.2# !这是我的读取读取例子,读者需要进行替换----------------------------------path = "train.csc"data = pd.read_csv(path)	train_data[0] = data['x'].tolist()train_data[1] = data['y'].tolist()# -----------------------------------------------------------------------# 将数据转换为np.array即可,剩下的都是通用了X_2d, y = np.array(train_data[0]), np.array(train_data[1])print(X_2d.shape, y.shape)    # (4000, 4096) (4000,)splits = get_splits(y, valid_size=0.2, stratify=True, random_state=23, shuffle=True, show_plot=False)print(splits)tfms = [None, [Categorize()]]dsets = TSDatasets(X_2d, y, tfms=tfms, splits=splits, inplace=True)print(dsets)return dsets

将数据集转换为tsai的dsets后,就可以直接用于训练模型了。

from tsai.all import *
from sklearn.linear_model import RidgeClassifierCV
from dsets_build import get_my_dsetsdevice = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)# 加载UCR数据集
# X, y, splits = get_UCR_data('Beef', return_split=False, on_disk=True, verbose=True)
# tfms  = [None, [Categorize()]]
# batch_tfms = [TSStandardize(by_sample=True)]
# dsets = TSDatasets(X, y, tfms=tfms, splits=splits)# 加载自定义的数据集
dsets = get_my_dsets()         # 和Rockert例子只有这里的区别# 标准示例
dls = TSDataLoaders.from_dsets(dsets.train, dsets.valid, bs=768, drop_last=False, shuffle_train=False,device=device,batch_tfms=[TSStandardize(by_sample=True)])
model = create_model(ROCKET, dls=dls)
# model = model.to(device)print("构造rocket特征...")
X_train, y_train = create_rocket_features(dls.train, model, verbose=False)
X_valid, y_valid = create_rocket_features(dls.valid, model, verbose=False)
print(X_train.shape, X_valid.shape)print("基于特征开始训练...")
ridge = RidgeClassifierCV(alphas=np.logspace(-8, 8, 17))
ridge.fit(X_train, y_train)
print(f'alpha: {ridge.alpha_:.2E}  train: {ridge.score(X_train, y_train):.5f}  valid: {ridge.score(X_valid, y_valid):.5f}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/628940.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

WPF应用程序生存期以及相关事件

WPF 应用程序的生存期会通过 Application 引发的几个事件来加以标记,相关事件对应着应用程序何时启动、激活、停用和关闭。 应用程序生存期事件 • 独立应用程序(传统风格的 Windows 应用程序,这些应用程序作为要安装到客户端计算机并从客户端计算机运…

VitePress-01-从零开始的项目创建(npm版)

说明 本文介绍一下 VitePress的项目创建的步骤。 主要用到的命令工具是 npm。 本文的操作步骤是从无到有的创建一个完整的基本的【VitePress】项目。 环境准备 根据官方文档的介绍,截止本文发稿时,需要使用node.js 18 的版本。 可以使用node -v 的命令查…

关于java的封装

关于java的封装 我们在前面的文章中,了解到了类和对象的知识,以及做了创建对象的时候对内存的分析,我们本篇文章来了解一下面向对象的三大基本特征之一,封装😀。 一、初识封装 封装就好比,我们把一些物品…

【操作系统】1. 操作系统概述

文章目录 【 1. 什么是操作系统 】【 2. 操作系统软件的分类 】【 3. 操作系统内核的抽象和特征 】3.1 操作系统内核的抽象3.2 操作系统内核的特征 【 1. 什么是操作系统 】 操作系统是管理硬件资源、控制程序运行、改善人机界面和为应用软件提供服务的一种系统 软件。一个服务…

<软考高项备考>《论文专题 - 71 风险管理(3)》

3 过程2-识别风险 3.1 问题 4W1H过程做什么是识别单个项目风险以及整体项目风险的来源,并记录风险特征的过程。作用:1、记录现有的单个项目风险,以及整体项目风险的来源:2、汇总相关信息,以便项目团队能够恰当地应对已识别的风险。为什么做…

怎么修改或移除WordPress后台仪表盘概览底部的版权信息和主题信息?

前面跟大家分享『WordPress怎么把后台左上角的logo和评论图标移除?』和『WordPress后台底部版权信息“感谢使用 WordPress 进行创作”和版本号怎么修改或删除?』,其实在WordPress后台仪表盘的“概览”底部还有一个WordPress版权信息和所使用的…

项目解决方案:“ZL铁路轨行车辆”实时视频监控系统

目 录 一、建设背景 1.1 政策背景 1.2 现状 二、建设目标 三、建设依据 四、建设原则 4.1经济高效性 4.2系统开放性 4.3系统继承性 4.4系统扩展性 4.5系统经济性 4.6系统安全性 五、系统架构 5.1系统架构图 5.2技术架构 1、DVS 2、中心管理服务…

【Java SE语法篇】11.异常

📚博客主页:爱敲代码的小杨. ✨专栏:《Java SE语法》 ❤️感谢大家点赞👍🏻收藏⭐评论✍🏻,您的三连就是我持续更新的动力❤️ 文章目录 1. 异常的概念和体系结构1.1 异常的概念1.2 异常体系…

PTA——7-31 三角形判断

7-31 三角形判断 (15分) 给定平面上任意三个点的坐标(x​1​​,y​1​​)、(x​2​​,y​2​​)、(x​3​​,y​3​​),检验它们能否构成三角形。 输入格式: 输入在一行中顺序给出六个[−100,100]范围内的数字,即三个点的坐标x​1​​、y​1​​、x​2​…

SUKER书客重磅发布全新系列:书客Sun立式护眼台灯,护眼养眼新境界

近日,国内知名的光学国货品牌——SUKER书客在2024年新品发布上,正式发布了全新系列的书客Sun立式护眼台灯。 SUKER书客作为近年来快速发展的创新型光学技术品牌,曾推出的一系列产品都取得了刷新行业标准的成绩,他们坚持以创新为动…

【51单片机系列】单片机与PC进行串行通信

一、单片机与PC机串行通信的设计 工业现场的测控系统中,常使用单片机进行监测点的数据采集,然后单片机通过串口与PC通信,把采集的数据串行传送到PC机上,再在PC机上进行数据处理。 PC机配置的都是RS-232标准串口,为D型…

YOLOv5改进 | 2023主干篇 | 多种轻量化卷积优化PP-HGNetV2改进主干(全网独家创新)

一、本文介绍 Hello,大家好,上一篇博客我们讲了利用HGNetV2去替换YOLOv5的主干,经过结构的研究我们可以发现在HGNetV2的网络中有大量的卷积存在,所以我们可以用一种更加轻量化的卷积去优化HGNetV2从而达到更加轻量化的效果(亲测优化后的HGNetV2网络比正常HGNetV2精度更高…

开发知识点-java基础

java基础知识整理 windows 多版本java jar包不能直接打开 需要java -jar问题解决 windows 多版本 控制面板 java15 download 多版本 https://www.cnblogs.com/chenmingjun/p/9941191.html https://gitee.com/shixinke/JC-jEnv/repository/archive/master.zip java jar包不…

React16源码: React中的renderRoot的源码实现

renderRoot 1 )概述 renderRoot 是一个非常复杂的方法这个方法里处理很多各种各样的逻辑, 它主要的工作内容是什么?A. 它调用 workLoop 进行循环单元更新 遍历整个 Fiber Tree,把每一个组件或者 dom 节点对应的Fiber 节点拿出来单一的进行更…

万户 ezOFFICE ezflow_gd.jsp SQL注入漏洞复现

0x01 产品简介 万户OA ezoffice是万户网络协同办公产品多年来一直将主要精力致力于中高端市场的一款OA协同办公软件产品,统一的基础管理平台,实现用户数据统一管理、权限统一分配、身份统一认证。统一规划门户网站群和协同办公平台,将外网信息维护、客户服务、互动交流和日…

DC电源模块与AC电源模块的对比分析

DC电源模块与AC电源模块的对比分析 BOSHIDA DC电源模块和AC电源模块是两种常见的电源模块,它们在供电方式、稳定性、适用范围等方面有所不同,下面是它们的对比分析: 1. 供电方式: DC电源模块通过直流电源供电,通常使用…

【Linux】Linux 系统编程——which 命令

文章目录 1.命令概述2.命令格式3.常用选项4.相关描述5.参考示例 1.命令概述 which 命令用于定位执行文件的路径。当输入一个命令时,which 会在环境变量 PATH 所指定的路径中搜索每个目录,以查找指定的可执行文件。 2.命令格式 which [选项] 命令名3.常…

生产力与生产关系 —— 浅析爱泼斯坦事件 之 弱电控制强电原理

据网络文字与视频资料,爱泼斯坦事件是犹太精英阶层,为了掌控美国国家机器为犹太利益集团服务,而精心设下的一个局。本文先假设这个结论成立,并基于此展开讨论。 我们知道,弱电管理强电是电气工程中的一门专门学问&…

Mysql 数据库DDL 数据定义语言——数据库,数据表的创建

DDL:数据定义语言,用来定义数据库对象(数据库,表,字段)—Database Definition Language 1、登录数据库,输入用户名和密码 mysql -ufdd -p990107Wjl2、查看数据库 show databases;3、创建一个…

MySQL面试题 | 12.精选MySQL面试题

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云…