Pytorch-CNN轴承故障一维信号分类(二)

目录

前言

1 数据集制作与加载

1.1 导入数据

1.2 数据加载,训练数据、测试数据分组,数据分batch

2 CNN-2D分类模型和训练、评估

2.1 定义CNN-2d分类模型

2.2 定义模型参数

2.3 模型结构

2.4 模型训练

2.5 模型评估

3 CNN-1D分类模型和训练、评估

3.1 定义CNN-1d分类模型

3.2 定义模型参数

3.3 模型结构

3.4 模型训练

3.5 模型评估

4 模型对比


往期精彩内容:

Python-凯斯西储大学(CWRU)轴承数据解读与分类处理

Python轴承故障诊断 (一)短时傅里叶变换STFT

Python轴承故障诊断 (二)连续小波变换CWT-CSDN博客

Python轴承故障诊断 (三)经验模态分解EMD-CSDN博客

Python轴承故障诊断 (四)基于EMD-CNN的故障分类-CSDN博客

Python轴承故障诊断 (五)基于EMD-LSTM的故障分类-CSDN博客

Pytorch-LSTM轴承故障一维信号分类(一)-CSDN博客

前言

本文基于凯斯西储大学(CWRU)轴承数据,先经过数据预处理进行数据集的制作和加载,最后通过Pytorch实现CNN模型一维卷积和二维卷积对故障数据的分类,然后进行对比。凯斯西储大学轴承数据的详细介绍可以参考下文:

Python-凯斯西储大学(CWRU)轴承数据解读与分类处理

1 数据集制作与加载

1.1 导入数据

参考之前的文章,进行故障10分类的预处理,凯斯西储大学轴承数据10分类数据集:

第一步,导入十分类数据

import numpy as np
import pandas as pd
from scipy.io import loadmatfile_names = ['0_0.mat','7_1.mat','7_2.mat','7_3.mat','14_1.mat','14_2.mat','14_3.mat','21_1.mat','21_2.mat','21_3.mat']for file in file_names:# 读取MAT文件data = loadmat(f'matfiles\\{file}')print(list(data.keys()))

第二步,读取MAT文件驱动端数据

# 采用驱动端数据
data_columns = ['X097_DE_time', 'X105_DE_time', 'X118_DE_time', 'X130_DE_time', 'X169_DE_time','X185_DE_time','X197_DE_time','X209_DE_time','X222_DE_time','X234_DE_time']
columns_name = ['de_normal','de_7_inner','de_7_ball','de_7_outer','de_14_inner','de_14_ball','de_14_outer','de_21_inner','de_21_ball','de_21_outer']
data_12k_10c = pd.DataFrame()
for index in range(10):# 读取MAT文件data = loadmat(f'matfiles\\{file_names[index]}')dataList = data[data_columns[index]].reshape(-1)data_12k_10c[columns_name[index]] = dataList[:119808]  # 121048  min: 121265
print(data_12k_10c.shape)
data_12k_10c

第三步,制作数据集

train_set、val_set、test_set 均为按照7:2:1划分训练集、验证集、测试集,最后保存数据

第四步,制作训练集和标签

# 制作数据集和标签
import torch# 这些转换是为了将数据和标签从Pandas数据结构转换为PyTorch可以处理的张量,
# 以便在神经网络中进行训练和预测。def make_data_labels(dataframe):'''参数 dataframe: 数据框返回 x_data: 数据集     torch.tensory_label: 对应标签值  torch.tensor'''# 信号值x_data = dataframe.iloc[:,0:-1]# 标签值y_label = dataframe.iloc[:,-1]x_data = torch.tensor(x_data.values).float()y_label = torch.tensor(y_label.values.astype('int64')) # 指定了这些张量的数据类型为64位整数,通常用于分类任务的类别标签return x_data, y_label# 加载数据
train_set = load('train_set')
val_set = load('val_set')
test_set = load('test_set')# 制作标签
train_xdata, train_ylabel = make_data_labels(train_set)
val_xdata, val_ylabel = make_data_labels(val_set)
test_xdata, test_ylabel = make_data_labels(test_set)
# 保存数据
dump(train_xdata, 'trainX_1024_10c')
dump(val_xdata, 'valX_1024_10c')
dump(test_xdata, 'testX_1024_10c')
dump(train_ylabel, 'trainY_1024_10c')
dump(val_ylabel, 'valY_1024_10c')
dump(test_ylabel, 'testY_1024_10c')

1.2 数据加载,训练数据、测试数据分组,数据分batch

import torch
from joblib import dump, load
import torch.utils.data as Data
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
# 参数与配置
torch.manual_seed(100)  # 设置随机种子,以使实验结果具有可重复性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 有GPU先用GPU训练# 加载数据集
def dataloader(batch_size, workers=2):# 训练集train_xdata = load('trainX_1024_10c')train_ylabel = load('trainY_1024_10c')# 验证集val_xdata = load('valX_1024_10c')val_ylabel = load('valY_1024_10c')# 测试集test_xdata = load('testX_1024_10c')test_ylabel = load('testY_1024_10c')# 加载数据train_loader = Data.DataLoader(dataset=Data.TensorDataset(train_xdata, train_ylabel),batch_size=batch_size, shuffle=True, num_workers=workers, drop_last=True)val_loader = Data.DataLoader(dataset=Data.TensorDataset(val_xdata, val_ylabel),batch_size=batch_size, shuffle=True, num_workers=workers, drop_last=True)test_loader = Data.DataLoader(dataset=Data.TensorDataset(test_xdata, test_ylabel),batch_size=batch_size, shuffle=True, num_workers=workers, drop_last=True)return train_loader, val_loader, test_loaderbatch_size = 32
# 加载数据
train_loader, val_loader, test_loader = dataloader(batch_size)

2 CNN-2D分类模型和训练、评估

2.1 定义CNN-2d分类模型

注意:输入数据进行了堆叠 ,把一个1*1024 的序列 进行划分堆叠成形状为1 * 32 * 32, 就使输入序列的长度降下来了,(channels, seq_length, H_in)

2.2 定义模型参数

# 定义模型参数
batch_size = 32
# 先用浅层试一试
conv_arch = ((2, 32), (1, 64), (1, 128))  
input_channels = 1
num_classes = 10
model = CNN2DModel(conv_arch, num_classes, batch_size)  
# 定义损失函数和优化函数
model = model.to(device)
loss_function = nn.CrossEntropyLoss(reduction='sum')  # loss
learn_rate = 0.0003
optimizer = torch.optim.Adam(model.parameters(), learn_rate)  # 优化器

2.3 模型结构

2.4 模型训练

训练结果

50个epoch,准确率将近97%,CNN-2D网络分类模型效果良好。

2.5 模型评估

# 模型 测试集 验证  
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 有GPU先用GPU训练# 加载模型
model =torch.load('best_model_cnn2d.pt')
# model = torch.load('best_model_cnn2d.pt', map_location=torch.device('cpu'))# 将模型设置为评估模式
model.eval()
# 使用测试集数据进行推断
with torch.no_grad():correct_test = 0test_loss = 0for test_data, test_label in test_loader:test_data, test_label = test_data.to(device), test_label.to(device)test_output = model(test_data)probabilities = F.softmax(test_output, dim=1)predicted_labels = torch.argmax(probabilities, dim=1)correct_test += (predicted_labels == test_label).sum().item()loss = loss_function(test_output, test_label)test_loss += loss.item()test_accuracy = correct_test / len(test_loader.dataset)
test_loss = test_loss / len(test_loader.dataset)
print(f'Test Accuracy: {test_accuracy:4.4f}  Test Loss: {test_loss:10.8f}')Test Accuracy: 0.9313  Test Loss: 0.04866932

3 CNN-1D分类模型和训练、评估

3.1 定义CNN-1d分类模型

注意:与2d模型的信号长度堆叠不同,CNN-1D模型直接在一维序列上进行卷积池化操作;形状为(batch,H_in, seq_length),利用平均池化 使CNN-1D和CNN-2D模型最后输出维度相同,保持着相近的参数量。

3.2 定义模型参数

# 定义模型参数
batch_size = 32
# 先用浅层试一试
conv_arch = ((2, 32), (1, 64), (1, 128))  
input_channels = 1
num_classes = 10
model = CNN1DModel(conv_arch, num_classes, batch_size)  
# 定义损失函数和优化函数
model = model.to(device)
loss_function = nn.CrossEntropyLoss(reduction='sum')  # loss
learn_rate = 0.0003
optimizer = torch.optim.Adam(model.parameters(), learn_rate)  # 优化器

3.3 模型结构

3.4 模型训练

训练结果

50个epoch,准确率将近95%,CNN-1D网络分类模型效果良好。

3.5 模型评估

# 模型 测试集 验证  
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 有GPU先用GPU训练# 加载模型
model =torch.load('best_model_cnn1d.pt')
# model = torch.load('best_model_cnn2d.pt', map_location=torch.device('cpu'))# 将模型设置为评估模式
model.eval()
# 使用测试集数据进行推断
with torch.no_grad():correct_test = 0test_loss = 0for test_data, test_label in test_loader:test_data, test_label = test_data.to(device), test_label.to(device)test_output = model(test_data)probabilities = F.softmax(test_output, dim=1)predicted_labels = torch.argmax(probabilities, dim=1)correct_test += (predicted_labels == test_label).sum().item()loss = loss_function(test_output, test_label)test_loss += loss.item()test_accuracy = correct_test / len(test_loader.dataset)
test_loss = test_loss / len(test_loader.dataset)
print(f'Test Accuracy: {test_accuracy:4.4f}  Test Loss: {test_loss:10.8f}')Test Accuracy: 0.9185  Test Loss: 0.14493044

4 模型对比

对比CNN-2D模型 和CNN-1D模型:

模型参数量训练集准确率验证集准确率测试集准确率
CNN1D61565496.5694.6491.85
CNN2D68343098.3896.8893.13

由于CNN-2D模型参数量稍微多一点,所以模型表现得也略好一点,适当调整参数,两者模型准确率相近。但是CNN-2D推理速度要快于CNN-1D,在轴承故障数据集上,应该更考虑CNN-2D模型在堆叠后的一维信号上进行卷积池化。

注意调整参数:

  • 可以适当增加 CNN层数 和每层神经元个数,微调学习率;

  • 增加更多的 epoch (注意防止过拟合)

  • 可以改变一维信号堆叠的形状(设置合适的长度和维度)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/220575.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

flink找不到隐式项

增加 import org.apache.flink.streaming.api.scala._ 即可

改进的A*算法的路径规划(2)

子节点优化选择策略 (1)子节点选择方式 为了找到从起始点到终点的路径,需定义一种可以选择后续节点的方式。在 A*算法中两种常见的方法为4-邻接(见图5-7(a) 和8-邻接(见图5-7(b)), 但考虑到 在复杂越野环境上,我们希望智能车辆允许更多的自由运动来更…

MSF学习

之前的渗透测试中 其实很少用到 cs msf 但是在实际内网的时候 可以发现 msf cs 都是很好用的 所以现在我来学习一下 msf的使用方法 kali自带msf https://www.cnblogs.com/bmjoker/p/10051014.html 使用 msfconsole 启动即可 首先就是最正常的木马生成 所以这里其实只需…

分类预测 | Matlab实现DBO-SVM蜣螂算法优化支持向量机的数据分类预测【23年新算法】

分类预测 | Matlab实现DBO-SVM蜣螂算法优化支持向量机的数据分类预测【23年新算法】 目录 分类预测 | Matlab实现DBO-SVM蜣螂算法优化支持向量机的数据分类预测【23年新算法】分类效果基本描述程序设计参考资料 分类效果 基本描述 1.Matlab实现DBO-SVM蜣螂算法优化支持向量机的…

数理统计基础:参数估计与假设检验

在学习机器学习的过程中,我充分感受到概率与统计知识的重要性,熟悉相关概念思想对理解各种人工智能算法非常有意义,从而做到知其所以然。因此打算写这篇笔记,先好好梳理一下参数估计与假设检验的相关内容。 1 总体梳理 先从整体结…

串口通信(4)-C#串口通信入门实例

本文通过实例讲解C#串口通信。 入门实例设计一个串口助手,能够很好的涵盖串口要点的使用。 目录 一、成品图 二、界面文件 三、后台代码 四、实例中要点 一、成品图 如下: 实现的过程 创建winform项目,将Form1文件的名称改为MainForm&…

Windows汇编调用printf

VS2022 汇编 项目右键 生成依赖项 生成自定义 勾选masm 链接器 高级 入口点 main X86 .686 .model flat,stdcall option casemap:none includelib ucrt.lib includelib legacy_stdio_definitions.libEXTERN printf:proc.data szFormat db %s,0 szStr db hello,0.code main…

关于职场伪勤奋

前段时间看了一些关于勤奋学习、职场成长类的书籍,就在思考勤奋学习和职场的关系时,结合个人的理解,我定义了一种勤奋叫职场“伪勤奋”。那关于职场“伪勤奋”的定义和理解,与大家分享: 1、选择性任务完成 伪勤奋特征…

vue 图片等比例缩放上传

需求:上传图片之前按比例缩小图片分辨率,宽高不超过1920不处理图片,宽高超过1920则缩小图片分辨率,如果是一张图片请参考这篇博客:js实现图片压缩、分辨率等比例缩放 我根据这篇博主的分享,写下了我的循环上…

HarmonyOS使用Web组件

Web组件的使用 1 概述 相信大家都遇到过这样的场景,有时候我们点击应用的页面,会跳转到一个类似浏览器加载的页面,加载完成后,才显示这个页面的具体内容,这个加载和显示网页的过程通常都是浏览器的任务。 ArkUI为我…

chatGPT 国内版,嵌入midjourney AI创作工具

聊天GPT国内入口,免切网直达,可直接多语言对话,操作简单,无需复杂注册,智能高效,即刻使用.可以用作个人助理,学习助理,智能创作、新媒体文案创作、智能创作等各种应用场景! 地址: https://ai.wboat.cn/

【51单片机系列】直流电机使用

本文是关于直流电机使用的相关介绍。 文章目录 一、直流电机介绍二、ULN2003芯片介绍三、在proteus中仿真实现对电机的驱动 51单片机的应用中,电机控制方面的应用也很多。在学习直流电机(PWM)之前,先使用GPIO控制电机的正反转和停止。但不能直接使用GPIO…

06 python 文件基础操作

6.1 .1文件读取操作 演示对文件的读取 # 打开文件 import timef open(02_word.txt, r, encoding"UTF-8") print(type(f))# #读取文件 - read() # print(f读取10个字节的结果{f.read(10)}) # print(f读取全部字节的结果{f.read()})# #读取文件 - readLines() # lines…

面试官:说说你对 linux 用户管理的理解?相关的命令有哪些?

面试官:说说你对 linux 用户管理的理解?相关的命令有哪些? 一、是什么 Linux是一个多用户的系统,允许使用者在系统上通过规划不同类型、不同层级的用户,并公平地分配系统资源与工作环境 而与 Windows 系统最大的不同…

基于MyBatis二级缓存深入装饰器模式

视频地址 学习文档 文章目录 一、示意代码二、装饰器三、经典案例—MyBatis二级缓存1、Cache 标准定义2、PerpetualCache 基础实现3、增强实现3-1、ScheduledCache3-2、LruCache 先来说说我对装饰器理解:当你有一个基础功能的代码,但你想在不改变原来代…

高效营销系统集成:百度营销的API无代码解决方案,提升电商与广告效率

百度营销API连接:构建无代码开发的高效集成体系 在数字营销的高速发展时代,企业追求的是快速响应市场的能力以及提高用户运营的效率。百度营销API连接正是为此而生,它通过无代码开发的方式,实现了电商平台、营销系统和CRM的一站式…

墒情监测FDS-400 土壤温湿电导率盐分传感器

墒情监测FDS-400 土壤温湿电导率盐分传感器产品概述 土壤温度部分是由精密铂电阻和高精度变送器两部分组成。变送器部分由电源模块、温度传感模块、变送模块、温度补偿模块及数据处理模块等组成,解决铂电阻因自身特点导入的测量误差,变送器内有零漂电路…

Redis队列原理解析:让你的应用程序运行更加稳定!

一、消息队列简介 消息队列(Message Queue),字面意思就是存放消息的队列。最简单的消息队列模型包括 3 个角色: 消息队列:存储和管理消息,也被称为消息代理(Message Broker)生产者…

Turtle绘制菱形-第11届蓝桥杯选拔赛Python真题精选

[导读]:超平老师的Scratch蓝桥杯真题解读系列在推出之后,受到了广大老师和家长的好评,非常感谢各位的认可和厚爱。作为回馈,超平老师计划推出《Python蓝桥杯真题解析100讲》,这是解读系列的第16讲。 Turtle绘制菱形&a…

六.聚合函数

聚合函数 1.什么是聚合函数1.1AVG和SUM函数1.2MIN和MAX函数1.3COUNT函数 2.GROUP BY2.1基本使用2.2使用多个列分组2.3GROUP BY中使用WITH ROLLUP 3.HAVING3.1基本使用3.2WHERE和HAVING的区别 4.SELECT的执行过程4.1查询的结构4.2SELECT执行顺序4.3SQL执行原理 1.什么是聚合函数…