Pytorch-Transformer轴承故障一维信号分类(三)

目录

前言

1 数据集制作与加载

1.1 导入数据

第一步,导入十分类数据

第二步,读取MAT文件驱动端数据

第三步,制作数据集

第四步,制作训练集和标签

1.2 数据加载,训练数据、测试数据分组,数据分batch

2 Transformer分类模型和超参数选取

2.1 定义Transformer分类模型,采用Transformer架构中的编码器:

2.2 定义模型参数

2.3 模型结构

3 Transformer模型训练与评估

3.1 模型训练

3.2 模型评估


往期精彩内容:

Python-凯斯西储大学(CWRU)轴承数据解读与分类处理

Python轴承故障诊断 (一)短时傅里叶变换STFT

Python轴承故障诊断 (二)连续小波变换CWT

Python轴承故障诊断 (三)经验模态分解EMD

Python轴承故障诊断 (四)基于EMD-CNN的故障分类

Python轴承故障诊断 (五)基于EMD-LSTM的故障分类

Pytorch-LSTM轴承故障一维信号分类(一)

Pytorch-CNN轴承故障一维信号分类(二)

前言

本文基于凯斯西储大学(CWRU)轴承数据,先经过数据预处理进行数据集的制作和加载,最后通过Pytorch实现Transformer模型对故障数据的分类,并介绍Transformer模型的超参数。凯斯西储大学轴承数据的详细介绍可以参考下文:

Python-凯斯西储大学(CWRU)轴承数据解读与分类处理

1 数据集制作与加载

1.1 导入数据

参考之前的文章,进行故障10分类的预处理,凯斯西储大学轴承数据10分类数据集:

第一步,导入十分类数据

import numpy as np
import pandas as pd
from scipy.io import loadmatfile_names = ['0_0.mat','7_1.mat','7_2.mat','7_3.mat','14_1.mat','14_2.mat','14_3.mat','21_1.mat','21_2.mat','21_3.mat']for file in file_names:# 读取MAT文件data = loadmat(f'matfiles\\{file}')print(list(data.keys()))

第二步,读取MAT文件驱动端数据

# 采用驱动端数据
data_columns = ['X097_DE_time', 'X105_DE_time', 'X118_DE_time', 'X130_DE_time', 'X169_DE_time','X185_DE_time','X197_DE_time','X209_DE_time','X222_DE_time','X234_DE_time']
columns_name = ['de_normal','de_7_inner','de_7_ball','de_7_outer','de_14_inner','de_14_ball','de_14_outer','de_21_inner','de_21_ball','de_21_outer']
data_12k_10c = pd.DataFrame()
for index in range(10):# 读取MAT文件data = loadmat(f'matfiles\\{file_names[index]}')dataList = data[data_columns[index]].reshape(-1)data_12k_10c[columns_name[index]] = dataList[:119808]  # 121048  min: 121265
print(data_12k_10c.shape)
data_12k_10c

第三步,制作数据集

train_set、val_set、test_set 均为按照7:2:1划分训练集、验证集、测试集,最后保存数据

第四步,制作训练集和标签

# 制作数据集和标签
import torch# 这些转换是为了将数据和标签从Pandas数据结构转换为PyTorch可以处理的张量,
# 以便在神经网络中进行训练和预测。def make_data_labels(dataframe):'''参数 dataframe: 数据框返回 x_data: 数据集     torch.tensory_label: 对应标签值  torch.tensor'''# 信号值x_data = dataframe.iloc[:,0:-1]# 标签值y_label = dataframe.iloc[:,-1]x_data = torch.tensor(x_data.values).float()y_label = torch.tensor(y_label.values.astype('int64')) # 指定了这些张量的数据类型为64位整数,通常用于分类任务的类别标签return x_data, y_label# 加载数据
train_set = load('train_set')
val_set = load('val_set')
test_set = load('test_set')# 制作标签
train_xdata, train_ylabel = make_data_labels(train_set)
val_xdata, val_ylabel = make_data_labels(val_set)
test_xdata, test_ylabel = make_data_labels(test_set)
# 保存数据
dump(train_xdata, 'trainX_1024_10c')
dump(val_xdata, 'valX_1024_10c')
dump(test_xdata, 'testX_1024_10c')
dump(train_ylabel, 'trainY_1024_10c')
dump(val_ylabel, 'valY_1024_10c')
dump(test_ylabel, 'testY_1024_10c')

1.2 数据加载,训练数据、测试数据分组,数据分batch

import torch
from joblib import dump, load
import torch.utils.data as Data
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
# 参数与配置
torch.manual_seed(100)  # 设置随机种子,以使实验结果具有可重复性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 有GPU先用GPU训练# 加载数据集
def dataloader(batch_size, workers=2):# 训练集train_xdata = load('trainX_1024_10c')train_ylabel = load('trainY_1024_10c')# 验证集val_xdata = load('valX_1024_10c')val_ylabel = load('valY_1024_10c')# 测试集test_xdata = load('testX_1024_10c')test_ylabel = load('testY_1024_10c')# 加载数据train_loader = Data.DataLoader(dataset=Data.TensorDataset(train_xdata, train_ylabel),batch_size=batch_size, shuffle=True, num_workers=workers, drop_last=True)val_loader = Data.DataLoader(dataset=Data.TensorDataset(val_xdata, val_ylabel),batch_size=batch_size, shuffle=True, num_workers=workers, drop_last=True)test_loader = Data.DataLoader(dataset=Data.TensorDataset(test_xdata, test_ylabel),batch_size=batch_size, shuffle=True, num_workers=workers, drop_last=True)return train_loader, val_loader, test_loaderbatch_size = 32
# 加载数据
train_loader, val_loader, test_loader = dataloader(batch_size)

2 Transformer分类模型和超参数选取

2.1 定义Transformer分类模型,采用Transformer架构中的编码器:

注意:输入数据进行了堆叠 ,把一个1*1024 的序列 进行划分堆叠成形状为 32 * 32, 就使输入序列的长度降下来了

2.2 定义模型参数

# 模型参数
input_dim = 32 # 输入维度
hidden_dim = 512  # 注意力维度
output_dim  = 10  # 输出维度
num_layers = 4   # 编码器层数
num_heads = 8    # 多头注意力头数
batch_size = 32
# 模型
model = TransformerModel(input_dim, output_dim, hidden_dim, num_layers, num_heads, batch_size)  
model = model.to(device)
loss_function = nn.CrossEntropyLoss(reduction='sum')  # loss
learn_rate = 0.0003
optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)  # 优化器

2.3 模型结构

3 Transformer模型训练与评估

3.1 模型训练

训练结果

100个epoch,准确率将近90%,Transformer模型分类效果良好,参数过拟合了,适当调整模型参数,降低模型复杂度,还可以进一步提高分类准确率。

注意调整参数:

  • 可以适当增加 Transforme编码器层数 和隐藏层的维度,微调学习率;

  • 调整多头注意力的头数,增加更多的 epoch (注意防止过拟合)

  • 可以改变一维信号堆叠的形状(设置合适的长度和维度)

3.2 模型评估

# 模型 测试集 验证  
import torch.nn.functional as F# 加载模型
model =torch.load('best_model_transformer.pt')
# model = torch.load('best_model_cnn2d.pt', map_location=torch.device('cpu'))# 将模型设置为评估模式
model.eval()
# 使用测试集数据进行推断
with torch.no_grad():correct_test = 0test_loss = 0for test_data, test_label in test_loader:test_data, test_label = test_data.to(device), test_label.to(device)test_output = model(test_data)probabilities = F.softmax(test_output, dim=1)predicted_labels = torch.argmax(probabilities, dim=1)correct_test += (predicted_labels == test_label).sum().item()loss = loss_function(test_output, test_label)test_loss += loss.item()test_accuracy = correct_test / len(test_loader.dataset)
test_loss = test_loss / len(test_loader.dataset)
print(f'Test Accuracy: {test_accuracy:4.4f}  Test Loss: {test_loss:10.8f}')Test Accuracy: 0.9570  Test Loss: 0.12100271

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/215382.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

据房间Id是否存在,判断当前房间是否到期且实时更改颜色

重点代码展示&#xff1a; <template><el-col style"width: 100%;height: 100%;"><el-col :span"20"><el-card class"room_info"><avue-data-icons :option"option"></avue-data-icons></el-…

RT-DETR算法优化改进:轻量化自研设计双卷积重新设计backbone和neck,完成涨点且计算量和参数量显著下降

💡💡💡本文自研创新改进:双卷积由组卷积和异构卷积组成,执行 33 和 11 卷积运算代替其他卷积核仅执行 11 卷积,YOLOv8 Conv,从而轻量化RT-DETR,性能如下表,GFLOPs 8.1降低至7.6,参数量6.3MB降低至5.8MB RT-DETR魔术师专栏介绍: https://blog.csdn.net/m0_637742…

ubuntu-c++-可执行模块-动态链接库-链接库搜索-基础知识

文章目录 1.动态链接库简介2.动态库搜索路径3.运行时链接及搜索顺序4.查看可运行模块的链接库5.总结 1.动态链接库简介 动态库又叫动态链接库&#xff0c;是程序运行的时候加载的库&#xff0c;当动态链接库正确安装后&#xff0c;所有的程序都可以使用动态库来运行程序。动态…

Android帝国之日志系统--logd、logcat

本文概要 这是Android系统进程系列的第四篇文章&#xff0c;本文以自述的方式来介绍logd进程&#xff0c;通过本文您将了解到logd进程存在的意义&#xff0c;以及日志系统的实现原理。&#xff08;文中的代码是基于android13&#xff09; Android系统进程系列的前三篇文章如下…

C#基础与进阶扩展合集-基础篇(持续更新)

目录 本文分两篇&#xff0c;进阶篇点击&#xff1a;C#基础与进阶扩展合集-进阶篇 一、基础入门 Ⅰ 关键字 Ⅱ 特性 Ⅲ 常见异常 Ⅳ 基础扩展 1、哈希表 2、扩展方法 3、自定义集合与索引器 4、迭代器与分部类 5、yield return 6、注册表 7、不安全代码 8、方法…

MATLAB中cell函数的用法

cell用法 在MATLAB中&#xff0c;cell 是一种特殊的数据类型&#xff0c;用于存储不同大小和类型的数据。cell 数组是一种容器&#xff0c;每个元素可以包含任意类型的数据&#xff0c;包括数值、字符串、矩阵、甚至其他的 cell 数组。 以下是 cell 数组的基本语法和示例&…

gitblit自建git仓库

安装 java sudo apt-get update sudo apt-get install openjdk-8-jdk # 或者其它你喜欢的版本 验证&#xff1a; java -version 下载 gitblit https://github.com/gitblit-org/gitblit/releases 解压/usr/local tar -zxvf gitblit-1.9.3.tar.gz 修改配置文件 nano /usr/local/…

【React】useCallback 使用的说明

文章目录 useCallback的优缺点优点缺点JavaScript 的内联优化 使用场景 用了两年多的react&#xff0c;今天抽空写点小内容 useCallback的优缺点 缓存了每次渲染时候 inline callback的实例 优点 关键点&#xff1a;利用memoize减少无效的re-render&#xff0c;通常配合shouldC…

ElasticSearch之cat trained model API

命令样例如下&#xff1a; curl -X GET "https://localhost:9200/_cat/ml/trained_models?vtrue&pretty" --cacert $ES_HOME/config/certs/http_ca.crt -u "elastic:ohCxPHQBEs5*lo7F9"执行结果输出如下&#xff1a; id heap_size …

如何在OpenWRT软路由系统部署uhttpd搭建web服务器实现远程访问——“cpolar内网穿透”

文章目录 前言1. 检查uhttpd安装2. 部署web站点3. 安装cpolar内网穿透4. 配置远程访问地址5. 配置固定远程地址 前言 uhttpd 是 OpenWrt/LuCI 开发者从零开始编写的 Web 服务器&#xff0c;目的是成为优秀稳定的、适合嵌入式设备的轻量级任务的 HTTP 服务器&#xff0c;并且和…

docker-compose的介绍与使用

一、docker-compose 常用命令和指令 1. 概要 默认的模板文件是 docker-compose.yml&#xff0c;其中定义的每个服务可以通过 image 指令指定镜像或 build 指令&#xff08;需要 Dockerfile&#xff09;来自动构建。 注意如果使用 build 指令&#xff0c;在 Dockerfile 中设置…

RHEL网络服务器

目录 1.时间同步的重要性 2.配置时间服务器 &#xff08;1&#xff09;指定所使用的上层时间服务器。 (2&#xff09;指定允许访问的客户端 (3&#xff09;把local stratum 前的注释符#去掉。 3.配置chrony客户端 &#xff08;1&#xff09;修改pool那行,指定要从哪台时间…

Python常见面试知识总结(一):迭代器、拷贝、线程及底层结构

前言&#xff1a; Hello大家好&#xff0c;我是Dream。 今天来总结一下Python和C语言中常见的面试知识&#xff0c;欢迎大家一起前来探讨学习~ 【一】Python中迭代器的概念&#xff1f; 可迭代对象是迭代器、生成器和装饰器的基础。简单来说&#xff0c;可以使用for来循环遍历…

[古剑山2023] pwn

最近这个打stdout的题真多。这个比赛没打。拿到附件作了一天。 choice 32位&#xff0c;libc-2.23-i386&#xff0c;nbytes初始值为0x14,读入0x804A04C 0x14字节后会覆盖到nbytes 1个字节。当再次向v1读入nbytes字节时会造成溢出。 先写0x14p8(0xff)覆盖到nbytes然后溢出写传…

初次参加软考就想报高级,哪个相对容易考?

如果你想第一次参加软考时就报考高级科目&#xff0c;但是却不知道该报考高级中的哪个科目好、 ​ ​那么今天的这篇文章你一定不要错过&#xff01;首先&#xff0c;我们一起来了解一下&#xff0c;软考高级中的5个科目。 ​ ​软考高级科目 ​ 信息系统项目管理师 ​ …

记录一次postgresql临时表丢失问题

项目相关技术栈 springboot hikari连接池pgbouncerpostgresql数据库 背景 为了优化一个任务执行的速度&#xff0c;我将任务的sql中部分语句抽出生成临时表&#xff08;create temp table tempqw as xxxxxxxxx&#xff09;&#xff0c;再和其他表关联&#xff0c;提高查询速…

三翼鸟2023辉煌收官, 定盘2024高质量棋局

最近在不同平台上接连看到这样的热搜话题&#xff1a;用时间胶囊记录2023的自己、2023年度问答、2023十大网络流行语公布… 显然&#xff0c; 2023年进入最后一个月&#xff0c;时间匆匆&#xff0c;这也意味着又到了总结过去和规划未来的时候。拿到结果、取得成绩当然是对202…

算法通关村第十五关 | 白银 | 海量数据场景下的热门算法题

1.从 40 个亿中产生一个不存在的整数 可以采用位图存储数据&#xff0c;申请一个 bit 类型的数组 bitArr &#xff0c;每个位置只表示 0 或者 1 状态&#xff0c;可以将占用内存缩小为使用哈希表的 1/32 。 遍历给定的 40 亿个数&#xff0c;遇到数时就将 bitArr 相应位置设置…

短视频引流获客系统:引领未来营销的新潮流

在这个信息爆炸的时代&#xff0c;短视频已经成为了人们获取信息的主要渠道之一。而随着短视频的火爆&#xff0c;引流获客系统也逐渐成为了营销领域的新宠。本文将详细介绍短视频引流获客系统的开发流程以及涉及到的技术&#xff0c;让我们一起来看看这个引领未来营销的新潮流…

华清远见作业第二十四天

使用消息队列完成两个进程之间相互通信 代码 #include<stdio.h> #include<string.h> #include<stdlib.h> #include <sys/types.h> #include <sys/stat.h> #include <fcntl.h> #include <sys/ipc.h> #include <sys/msg.h> #in…