深度学习笔记_7经典网络模型LSTM解决FashionMNIST分类问题

1、 调用模型库,定义参数,做数据预处理

import numpy as np
import torch
from torchvision.datasets import FashionMNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc
import matplotlib.pyplot as plt# 检查 GPU 可用性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)# 设置超参数
sequence_length = 28
input_size = 28  
hidden_size = 128
num_layers = 2 
num_classes = 10
batch_size = 64
learning_rate = 0.001
num_epochs = 50# 定义数据转换操作
transform = transforms.Compose([transforms.RandomRotation(degrees=[-30, 30]),   # 随机旋转transforms.RandomHorizontalFlip(),   # 随机水平翻转transforms.RandomCrop(size=28, padding=4),   # 随机裁剪transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),   # 颜色抖动transforms.ToTensor(),  # 将图像转换为张量transforms.Normalize((0.5,), (0.5,))
])

2、下载FashionMNIST训练集

# 下载FashionMNIST训练集
trainset = FashionMNIST(root='data', train=True,download=True, transform=transform)# 下载FashionMNIST测试集
testset = FashionMNIST(root='data', train=False,download=True, transform=transform)# 创建 DataLoader 对象
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)

3、定义LSTM模型

# 定义LSTM模型
class LSTM(nn.Module):def __init__(self, input_size, hidden_size, num_layers, num_classes):super(LSTM, self).__init__()self.hidden_size = hidden_size  # LSTM隐含层神经元数self.num_layers = num_layers  # LSTM层数self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)  # LSTM层self.fc = nn.Linear(hidden_size, num_classes)  # 全连接层def forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)  # 初始化状态c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)out, _ = self.lstm(x, (h0, c0))  # LSTM前向传播out = self.fc(out[:, -1, :])  # 只取序列最后一个时间步的输出return F.log_softmax(out, dim=1)  # 使用log_softmax作为输出# 初始化模型、优化器和损失函数
model = LSTM(input_size, hidden_size, num_layers, num_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()# 记录训练和测试过程中的损失和准确率
train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []conf_matrix_list = []
accuracy_list = []
error_rate_list = []
precision_list = []
recall_list = []
f1_score_list = []
roc_auc_list = []

4、 训练循环

for epoch in range(num_epochs):model.train()train_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):optimizer.zero_grad()data, target = data.to(device), target.to(device)  # 将数据移到 GPU 上data = data.view(-1, sequence_length, input_size)output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()train_loss += loss.item()# 计算训练准确率_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()# 计算平均训练损失和训练准确率train_loss /= len(train_loader)train_accuracy = 100. * correct / totaltrain_losses.append(train_loss)train_accuracies.append(train_accuracy)# 测试模型model.eval()test_loss = 0.0correct = 0all_labels = []all_preds = []with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)  # 将数据移到 GPU 上data = data.view(-1, sequence_length, input_size)output = model(data)test_loss += criterion(output, target).item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()all_labels.extend(target.cpu().numpy())  # 将结果移到 CPU 上all_preds.extend(pred.cpu().numpy())  # 将结果移到 CPU 上# 计算平均测试损失和测试准确率test_loss /= len(test_loader)test_accuracy = 100. * correct / len(test_loader.dataset)test_losses.append(test_loss)test_accuracies.append(test_accuracy)# 计算额外的指标conf_matrix = confusion_matrix(all_labels, all_preds)conf_matrix_list.append(conf_matrix)accuracy = accuracy_score(all_labels, all_preds)accuracy_list.append(accuracy)error_rate = 1 - accuracyerror_rate_list.append(error_rate)precision = precision_score(all_labels, all_preds, average='weighted')recall = recall_score(all_labels, all_preds, average='weighted')f1 = f1_score(all_labels, all_preds, average='weighted')precision_list.append(precision)recall_list.append(recall)f1_score_list.append(f1)fpr, tpr, thresholds = roc_curve(all_labels, all_preds, pos_label=1)roc_auc = auc(fpr, tpr)roc_auc_list.append(roc_auc)# 打印每个 epoch 的指标print(f'Epoch [{epoch + 1}/{num_epochs}] -> Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')
# 打印或绘制训练后的最终指标
print(f'Final Confusion Matrix:\n{conf_matrix_list[-1]}')
print(f'Final Accuracy: {accuracy_list[-1]:.2%}')
print(f'Final Error Rate: {error_rate_list[-1]:.2%}')
print(f'Final Precision: {precision_list[-1]:.2%}')
print(f'Final Recall: {recall_list[-1]:.2%}')
print(f'Final F1 Score: {f1_score_list[-1]:.2%}')
print(f'Final ROC AUC: {roc_auc_list[-1]:.2%}')

5、绘制Loss、Accuracy曲线图, 计算混淆矩阵

import seaborn as sns
# 绘制Loss曲线图
plt.figure()
plt.plot(train_losses, label='Train Loss', color='blue')
plt.plot(test_losses, label='Test Loss', color='red')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curve')
plt.grid(True)
plt.savefig('loss_curve.png')
plt.show()# 绘制Accuracy曲线图
plt.figure()
plt.plot(train_accuracies, label='Train Accuracy', color='red')  # 绘制训练准确率曲线
plt.plot(test_accuracies, label='Test Accuracy', color='green')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy Curve')
plt.grid(True)
plt.savefig('accuracy_curve.png')
plt.show()# 计算混淆矩阵
class_labels = [str(i) for i in range(10)]
confusion_mat = confusion_matrix(all_labels, all_preds)
plt.figure()
sns.heatmap(confusion_mat, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.savefig('confusion_matrix.png')
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/232161.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

是什么导致了我孩子的听力损失?

是什么导致了我孩子的听力损失? 有些婴儿天生就有听力损失,这被称为先天性听力损失。许多不同的因素都可能导致这种类型的听力损失,但并不总是能够确定确切的原因。在大约一半的病例中,原因是遗传的,也就是说&#xff…

深度学习 tensorflow基础介绍

深度学习是一种基于人工神经网络的机器学习方法,其目标是通过模仿人脑的结构和功能,实现对大量复杂数据的学习和理解。它可以在图像识别、语音识别、自然语言处理等领域取得惊人的成就。 深度学习的引入引出了TensorFlow,它是一个由Google Br…

DBeaver Ultimate for Mac/win:掌握数据库的终极利器,助您高效管理数据!

在当今数字化时代,数据管理变得越来越重要。而作为一款功能强大的数据库管理工具,DBeaver Ultimate(简称DBU)助您轻松应对各种复杂的数据管理任务。无论您是数据库管理员、开发人员还是数据分析师,DBU都能为您提供全面…

带你学C语言~指针(2)

目录 🏉前言 🚀 数组名的理解 🚀使用指针访问数组 ✈一维数组传参的本质 ✈冒泡排序 🏆二级指针 🏆指针数组 🏆指针数组模拟二维数组 🎉结束语 🏉前言 上一章,小…

关于“Python”的核心知识点整理大全28

目录 11.1.5 添加新测试 11.2 测试类 11.2.1 各种断言方法 unittestModule中的断言方法: ​编辑11.2.2 一个要测试的类 survey.py language_survey.py 11.2.3 测试 AnonymousSurvey 类 test_survey.py 往期快速传送门👆(在文章最后&…

计算机操作系统-第十九天

目录 调度器/调度程序 闲逛进程 调度器/调度程序 ②、③由调度程序引起,调度程序决定了:让谁运行(调度算法)运行多长时间(时间片大小) 调度时机(什么事件会触发”调度程序“)&…

proxysql读写分离组件部署

一、前言 在mysql一主两从架构的前提下,引入读写分离组件,可以极大的提高mysql性能,proxysql可以在高可用mysql架构发生主从故障时,进行自动的主从读写节点切换,即当mysql其他从节点当选新的主节点时,proxy…

HuatuoGPT

文章目录 HuatuoGPT 模型介绍LLM4Med(医疗大模型)的作用ChatGPT 存在的问题HuatuoGPT的特点ChatGPT 与真实医生的区别解决方案用于SFT阶段的混合数据基于AI反馈的RL 评估单轮问答多轮问答人工评估 HuatuoGPT 模型介绍 HuatuoGPT(华佗GPT&…

Elasticsearch 向量相似搜索

Elasticsearch 向量相似搜索的原理涉及使用密集向量(dense vector)来表示文档,并通过余弦相似性度量来计算文档之间的相似性。以下是 Elasticsearch 向量相似搜索的基本原理: 向量表示文档: 文档的文本内容经过嵌入模型(如BERT、Word2Vec等)处理,得到一个密集向量(den…

Semaphore 详解

1、Semaphore 是什么 Semaphore 通常我们叫它信号量, 可以用来控制同时访问特定资源的线程数量,通过协调各个线程,以保证合理的使用资源。 可以把它简单的理解成我们停车场入口立着的那个显示屏,每有一辆车进入停车场显示屏就会…

JDK各个版本特性讲解-JDK13特性

JDK各个版本特性讲解-JDK13特性 一、JAVA13概述二、语法层面特性1.switch表达式(预览)2.文本块(预览)2.1 概念2.2 问题2.3 目标2.4 语法细节1 基本使用2.5 语法细节2 编译器在编译时,会删除多余的空格2.6 语法细节3 转义字符2.7 语法细节4 文本块连接 三、API层次特性1.重新实现…

13、Kafka副本机制详解

Kafka 副本机制详解 1、副本定义2、副本角色3、In-sync Replicas(ISR)4、Unclean 领导者选举(Unclean Leader Election) 所谓的副本机制(Replication),也可以称之为备份机制,通常是指…

为什么我的对话框创建失败了?菜鸟错误1

对话框中的资源要么被定义为一个整数&#xff0c;要么被定义为一个字符串。 仅仅一个简单的错误将会将其中的一个类型错误的变成另一个类型。我们来看一个例子。 >> 请移步至 www.topomel.com 以查看图片 << 你是否能发现其中的两处 “菜鸟级错误” ? 如果先获…

Elasticsearch:生成 AI 中的微调与 RAG

在自然语言处理 (NLP) 领域&#xff0c;出现了两种卓越的技术&#xff0c;每种技术都有其独特的功能&#xff1a;微调大型语言模型 (LLM) 和 RAG&#xff08;检索增强生成&#xff09;。 这些方法极大地影响了我们利用语言模型的方式&#xff0c;使它们更加通用和有效。 在本文…

Linux系统管理、服务器设置、安全、云数据中心

前言 「作者主页」&#xff1a;雪碧有白泡泡 「个人网站」&#xff1a;雪碧的个人网站 我们来快速了解liunx命令 文章目录 前言解析命令提示符linux的文件和目录文件和目录管理文件操作 进程管理命令系统管理网络管理 书籍推荐 本文以服务器最常用的CentOS为例 解析命令提示…

2024年完整湖北等保测评机构名单看这里!

等保测评机构是指经公安部认证的具有资质的测评机构&#xff0c;主要从事等级测评活动。一般过等保需要找正规具有资质的等保测评机构。那你知道2024年湖北等保测评机构有哪些&#xff1f;名单有吗&#xff1f; 2024年完整湖北等保测评机构名单看这里&#xff01; 1、湖北星…

接口测试【断言设置思路】实操

1 断言设置思路 这里总结了我在项目中常用的5种断言方式&#xff0c;基本可能满足90%以上的断言场景&#xff0c;具体参见如下脑图&#xff1a; 在这里插入图片描述 下面分别解释一下图中的五种思路&#xff1a; 1&#xff09; 响应码 对于http类接口&#xff0c;有时开发人…

无损编码——Slepian-Wolf理论

在信息论中&#xff0c;无损编码是一种重要的编码技术&#xff0c;其目的是通过尽量少的比特数来表示一段信息&#xff0c;同时保证信息的完整性和准确性。传统的无损编码方法往往只考虑单个源的编码问题&#xff0c;比如哈夫曼编码和算术编码等。然而&#xff0c;在实际应用中…

RTK、PPP与RTK-PPP?一文带您认识高精定位及如何进行高精定位GNSS测试!(一)

来源&#xff1a;德思特测试测量 德思特干货丨RTK、PPP与RTK-PPP&#xff1f;一文带您认识高精定位及如何进行高精定位GNSS测试&#xff01;&#xff08;一&#xff09; 原文链接&#xff1a;https://mp.weixin.qq.com/s/6Jb3DuJEhRGqFPrH3CX8xQ 欢迎关注虹科&#xff0c;为您…

#HarmonyOS:项目结构图

.hvigor&#xff1a;存储构建配置文件信息 .idea&#xff1a;存储项目的配置信息 AppScope&#xff1a;全局的共有资源存放目录