【时间序列篇】基于LSTM的序列分类-Pytorch实现 part1 案例复现

系列文章目录

【时间序列篇】基于LSTM的序列分类-Pytorch实现 part1 案例复现
【时间序列篇】基于LSTM的序列分类-Pytorch实现 part2 自有数据集构建
【时间序列篇】基于LSTM的序列分类-Pytorch实现 part3 化为己用

本篇文章是对已有一篇文章的整理归纳,并对文章中提及的模型用Pytorch实现。

文章目录

  • 系列文章目录
  • 前言
  • 一、任务问题和数据集
    • 1 任务问题
    • 2 数据集
    • 3 数据集读取并展示
  • 二、模型实现
    • 1 数据导入
    • 2 数据预处理
    • 3 数据集划分
    • 4 网络模型及实例化
    • 5 训练过程
  • 三、总结


前言

序列,可以是采样得到的信号样本,也可以是传感器数据。

对于序列分类任务,常用的思路有两种:
1、原理统计相关,分解序列的相关性质研究规律(人工设计特征,再分类)
2、数据挖掘思路,机器学习做特征工程,模型拟合(自动学习特征,再分类)

  • 人工设计特征方法:
    基于序列距离:计算距离进行分类(类别模板or聚类)
    基于统计特征:时序特征提取 (均值,方差,差分)再分类

  • 自动学习特征方法:
    深度学习端到端(RNN, LSTM)

本文通过LSTM来实现对序列信号的分类。


主要思想和代码框架来自参考文献[1]

一、任务问题和数据集

1 任务问题

人体运动估计:
传感器生成高频数据,对不同状态下采集的数据进行分类,可以识别其范围内对象的移动。通过设置多个传感器并对信号进行采样分析,可以识别物体的运动方向。

“ 室内用户运动预测 ”问题:
在该任务中,多个运动传感器被放置在不同房间中,目标基于运动传感器捕获的数据来识别个体是否已经移动穿过房间。

两个房间有四个运动传感器(A1,A2,A3,A4)。
下图说明了传感器在每个房间中的位置。
在这里插入图片描述
一个人可以沿着上图中所示的六个预定义路径中的任何一个移动。每个路径都生成一个 RSS 测量的轨迹样本,从轨迹的开始一直到标记点,在图中表示为 M。标记 M 对于所有运动都是相同的,因此不能仅仅根据在 M 处收集的 RSS 值来区分不同的路径。
该图还显示了所考虑的用户轨迹类型的简化说明,直线路径导致于空间变化,曲线路径导致空间不变。有在房间内移动和在房间之间移动两种类别。

2 数据集

文件含义
RSS_Position_dataset/dataset样本数据
RSS_Position_dataset/groups标签文件和组别文件(划分数据集)
RSS_Position_dataset/MovementAAL.jpg上面的示意图

数据集最重要的有316个csv文件:

  • 【dataset 文件夹】
    314 个MovementAAL csv文件,是序列样本,每个文件都包含与输入 RSS 数据的一个序列数据(每个文件记录一个用户轨迹)。该数据集包含314个序列数据(样本csv文件)。
    1个 MovementAAL_target.csv 文件,是每个MovementAAL文件对应的标签(类别)。每一个样本对应的类别,表明用户的轨迹是否会导致空间变化(例如房间的变化)。特别地,标签为+1与位置变化相关联,而标签为 -1与位置保留相关联。
  • 【groups 文件夹】
    MovementAAL_DatasetGroup.csv文件,用于划分数据集

3 数据集读取并展示

import pandas as pd
# ----------------------------------------------------#
#   路径指定,文件读取
# ----------------------------------------------------#
df1 = pd.read_csv("DATA/RSS_Position_dataset/dataset/MovementAAL_RSS_1.csv")
df2 = pd.read_csv("DATA/RSS_Position_dataset/dataset/MovementAAL_RSS_2.csv")df1.head()  # 返回一个新的DataFrame或Series对象,默认返回前5行。
df1.shape  # 返回文件的size,不同文件的len(行数)不同

二、模型实现

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim

1 数据导入

'''
/****************************************************/导入数据集
/****************************************************/
'''
# ----------------------------------------------------#
#   数据集样本
# ----------------------------------------------------#
path = "DATA/RSS_Position_dataset/dataset/MovementAAL_RSS_"
sequences = list()
for i in range(1, 315):  # 315为样本数file_path = path + str(i) + '.csv'df = pd.read_csv(file_path, header=0)values = df.valuessequences.append(values)# ----------------------------------------------------#
#   数据集标签
# ----------------------------------------------------#
targets = pd.read_csv('DATA/RSS_Position_dataset/dataset/MovementAAL_target.csv')
targets = targets.values[:, 1]# ----------------------------------------------------#
#   数据集划分
# ----------------------------------------------------#
groups = pd.read_csv('DATA/RSS_Position_dataset/groups/MovementAAL_DatasetGroup.csv', header=0)
groups = groups.values[:, 1]

分析:

  1. 数据集样本:将所有的样本读入sequences列表中,列表长度为样本数,列表中每一个元素为一个样本。
  2. 数据集标签:targets 中存放。
  3. 数据集划分:数据集是在三对不同的房间中收集的,因此有三组。此信息可用于将数据集划分为训练集,测试集和验证集。

2 数据预处理

由于时间序列数据的长度不同,sequences列表中每个元素长度不一。无法直接在此数据集上构建模型。需要统一。原文中的思想是填充使相等。
这里是对样本,即sequences列表变量进行处理。

# ----------------------------------------------------#
#   Padding the sequence with the values in last row to max length
# ----------------------------------------------------#
# 函数用于填充和截断序列
def pad_truncate_sequences(sequences, max_len, dim=4, truncating='post', padding='post'):# 初始化一个空的numpy数组,用于存储填充后的序列padded_sequences = np.zeros((len(sequences), max_len, dim))for i, one_seq in enumerate(sequences):if len(one_seq) > max_len:  # 截断if truncating == 'pre':padded_sequences[i] = one_seq[-max_len:]else:padded_sequences[i] = one_seq[:max_len]else:  # 填充padding_len = max_len - len(one_seq)to_concat = np.repeat(one_seq[-1], padding_len).reshape(dim, padding_len).transpose()if padding == 'pre':padded_sequences[i] = np.concatenate([to_concat, one_seq])else:padded_sequences[i] = np.concatenate([one_seq, to_concat])return padded_sequences# 使用自定义函数进行填充和截断
seq_len = 32
# truncate or pad the sequence to seq_len
final_seq = pad_truncate_sequences(sequences, max_len=seq_len, dim=4, truncating='post', padding='post')

对数据集来说,标签 +1/-1 不利于模型输出,变为 1/0。
这里是对标签,即targets类别变量进行处理。

# 设置标签从 +1/-1 ,变为 1/0
targets = np.array(targets)
final_targets = (targets+1)/2

3 数据集划分

# ----------------------------------------------------#
#   数据集划分
# ----------------------------------------------------#
# 将numpy数组转换为PyTorch张量
final_seq = torch.tensor(final_seq, dtype=torch.float)# 划分样本为 训练集,验证集 和 测试集
train = [final_seq[i] for i in range(len(groups)) if groups[i] == 1]
validation = [final_seq[i] for i in range(len(groups)) if groups[i] == 2]
test = [final_seq[i] for i in range(len(groups)) if groups[i] == 3]
# 标签同理
train_target = [final_targets[i] for i in range(len(groups)) if groups[i] == 1]
validation_target = [final_targets[i] for i in range(len(groups)) if groups[i] == 2]
test_target = [final_targets[i] for i in range(len(groups)) if groups[i] == 3]# 转换为PyTorch张量
train = torch.stack(train)
train_target = torch.tensor(train_target).long()validation = torch.stack(validation)
validation_target = torch.tensor(validation_target).long()test = torch.stack(test)
test_target = torch.tensor(test_target).long()

4 网络模型及实例化

'''
/****************************************************/网络模型
/****************************************************/
'''
# ----------------------------------------------------#
#   LSTM 模型
# ----------------------------------------------------#
class TimeSeriesClassifier(nn.Module):def __init__(self, n_features, hidden_dim=256, output_size=1):super().__init__()self.lstm = nn.LSTM(input_size=n_features, hidden_size=hidden_dim, batch_first=True)self.fc = nn.Linear(hidden_dim, output_size)  # output_size classesdef forward(self, x):x, _ = self.lstm(x)  # LSTM层x = x[:, -1, :]  # 只取LSTM输出中的最后一个时间步x = self.fc(x)  # 通过一个全连接层return x# ----------------------------------------------------#
#   模型实例化 和 部署
# ----------------------------------------------------#
n_features = 4  # 根据你的特征数量进行调整
output_size = 2
model = TimeSeriesClassifier(n_features=n_features, output_size=output_size)# 打印模型结构
print(model)

5 训练过程

'''
/****************************************************/训练过程
/****************************************************/
'''
# 设置训练参数
epochs = 100  # 训练轮数,根据需要进行调整
batch_size = 4  # 批大小,根据你的硬件调整# DataLoader 加载数据集
# 将数据集转换为张量并创建数据加载器
train_dataset = torch.utils.data.TensorDataset(train, train_target)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)validation_dataset = torch.utils.data.TensorDataset(validation, validation_target)
validation_loader = torch.utils.data.DataLoader(dataset=validation_dataset, batch_size=batch_size, shuffle=True)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()# 学习率和优化策略
learning_rate = 1e-3
optimizer = optim.Adam(params=model.parameters(), lr=learning_rate, weight_decay=5e-4)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)  # 设置学习率下降策略# ----------------------------------------------------#
#   训练
# ----------------------------------------------------#
def calculate_accuracy(y_pred, y_true):_, predicted_labels = torch.max(y_pred, 1)correct = (predicted_labels == y_true).float()accuracy = correct.sum() / len(correct)return accuracyfor epoch in range(epochs):model.train()  # 将模型设置为训练模式train_epoch_loss = []train_epoch_accuracy = []for i, data in enumerate(train_loader, 0):inputs, labels = data  # 获取输入数据和标签optimizer.zero_grad()  # 清零梯度outputs = model(inputs)  # 前向传播loss = criterion(outputs, labels)loss.backward()  # 反向传播和优化optimizer.step()# 打印统计信息# train_epoch_loss.append(loss.item())# accuracy = calculate_accuracy(outputs, labels)# train_epoch_accuracy.append(accuracy.item())## train_running_loss = np.average(train_epoch_loss)# train_running_accuracy = np.average(train_epoch_accuracy)## if i % 10 == 9:  # 每10个批次打印一次#     print("--------------------------------------------")#     print(f'Epoch {epoch + 1}, Loss: {train_running_loss}, accuracy: {train_running_accuracy}')# Validation accuracymodel.eval()valid_epoch_accuracy = []with torch.no_grad():for inputs, labels in validation_loader:  # Assuming validation_loader is definedoutputs = model(inputs)accuracy = calculate_accuracy(outputs, labels)valid_epoch_accuracy.append(accuracy.item())# 计算平均精度valid_running_accuracy = np.average(valid_epoch_accuracy)print(f'Epoch {epoch + 1}, Validation Accuracy: {valid_running_accuracy:.4f}')print('Finished Training')

三、总结

在验证集上的分类准确率最高才70%。emmm我猜是数据少。

CSDN: 进行时间序列分类实践–python实战

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/650386.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HybridA* 论文解读

本文旨在对原论文进行翻译,对混合A*有一个大概的理解 论文题目:Practical Search Techniques in Path Planning for Autonomous Driving 1 摘要 本文描述了一个实用的路径规划算法,无人驾驶汽车在未知的环境中,障碍物通过机器人…

计算机毕业设计 | SSM 凌云招聘平台(附源码)

1,绪论 人力资源是企业产生效益、创造利润的必不可少的、最重要的资源。人作为人力资源的个体可看作是一个承载着有效知识、能力的信息单元。这样的信息单元可看作是一个为企业产生价值和利润的个体。从而使得这样的信息单元所具有的信息就是一个有价值的信息。 校…

day34WEB 攻防-通用漏洞文件上传黑白盒审计逻辑中间件外部引用

目录 一,白盒审计-Finecms-代码常规-处理逻辑 黑盒思路:寻找上传点抓包修改突破获取状态码及地址 审计流程:功能点-代码文件-代码块-抓包调试-验证测试 二,白盒审计-CuppaCms-中间件-.htaccess 三,白盒审计-Metin…

银行数据仓库体系实践(11)--数据仓库开发管理系统及开发流程

数据仓库管理着整个银行或公司的数据,数据结构复杂,数据量庞大,任何一个数据字段的变化或错误都会引起数据错误,影响数据应用,同时业务的发展也带来系统不断升级,数据需求的不断增加,数据仓库需…

adb测试冷启动和热启动 Permission Denial解决

先清理日志 adb shell logcat -c 打开手机模拟器中的去哪儿网,然后日志找到包名和MainActivity adb shell logcat |grep Main com.Qunar/com.mqunar.atom.alexhome.ui.activity.MainActivity 把手机模拟器的去哪儿的进程给杀掉 执行 命令 adb shell am start -W…

专业133总分400+上海交通大学819考研经验分享上交819电子信息与通信工程

今年专业819信号系统与信号处理133,总分400,如愿考上梦中上海交通大学,通过自己将近一年的复习,实现了人生中目前为止最大的逆袭(自己本科学校很普通),总结自己的复习经历,希望可以给…

苹果Arcade会员的交易开通

arcade是苹果的游戏订阅服务,会员可以畅玩200多个苹果商店精品游戏,包括美区apple id绑卡apple tv购买内购游戏apple one、A2K、狂野飙8,同时ChatGPT也可以,并且这些游戏没有广告没有内购项目,可以在线玩也可以离线玩&…

华为云WAF,开启web网站的专属反爬虫防护罩

背景 从保护原创说起 作为一个原创技术文章分享博主,日常除了Codeing就是总结Codeing中的技术经验。 之前并没有对文章原创性的保护意识,直到在某个非入驻的平台看到了我的文章,才意识到,辛苦码字、为灵感反复试验创作出来的文…

JavaScript模块系统入门教程

🧑‍🎓 个人主页:《爱蹦跶的大A阿》 🔥当前正在更新专栏:《VUE》 、《JavaScript保姆级教程》、《krpano》、《krpano中文文档》 ​ 目录 ✨ 前言 ✨ 正文 一、模块 (Module) 简介 什么是模块 导出与导入 默…

QGIS编译(跨平台编译)之二十四:libbz2编译(Windows、Linux、MacOS环境下编译)

文章目录 1、libbz2介绍2、文件下载3、Linux下编译4、MacOS下编译5、Windows下编译1、libbz2介绍 bzip2是一个基于Burrows-Wheeler 变换的无损压缩软件,压缩效果比传统的LZ77/LZ78压缩算法来得好。它是一款免费软件。可以自由分发免费使用。 bzip2能够进行高质量的数据压缩。…

【代码随想录15】110.平衡二叉树 257. 二叉树的所有路径 404.左叶子之和

目录 110. 平衡二叉树题目描述参考代码 257. 二叉树的所有路径题目描述参考代码 404.左叶子之和题目描述参考代码 110. 平衡二叉树 题目描述 给定一个二叉树,判断它是否是高度平衡的二叉树。 本题中,一棵高度平衡二叉树定义为: 一个二叉树…

AI数字人-数字人视频创作数字人直播效果媲美真人

在科技的不断革新下,数字人技术正日益融入到人们的生活中。近年来,随着AI技术的进一步发展,数字人视频创作领域出现了一种新的创新方式——AI数字人。数字人视频通过AI算法生成虚拟主播,其外貌、动作、语音等方面可与真实人类媲美…

huggingface高速下载模型的实战代码

大家好,我是herosunly。985院校硕士毕业,现担任算法研究员一职,热衷于机器学习算法研究与应用。曾获得阿里云天池比赛第一名,CCF比赛第二名,科大讯飞比赛第三名。拥有多项发明专利。对机器学习和深度学习拥有自己独到的见解。曾经辅导过若干个非计算机专业的学生进入到算法…

Neo4j 国内镜像下载与安装

Neo4j 5.x 简体中文版指南 社区版:https://neo4j.com/download-center/#community 链接地址(Linux版):https://neo4j.com/artifact.php?nameneo4j-community-3.5.13-unix.tar.gz 链接地址(Windows)&#x…

蓝桥杯省赛无忧 编程13 肖恩的投球游戏

#include <iostream> #include <vector> using namespace std; int main() {int n, q;cin >> n >> q;vector<int> a(n 1);vector<int> diff(n 2, 0); // 初始化差分数组// 读取初始球数&#xff0c;构建差分数组for (int i 1; i < …

Go 从标准输入读取数据

fmt.Scan系列 fmt.Scan函数定义如下&#xff1a; // Scan scans text read from standard input, storing successive space-separated values into successive arguments. // Newlines count as space. // It returns the number of items successfully scanned. // If tha…

DS:单链表的实现(超详细!!)

创作不易&#xff0c;友友们点个三连吧&#xff01; 在博主的上一篇文章中&#xff0c;很详细地介绍了顺序表实现的过程以及如何去书写代码&#xff0c;如果没看过的友友们建议先去看看哦&#xff01; DS&#xff1a;顺序表的实现&#xff08;超详细&#xff01;&#xff01;&…

JAVA大学生兼职平台后台管理

运行环境&#xff1a; tomcat7.0jdk1.7或以上 eclipse或idea 使用技术&#xff1a; springboot 功能描述&#xff1a; 求职人员 注册&#xff0c;登录 选定登录角色&#xff08;1、兼职人员2、发布兼职招聘人员&#xff09; 书写简历&#xff0c;上传学生证照片&#…

力扣每日一题 ---- 1039. 多边形三角剖分的最低得分

这题的难点在哪部分呢&#xff0c;其实是怎么思考。这道题如果之前没做过类似的话&#xff0c;还是很难看出一些性质的&#xff0c;这题原本的话是没有图片把用例显示的这么详细的。这题中有个很隐晦的点没有说出来 剖出来的三角形是否有交叉&#xff0c;这题中如果加一个三角…

网络防御——NET实验

一、实验拓扑 二、实验要求 1、生产区在工作时间&#xff08;9&#xff1a;00---18&#xff1a;00&#xff09;内可以访问服务区&#xff0c;仅可以访问http服务器&#xff1b; 2、办公区全天可以访问服务器区&#xff0c;其中&#xff0c;10.0.2.20可以访问FTP服务器和HTTP服…