Pytorch-Transformer轴承故障一维信号分类(三)

目录

前言

1 数据集制作与加载

1.1 导入数据

第一步,导入十分类数据

第二步,读取MAT文件驱动端数据

第三步,制作数据集

第四步,制作训练集和标签

1.2 数据加载,训练数据、测试数据分组,数据分batch

2 Transformer分类模型和超参数选取

2.1 定义Transformer分类模型,采用Transformer架构中的编码器:

2.2 定义模型参数

2.3 模型结构

3 Transformer模型训练与评估

3.1 模型训练

3.2 模型评估


往期精彩内容:

Python-凯斯西储大学(CWRU)轴承数据解读与分类处理

Python轴承故障诊断 (一)短时傅里叶变换STFT

Python轴承故障诊断 (二)连续小波变换CWT

Python轴承故障诊断 (三)经验模态分解EMD

Python轴承故障诊断 (四)基于EMD-CNN的故障分类

Python轴承故障诊断 (五)基于EMD-LSTM的故障分类

Pytorch-LSTM轴承故障一维信号分类(一)

Pytorch-CNN轴承故障一维信号分类(二)

前言

本文基于凯斯西储大学(CWRU)轴承数据,先经过数据预处理进行数据集的制作和加载,最后通过Pytorch实现Transformer模型对故障数据的分类,并介绍Transformer模型的超参数。凯斯西储大学轴承数据的详细介绍可以参考下文:

Python-凯斯西储大学(CWRU)轴承数据解读与分类处理

1 数据集制作与加载

1.1 导入数据

参考之前的文章,进行故障10分类的预处理,凯斯西储大学轴承数据10分类数据集:

第一步,导入十分类数据

import numpy as np
import pandas as pd
from scipy.io import loadmatfile_names = ['0_0.mat','7_1.mat','7_2.mat','7_3.mat','14_1.mat','14_2.mat','14_3.mat','21_1.mat','21_2.mat','21_3.mat']for file in file_names:# 读取MAT文件data = loadmat(f'matfiles\\{file}')print(list(data.keys()))

第二步,读取MAT文件驱动端数据

# 采用驱动端数据
data_columns = ['X097_DE_time', 'X105_DE_time', 'X118_DE_time', 'X130_DE_time', 'X169_DE_time','X185_DE_time','X197_DE_time','X209_DE_time','X222_DE_time','X234_DE_time']
columns_name = ['de_normal','de_7_inner','de_7_ball','de_7_outer','de_14_inner','de_14_ball','de_14_outer','de_21_inner','de_21_ball','de_21_outer']
data_12k_10c = pd.DataFrame()
for index in range(10):# 读取MAT文件data = loadmat(f'matfiles\\{file_names[index]}')dataList = data[data_columns[index]].reshape(-1)data_12k_10c[columns_name[index]] = dataList[:119808]  # 121048  min: 121265
print(data_12k_10c.shape)
data_12k_10c

第三步,制作数据集

train_set、val_set、test_set 均为按照7:2:1划分训练集、验证集、测试集,最后保存数据

第四步,制作训练集和标签

# 制作数据集和标签
import torch# 这些转换是为了将数据和标签从Pandas数据结构转换为PyTorch可以处理的张量,
# 以便在神经网络中进行训练和预测。def make_data_labels(dataframe):'''参数 dataframe: 数据框返回 x_data: 数据集     torch.tensory_label: 对应标签值  torch.tensor'''# 信号值x_data = dataframe.iloc[:,0:-1]# 标签值y_label = dataframe.iloc[:,-1]x_data = torch.tensor(x_data.values).float()y_label = torch.tensor(y_label.values.astype('int64')) # 指定了这些张量的数据类型为64位整数,通常用于分类任务的类别标签return x_data, y_label# 加载数据
train_set = load('train_set')
val_set = load('val_set')
test_set = load('test_set')# 制作标签
train_xdata, train_ylabel = make_data_labels(train_set)
val_xdata, val_ylabel = make_data_labels(val_set)
test_xdata, test_ylabel = make_data_labels(test_set)
# 保存数据
dump(train_xdata, 'trainX_1024_10c')
dump(val_xdata, 'valX_1024_10c')
dump(test_xdata, 'testX_1024_10c')
dump(train_ylabel, 'trainY_1024_10c')
dump(val_ylabel, 'valY_1024_10c')
dump(test_ylabel, 'testY_1024_10c')

1.2 数据加载,训练数据、测试数据分组,数据分batch

import torch
from joblib import dump, load
import torch.utils.data as Data
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
# 参数与配置
torch.manual_seed(100)  # 设置随机种子,以使实验结果具有可重复性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 有GPU先用GPU训练# 加载数据集
def dataloader(batch_size, workers=2):# 训练集train_xdata = load('trainX_1024_10c')train_ylabel = load('trainY_1024_10c')# 验证集val_xdata = load('valX_1024_10c')val_ylabel = load('valY_1024_10c')# 测试集test_xdata = load('testX_1024_10c')test_ylabel = load('testY_1024_10c')# 加载数据train_loader = Data.DataLoader(dataset=Data.TensorDataset(train_xdata, train_ylabel),batch_size=batch_size, shuffle=True, num_workers=workers, drop_last=True)val_loader = Data.DataLoader(dataset=Data.TensorDataset(val_xdata, val_ylabel),batch_size=batch_size, shuffle=True, num_workers=workers, drop_last=True)test_loader = Data.DataLoader(dataset=Data.TensorDataset(test_xdata, test_ylabel),batch_size=batch_size, shuffle=True, num_workers=workers, drop_last=True)return train_loader, val_loader, test_loaderbatch_size = 32
# 加载数据
train_loader, val_loader, test_loader = dataloader(batch_size)

2 Transformer分类模型和超参数选取

2.1 定义Transformer分类模型,采用Transformer架构中的编码器:

注意:输入数据进行了堆叠 ,把一个1*1024 的序列 进行划分堆叠成形状为 32 * 32, 就使输入序列的长度降下来了

2.2 定义模型参数

# 模型参数
input_dim = 32 # 输入维度
hidden_dim = 512  # 注意力维度
output_dim  = 10  # 输出维度
num_layers = 4   # 编码器层数
num_heads = 8    # 多头注意力头数
batch_size = 32
# 模型
model = TransformerModel(input_dim, output_dim, hidden_dim, num_layers, num_heads, batch_size)  
model = model.to(device)
loss_function = nn.CrossEntropyLoss(reduction='sum')  # loss
learn_rate = 0.0003
optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)  # 优化器

2.3 模型结构

3 Transformer模型训练与评估

3.1 模型训练

训练结果

100个epoch,准确率将近90%,Transformer模型分类效果良好,参数过拟合了,适当调整模型参数,降低模型复杂度,还可以进一步提高分类准确率。

注意调整参数:

  • 可以适当增加 Transforme编码器层数 和隐藏层的维度,微调学习率;

  • 调整多头注意力的头数,增加更多的 epoch (注意防止过拟合)

  • 可以改变一维信号堆叠的形状(设置合适的长度和维度)

3.2 模型评估

# 模型 测试集 验证  
import torch.nn.functional as F# 加载模型
model =torch.load('best_model_transformer.pt')
# model = torch.load('best_model_cnn2d.pt', map_location=torch.device('cpu'))# 将模型设置为评估模式
model.eval()
# 使用测试集数据进行推断
with torch.no_grad():correct_test = 0test_loss = 0for test_data, test_label in test_loader:test_data, test_label = test_data.to(device), test_label.to(device)test_output = model(test_data)probabilities = F.softmax(test_output, dim=1)predicted_labels = torch.argmax(probabilities, dim=1)correct_test += (predicted_labels == test_label).sum().item()loss = loss_function(test_output, test_label)test_loss += loss.item()test_accuracy = correct_test / len(test_loader.dataset)
test_loss = test_loss / len(test_loader.dataset)
print(f'Test Accuracy: {test_accuracy:4.4f}  Test Loss: {test_loss:10.8f}')Test Accuracy: 0.9570  Test Loss: 0.12100271

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/215382.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

据房间Id是否存在,判断当前房间是否到期且实时更改颜色

重点代码展示&#xff1a; <template><el-col style"width: 100%;height: 100%;"><el-col :span"20"><el-card class"room_info"><avue-data-icons :option"option"></avue-data-icons></el-…

RT-DETR算法优化改进:轻量化自研设计双卷积重新设计backbone和neck,完成涨点且计算量和参数量显著下降

💡💡💡本文自研创新改进:双卷积由组卷积和异构卷积组成,执行 33 和 11 卷积运算代替其他卷积核仅执行 11 卷积,YOLOv8 Conv,从而轻量化RT-DETR,性能如下表,GFLOPs 8.1降低至7.6,参数量6.3MB降低至5.8MB RT-DETR魔术师专栏介绍: https://blog.csdn.net/m0_637742…

ubuntu-c++-可执行模块-动态链接库-链接库搜索-基础知识

文章目录 1.动态链接库简介2.动态库搜索路径3.运行时链接及搜索顺序4.查看可运行模块的链接库5.总结 1.动态链接库简介 动态库又叫动态链接库&#xff0c;是程序运行的时候加载的库&#xff0c;当动态链接库正确安装后&#xff0c;所有的程序都可以使用动态库来运行程序。动态…

Android帝国之日志系统--logd、logcat

本文概要 这是Android系统进程系列的第四篇文章&#xff0c;本文以自述的方式来介绍logd进程&#xff0c;通过本文您将了解到logd进程存在的意义&#xff0c;以及日志系统的实现原理。&#xff08;文中的代码是基于android13&#xff09; Android系统进程系列的前三篇文章如下…

如何在OpenWRT软路由系统部署uhttpd搭建web服务器实现远程访问——“cpolar内网穿透”

文章目录 前言1. 检查uhttpd安装2. 部署web站点3. 安装cpolar内网穿透4. 配置远程访问地址5. 配置固定远程地址 前言 uhttpd 是 OpenWrt/LuCI 开发者从零开始编写的 Web 服务器&#xff0c;目的是成为优秀稳定的、适合嵌入式设备的轻量级任务的 HTTP 服务器&#xff0c;并且和…

docker-compose的介绍与使用

一、docker-compose 常用命令和指令 1. 概要 默认的模板文件是 docker-compose.yml&#xff0c;其中定义的每个服务可以通过 image 指令指定镜像或 build 指令&#xff08;需要 Dockerfile&#xff09;来自动构建。 注意如果使用 build 指令&#xff0c;在 Dockerfile 中设置…

RHEL网络服务器

目录 1.时间同步的重要性 2.配置时间服务器 &#xff08;1&#xff09;指定所使用的上层时间服务器。 (2&#xff09;指定允许访问的客户端 (3&#xff09;把local stratum 前的注释符#去掉。 3.配置chrony客户端 &#xff08;1&#xff09;修改pool那行,指定要从哪台时间…

Python常见面试知识总结(一):迭代器、拷贝、线程及底层结构

前言&#xff1a; Hello大家好&#xff0c;我是Dream。 今天来总结一下Python和C语言中常见的面试知识&#xff0c;欢迎大家一起前来探讨学习~ 【一】Python中迭代器的概念&#xff1f; 可迭代对象是迭代器、生成器和装饰器的基础。简单来说&#xff0c;可以使用for来循环遍历…

[古剑山2023] pwn

最近这个打stdout的题真多。这个比赛没打。拿到附件作了一天。 choice 32位&#xff0c;libc-2.23-i386&#xff0c;nbytes初始值为0x14,读入0x804A04C 0x14字节后会覆盖到nbytes 1个字节。当再次向v1读入nbytes字节时会造成溢出。 先写0x14p8(0xff)覆盖到nbytes然后溢出写传…

记录一次postgresql临时表丢失问题

项目相关技术栈 springboot hikari连接池pgbouncerpostgresql数据库 背景 为了优化一个任务执行的速度&#xff0c;我将任务的sql中部分语句抽出生成临时表&#xff08;create temp table tempqw as xxxxxxxxx&#xff09;&#xff0c;再和其他表关联&#xff0c;提高查询速…

三翼鸟2023辉煌收官, 定盘2024高质量棋局

最近在不同平台上接连看到这样的热搜话题&#xff1a;用时间胶囊记录2023的自己、2023年度问答、2023十大网络流行语公布… 显然&#xff0c; 2023年进入最后一个月&#xff0c;时间匆匆&#xff0c;这也意味着又到了总结过去和规划未来的时候。拿到结果、取得成绩当然是对202…

短视频引流获客系统:引领未来营销的新潮流

在这个信息爆炸的时代&#xff0c;短视频已经成为了人们获取信息的主要渠道之一。而随着短视频的火爆&#xff0c;引流获客系统也逐渐成为了营销领域的新宠。本文将详细介绍短视频引流获客系统的开发流程以及涉及到的技术&#xff0c;让我们一起来看看这个引领未来营销的新潮流…

华清远见作业第二十四天

使用消息队列完成两个进程之间相互通信 代码 #include<stdio.h> #include<string.h> #include<stdlib.h> #include <sys/types.h> #include <sys/stat.h> #include <fcntl.h> #include <sys/ipc.h> #include <sys/msg.h> #in…

k8s一键部署uniswap

1、拉取uniswap源码 github地址 2、编写Dockerfile并打镜像 # Set the base image FROM node:18.10.0# WORKDIR /usr/src/app/ WORKDIR /home/gateway# Copy files COPY ./ /home/gateway/# Dockerfile author / maintainer LABEL maintainer"Michael Feng <mikehummi…

Java最全面试题专题---2、Java集合容器(2)

Map接口 说一下 HashMap 的实现原理&#xff1f; HashMap概述&#xff1a; HashMap是基于哈希表的Map接口的非同步实现。此实现提供所有可选的映射操作&#xff0c;并允许使用null值和null键。此类不保证映射的顺序&#xff0c;特别是它不保证该顺序恒久不变。 HashMap的数据…

C语言-枚举

常量符号化 用符号而不是具体的数字来表示程序中的数字 枚举 用枚举而不是定义独立的const int变量 枚举是一种用户定义的数据类型&#xff0c;他用关键词enum以如下语法来声明&#xff1a; enum枚举类型名字{名字0&#xff0c;…&#xff0c;名字n}&#xff1b; 枚举类型名…

外包干了3年,技术退步太明显了。。。。。

先说一下自己的情况&#xff0c;本科生生&#xff0c;18年通过校招进入武汉某软件公司&#xff0c;干了差不多3年的功能测试&#xff0c;今年国庆&#xff0c;感觉自己不能够在这样下去了&#xff0c;长时间呆在一个舒适的环境会让一个人堕落!而我已经在一个企业干了四年的功能…

6_CSS布局之浮动的应用

day06_CSS布局之浮动的应用 本课目标&#xff08;Objective&#xff09; 理解什么是浮动掌握浮动的三种机制掌握浮动的案例应用 1 CSS 布局的三种机制 CSS 提供了 3 种机制来设置盒子的摆放位置&#xff0c;分别是普通流&#xff08;标准流&#xff09;、浮动和定位。 普通流…

HarmonyOS开发:回调实现网络的拦截

前言 基于http封装的一个网络库&#xff0c;里面有一个知识点&#xff0c;在初始化的时候&#xff0c;可以设置请求头拦截和请求错误后的信息的拦截&#xff0c;具体案例如下&#xff1a; et.getInstance().init({netErrorInterceptor: new MyNetErrorInterceptor(), //设置全…

web网络安全

web安全 一&#xff0c;xss 跨站脚本攻击(全称Cross Site Scripting,为和CSS&#xff08;层叠样式表&#xff09;区分&#xff0c;简称为XSS)是指恶意攻击者在Web页面中插入恶意javascript代码&#xff08;也可能包含html代码&#xff09;&#xff0c;当用户浏览网页之时&…