深度学习笔记_6经典预训练网络LeNet-18解决FashionMNIST数据集

1、 调用模型库,定义参数,做数据预处理

import numpy as np
import torch
from torchvision.datasets import FashionMNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc
import matplotlib.pyplot as plt
from torchvision import models# 检查 GPU 可用性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)# 设置超参数
train_batch_size = 64
test_batch_size = 64
learning_rate = 0.001
num_epochs = 50# 定义数据转换操作
transform = transforms.Compose([transforms.RandomRotation(degrees=[-30, 30]),   # 随机旋转transforms.RandomHorizontalFlip(),   # 随机水平翻转transforms.Resize((224, 224)),  # 调整图像大小transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),   # 颜色抖动transforms.ToTensor(),  # 将图像转换为张量transforms.Normalize((0.5,), (0.5,))
])

2、下载FashionMNIST训练集

# 下载FashionMNIST训练集
trainset = FashionMNIST(root='data', train=True,download=True, transform=transform)# 下载FashionMNIST测试集
testset = FashionMNIST(root='data', train=False,download=True, transform=transform)# 创建 DataLoader 对象
train_loader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(testset, batch_size=test_batch_size, shuffle=False)

3、使用预训练的ResNet-18模型

# 使用预训练的ResNet-18模型
model = models.resnet18(pretrained=True)
# 修改最后一层,使其适应FashionMNIST的输出类别数
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.fc = nn.Linear(model.fc.in_features, 10)
model = model.to(device)# 冻结预训练模型的参数
for param in model.parameters():param.requires_grad = False# 只训练模型的最后一层
for param in model.fc.parameters():param.requires_grad = True
# 初始化优化器和损失函数
optimizer = optim.Adam(model.fc.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

4、 训练循环

# 记录训练和测试过程中的损失和准确率
train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []
conf_matrix_list = []
accuracy_list = []
error_rate_list = []
precision_list = []
recall_list = []
f1_score_list = []
roc_auc_list = []# 训练循环
for epoch in range(num_epochs):model.train()train_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):optimizer.zero_grad()data, target = data.to(device), target.to(device)  # 将数据移到 GPU 上output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()train_loss += loss.item()# 计算训练准确率_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()# 计算平均训练损失和训练准确率train_loss /= len(train_loader)train_accuracy = 100. * correct / totaltrain_losses.append(train_loss)train_accuracies.append(train_accuracy)# 测试模型model.eval()test_loss = 0.0correct = 0all_labels = []all_preds = []with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)  # 将数据移到 GPU 上output = model(data)test_loss += criterion(output, target).item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()all_labels.extend(target.cpu().numpy())  # 将结果移到 CPU 上all_preds.extend(pred.cpu().numpy())  # 将结果移到 CPU 上# 计算平均测试损失和测试准确率test_loss /= len(test_loader)test_accuracy = 100. * correct / len(test_loader.dataset)test_losses.append(test_loss)test_accuracies.append(test_accuracy)# 计算额外的指标conf_matrix = confusion_matrix(all_labels, all_preds)conf_matrix_list.append(conf_matrix)accuracy = accuracy_score(all_labels, all_preds)accuracy_list.append(accuracy)error_rate = 1 - accuracyerror_rate_list.append(error_rate)precision = precision_score(all_labels, all_preds, average='weighted')recall = recall_score(all_labels, all_preds, average='weighted')f1 = f1_score(all_labels, all_preds, average='weighted')precision_list.append(precision)recall_list.append(recall)f1_score_list.append(f1)fpr, tpr, thresholds = roc_curve(all_labels, all_preds, pos_label=1)roc_auc = auc(fpr, tpr)roc_auc_list.append(roc_auc)# 打印每个 epoch 的指标print(f'Epoch [{epoch + 1}/{num_epochs}] -> Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

5、绘制Loss、Accuracy曲线图, 计算混淆矩阵

import seaborn as sns
# 绘制Loss曲线图
plt.figure()
plt.plot(train_losses, label='Train Loss', color='blue')
plt.plot(test_losses, label='Test Loss', color='red')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curve')
plt.grid(True)
plt.show()# 绘制Accuracy曲线图
plt.figure()
plt.plot(train_accuracies, label='Train Accuracy', color='red')  # 绘制训练准确率曲线
plt.plot(test_accuracies, label='Test Accuracy', color='green')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy Curve')
plt.grid(True)
plt.show()# 计算混淆矩阵
confusion_mat = confusion_matrix(all_labels, all_preds)
class_labels = [str(i) for i in range(10)]
plt.figure()
sns.heatmap(confusion_mat, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.savefig('confusion_matrix.png')
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/230280.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Redis——Redis常用命令

Redis提供了丰富的命令,可以对数据库和各种数据类型进行操作,这些命令可以在Windows和Linux中使用。 1、键值相关命令 1.1、KEYS KEYS用于返回满足pattern的所有key,pattern支持以下通配符: *:匹配任意字符。&…

Python教程81:函数的位置参数、默认参数、动态参数、关键字参数(入门必看)

1.形式参数(Formal Parameters)和实际参数(Actual Parameters)是函数或方法定义和调用过程中的两个重要概念。举个例子,在下面的greet函数中,当我们调用greet(“李白”)时,"李白"就是…

electron这样使用更安全

背景: electron大家平时为了方便使用,或是一些网上demo的引导,会让渲染进程的业务界面支持直接使用nodejs,这种开发方式有一定的安全隐患,如果业务界面因为xss之类的漏洞被注入其他代码,危害非常大&#x…

Spring之容器:IOC(3)

学习的最大理由是想摆脱平庸,早一天就多一份人生的精彩;迟一天就多一天平庸的困扰。各位小伙伴,如果您: 想系统/深入学习某技术知识点… 一个人摸索学习很难坚持,想组团高效学习… 想写博客但无从下手,急需…

华为云CodeArts Repo常见问答汇总

1.【Repo】codearts Repo最大支持上传文件大小 答&#xff1a;参考链接 https://support.huaweicloud.com/productdesc-codehub/codehub_pdtd_0005.html • 单文件上传大小限制&#xff08;评论中上传附件&#xff09;<50MB。 • 单文件上传大小限制&#xff08;代码…

某联合产权交易所持续购买监控易产品的维保服务,提升IT运维保障能力

在信息化时代&#xff0c;企业信息化的程度已经成为影响其核心竞争力的重要因素。某联合产权交易所&#xff08;以下简称“交易所”&#xff09;作为行业领导者&#xff0c;一直以来都积极推进信息化建设&#xff0c;致力于提升运维管理水平&#xff0c;以适应日益激烈的市场竞…

Rust 嵌入式开发

Rust 进行嵌入式开发: https://xxchang.github.io/book/intro/index.html # 列出所有目标平台 rustup target list# 安装目标平台工具链 rustup target add thumbv7m-none-eabi# 创建工程 cargo new demo && cd demo cargo add cortex-m-rt cargo add panic-halt carg…

二十九、获取文件属性及相关信息

二十九、获取文件属性及相关信息QFileInfo QFileInfo 提供有关文件在文件系统中的名称 位置 &#xff08;路径&#xff09;、访问权限及它是目录还是符号链接、等信息。文件的大小、最后修改/读取时间也是可用的。QFileInfo 也可以被用于获取信息有关 Qt resource . QFileInf…

科技的成就(五十四)

511、线路板按层数来分的话分为单面板&#xff0c;双面板&#xff0c;和多层线路板三个大的分类。线路板按特性来分的话分为软板(FPC)&#xff0c;硬板(PCB)&#xff0c;软硬结合板(FPCB)。是当代电子元件业中最活跃的产业&#xff0c;其行业增长速度一般都高于电子元件产业3个…

算法模板之双链表图文详解

&#x1f308;个人主页&#xff1a;聆风吟 &#x1f525;系列专栏&#xff1a;算法模板、数据结构 &#x1f516;少年有梦不应止于心动&#xff0c;更要付诸行动。 文章目录 &#x1f4cb;前言一. ⛳️使用数组模拟双链表讲解1.1 &#x1f514;为什么我们要使用数组去模拟双链表…

使用java调用python批处理将pdf转为图片

你可以使用Java中的ProcessBuilder来调用Python脚本&#xff0c;并将PDF转换为图片。以下是一个简单的Java代码示例&#xff0c;假设你的Python脚本名为pdf2img.py&#xff1a; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader…

Powershell summaries with types of scales of summaries

tiny,small,medium, large and huge scale of Powershell summaries I) many kinds of Tiny summaries of Powershell1.1) Powershell能干嘛&#xff1f; I) many kinds of Tiny summaries of Powershell 1.1) Powershell能干嘛&#xff1f; 此外&#xff0c;关于PowerShell脚…

Java数组(2)

我是南城余&#xff01;阿里云开发者平台专家博士证书获得者&#xff01; 欢迎关注我的博客&#xff01;一同成长&#xff01; 一名从事运维开发的worker&#xff0c;记录分享学习。 专注于AI&#xff0c;运维开发&#xff0c;windows Linux 系统领域的分享&#xff01; 本…

Kotlin 笔记 -- Kotlin 语言特性的理解(二)

都是编译成字节码&#xff0c;为什么 Kotlin 能支持 Java 中没有的特性&#xff1f; kotlin 有哪些 Java 中没有的特性&#xff1a; 类型推断、可变性、可空性自动拆装箱、泛型数组高阶函数、DSL顶层函数、扩展函数、内联函数伴生对象、数据类、密封类、单例类接口代理、inter…

图像与视频压缩算法

图像压缩是通过减少图像数据量来降低图像文件的大小&#xff0c;从而减少存储空间和传输带宽。有多种图像压缩算法&#xff0c;它们可以分为两大类&#xff1a;有损压缩和无损压缩。 无损压缩算法&#xff1a; Run-Length Encoding (RLE): 这是一种简单的无损压缩方法&#x…

西蒙子S7协议介绍

西门子的S7 协议&#xff0c;没有仍何关于S7协议的官方文档&#xff0c;是一个不透明的协议。关于S7的协议介绍&#xff0c;大都是非官方的一些七零八落的文档。 1. S7的通信模型 西蒙子S7 通讯遵从着基于TCP 的 Master&#xff08;client&#xff09; & Slave&#xff0…

读取小数部分

1.题目描述 2.题目分析 //假设字符串为 char arr[] "123.4500"; 1. 找到小数点位置和末尾位置 代码如下&#xff1a; char* start strchr(arr, .);//找到小数点位置char* end start strlen(start) - 1;//找到末尾位置 如果有不知道strchr()用法的同学&#xf…

Pytorch神经网络的模型架构(nn.Module和nn.Sequential的用法)

一、层和块 在构造自定义块之前&#xff0c;我们先回顾一下多层感知机的代码。下面的代码生成一个网络&#xff0c;其中包含一个具有256个单元和ReLU激活函数的全连接隐藏层&#xff0c;然后是一个具有10个隐藏单元且不带激活函数的全连接输出层。 import torch from torch im…

Linux 基本语句_16_Udp网络聊天室

代码&#xff1a; 服务端代码&#xff1a; #include <stdio.h> #include <arpa/inet.h> #include <sys/types.h> #include <sys/socket.h> #include <netinet/in.h> #include <stdlib.h> #include <unistd.h> #include <string…

中小型企业网络综合实战案例分享

实验背景 某公司总部在厦门&#xff0c;北京、上海都有分部&#xff0c;网络结构如图所示&#xff1a; 一、网络连接描述&#xff1a; 厦门总部&#xff1a;内部网络使用SW1、SW2、SW3三台交换机&#xff0c;SW1为作为核心交换机&#xff0c;SW2、SW3作为接入层交换机&#x…