「ResNet-18」70 个犬种的图片分类

✅作者简介:人工智能专业本科在读,喜欢计算机与编程,写博客记录自己的学习历程。
🍎个人主页:小嗷犬的个人主页
🍊个人网站:小嗷犬的技术小站
🥭个人信条:为天地立心,为生民立命,为往圣继绝学,为万世开太平。


本文目录

    • 数据集与 Notebook
    • 环境准备
    • 数据集
    • 可视化
    • 模型
    • 预测
    • Loss 与评价指标


数据集与 Notebook

数据集:70 Dog Breeds-Image Data Set
Notebook:「ResNet-18」70 Dog Breeds-Image Classification


环境准备

import warnings
warnings.filterwarnings('ignore')

禁用警告,防止干扰。

!pip install lightning --quiet

安装 PyTorch Lightning。

import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as snssns.set_theme(style="darkgrid", font_scale=1.5, font="SimHei", rc={"axes.unicode_minus":False})

导入常用的库,设置绘图风格。

import torch
import torchmetrics
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets, models

导入 PyTorch 相关的库。

import lightning.pytorch as pl
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

导入 PyTorch Lightning 相关的库。

seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
pl.seed_everything(seed, workers=True)

设置随机种子。


数据集

batch_size = 64

设置批次大小。

train_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),
])test_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),
])

设置数据集的预处理。

train_dataset = datasets.ImageFolder(root="/kaggle/input/70-dog-breedsimage-data-set/train", transform=train_transform)
val_dataset = datasets.ImageFolder(root="/kaggle/input/70-dog-breedsimage-data-set/valid", transform=test_transform)
test_dataset = datasets.ImageFolder(root="/kaggle/input/70-dog-breedsimage-data-set/test", transform=test_transform)

读取数据集。

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

加载数据集。


可视化

class_names = train_dataset.classes
class_count = [train_dataset.targets.count(i) for i in range(len(class_names))]
df = pd.DataFrame({"Class": class_names, "Count": class_count})plt.figure(figsize=(12, 20), dpi=100)
sns.barplot(x="Count", y="Class", data=df)
plt.tight_layout()
plt.show()

绘制训练集的类别分布。

训练集的类别分布

plt.figure(figsize=(12, 20), dpi=100)
images, labels = next(iter(val_loader))
for i in range(8):ax = plt.subplot(8, 4, i + 1)plt.imshow(images[i].permute(1, 2, 0).numpy())plt.title(class_names[labels[i]])plt.axis("off")
plt.tight_layout()
plt.show()

绘制训练集的样本。

训练集的样本


模型

class LitModel(pl.LightningModule):def __init__(self, num_classes=1000):super().__init__()self.model = models.resnet18(weights="IMAGENET1K_V1")# for param in self.model.parameters():#     param.requires_grad = Falseself.model.fc = nn.Linear(self.model.fc.in_features, num_classes)self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)self.precision = torchmetrics.Precision(task="multiclass", average="macro", num_classes=num_classes)self.recall = torchmetrics.Recall(task="multiclass", average="macro", num_classes=num_classes)self.f1score = torchmetrics.F1Score(task="multiclass", num_classes=num_classes)def forward(self, x):x = self.model(x)return xdef configure_optimizers(self):optimizer = optim.Adam(self.parameters(), lr=0.001, betas=(0.9, 0.99), eps=1e-08, weight_decay=1e-5)return optimizerdef training_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)self.log("train_loss", loss, on_step=True, on_epoch=False, prog_bar=True, logger=True)self.log_dict({"train_acc": self.accuracy(y_hat, y),"train_prec": self.precision(y_hat, y),"train_recall": self.recall(y_hat, y),"train_f1score": self.f1score(y_hat, y),},on_step=True,on_epoch=False,logger=True,)return lossdef validation_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)self.log("val_loss", loss, on_step=False, on_epoch=True, logger=True)self.log_dict({"val_acc": self.accuracy(y_hat, y),"val_prec": self.precision(y_hat, y),"val_recall": self.recall(y_hat, y),"val_f1score": self.f1score(y_hat, y),},on_step=False,on_epoch=True,logger=True,)def test_step(self, batch, batch_idx):x, y = batchy_hat = self(x)self.log_dict({"test_acc": self.accuracy(y_hat, y),"test_prec": self.precision(y_hat, y),"test_recall": self.recall(y_hat, y),"test_f1score": self.f1score(y_hat, y),})def predict_step(self, batch, batch_idx, dataloader_idx=None):x, y = batchy_hat = self(x)preds = torch.argmax(y_hat, dim=1)return preds

定义模型。

num_classes = len(class_names)
model = LitModel(num_classes=num_classes)
logger = CSVLogger("./")
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=5, verbose=False, mode="min"
)
trainer = pl.Trainer(max_epochs=20,enable_progress_bar=True,logger=logger,callbacks=[early_stop_callback],deterministic=True,
)
trainer.fit(model, train_loader, val_loader)

训练模型。

trainer.test(model, val_loader)

测试模型。


预测

pred = trainer.predict(model, test_loader)
pred = torch.cat(pred, dim=0)
pred = pd.DataFrame(pred.numpy(), columns=["Class"])
pred["Class"] = pred["Class"].apply(lambda x: class_names[x])plt.figure(figsize=(12, 20), dpi=100)
sns.countplot(y="Class", data=pred)
plt.tight_layout()
plt.show()

绘制预测结果的类别分布。

预测结果的类别分布


Loss 与评价指标

log_path = logger.log_dir + "/metrics.csv"
metrics = pd.read_csv(log_path)
x_name = "epoch"plt.figure(figsize=(8, 6), dpi=100)
sns.lineplot(x=x_name, y="train_loss", data=metrics, label="Train Loss", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="val_loss", data=metrics, label="Valid Loss", linewidth=2, marker="X", markersize=12)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.tight_layout()
plt.show()plt.figure(figsize=(14, 12), dpi=100)plt.subplot(2,2,1)
sns.lineplot(x=x_name, y="train_acc", data=metrics, label="Train Accuracy", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="val_acc", data=metrics, label="Valid Accuracy", linewidth=2, marker="X", markersize=12)
plt.xlabel("Epoch")
plt.ylabel("Accuracy")plt.subplot(2,2,2)
sns.lineplot(x=x_name, y="train_prec", data=metrics, label="Train Precision", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="val_prec", data=metrics, label="Valid Precision", linewidth=2, marker="X", markersize=12)
plt.xlabel("Epoch")
plt.ylabel("Precision")plt.subplot(2,2,3)
sns.lineplot(x=x_name, y="train_recall", data=metrics, label="Train Recall", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="val_recall", data=metrics, label="Valid Recall", linewidth=2, marker="X", markersize=12)
plt.xlabel("Epoch")
plt.ylabel("Recall")plt.subplot(2,2,4)
sns.lineplot(x=x_name, y="train_f1score", data=metrics, label="Train F1-Score", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="val_f1score", data=metrics, label="Valid F1-Score", linewidth=2, marker="X", markersize=12)
plt.xlabel("Epoch")
plt.ylabel("F1-Score")plt.tight_layout()
plt.show()

绘制 Loss 与评价指标的变化。

Loss

评价指标

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/155261.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LeetCode【45】跳跃游戏2

题目: 思路: 注意和跳跃游戏【55】不同的是,题目保证可以跳到nums[n-1];那么每次跳到最大即可 代码: public class LeetCode45 {public static int jump(int[] nums) {int jumps 0;int currentEnd 0;int farthest 0;for(int…

Android Serializable / Parcelable

Serializable 序列化,将对象转为二进制序列 Parcelable 不是序列化,属于进程间通信,不需要IO/操作,没有拷贝内存的操作, Object -> ShareMemory -> Object 不需要IO,使用内存共享等方式 Kotlin inline fun 内联函数 TCP协议将数据包拆分,进行发送,保证网络数据的可…

基于纹理特征的kmeas聚类的图像分割方案

Gabor滤波器简介 在图像处理中,以Dennis Gabor命名的Gabor滤波器是一种用于纹理分析的线性滤波器,本质上是指在分析点或分析区域周围的局部区域内,分析图像中是否存在特定方向的特定频率内容。Gabor滤波器的频率和方向表示被许多当代视觉科学…

二十一、数组(3)

本章概要 Arrays的setAll方法增量生成 Arrays的setAll方法 在Java 8中, 在RaggedArray.java 中引入并在 ArrayOfGenerics.java.Array.setAll() 中重用。它使用一个生成器并生成不同的值,可以选择基于数组的索引元素(通过访问当前索引&…

Android项目更新依赖和打包步骤和问题汇总

目录 1、Android 项目打包,32位包升级到64位包问题一:ERROR: Conflicting configuration : armeabi-v7a,x86-64,arm64-v8a,x86 in ndk abiFilters cannot be present when splits abi filters are set : x86,armeabi-v7a,arm64-v8a 2、Android项目依赖升…

pytest-base-url插件之配置可选的项目系统URL

前言 ①当我们的自动化代码完成之后,通常期望可以在不同的环境进行测试,此时可以将项目系统的URL单独拿出来,并且可以通过pytest.ini配置文件和支持pytest命令行方式执行。 ② pytest-base-url 是一个简单的pytest插件,它通过命…

【数据结构】HashMap 和 HashSet

目录 1.哈希表概念 2冲突 2.1概念 2.2 冲突-避免 2.3冲突-避免-哈希函数设计 2.4 冲突-避免-负载因子调节 ​编辑 2.5 冲突-解决-开散列/哈希桶 2.5冲突严重时的解决办法 3.实现 4.性能分析 5.与Java集合类的关系 1.哈希表概念 在顺序结构中,元素关键码和存…

【vue+eltable】修改表格滚动条样式

<style lang"scss" scoped> ::v-deep .el-table__body-wrapper::-webkit-scrollbar {width: 10px; /*纵向滚动条的宽度*/height: 10px; /*横向滚动条的高度*/ } /*定义滚动条轨道 内阴影圆角*/ ::v-deep .el-table__body-wrapper::-webkit-scrollbar-track {bo…

Java 多线程之 volatile(可见性/重排序)

文章目录 一、概述二、使用方法三、测试程序3.1 验证可见性的示例3.2 验证指令重排序的示例 一、概述 在Java中&#xff0c;volatile 关键字用于修饰变量&#xff0c;其作用是确保多个线程之间对该变量的可见性和禁止指令重排序优化。 当一个变量被声明为volatile时&#xff0…

高德地图点击搜索触发输入提示

减少调用次数&#xff0c;不用每输入一次调用一次&#xff0c;输入完后再触发搜索 效果图&#xff1a; ![Alt](https://img-home.csdnimg.cn/images/20220524100510.png dom结构 <div class"seach"><van-searchshow-actionv-model"addressVal"…

【使用vscode在线web搭建开发环境--code-server搭建】

官方版本下载 https://github.com/coder/code-server/releases?q4.0.0&expandedtrue使用大于版本3.8.0,因为旧版本有插件市场不能访问的情况版本太高需要更新环境依赖 拉取安装包 []# wget "https://github.com/coder/code-server/releases/download/v4.0.0/code-…

探访九牧绿色黑灯工厂,找寻“科技卫浴 世界九牧”的答案

文 | 螳螂观察 作者 | 余一 你所想象中的工厂是怎么样的&#xff1f;灯火通明、人声鼎沸、人来人往&#xff1f;如果告诉你一座工厂既没有灯&#xff0c;也没有人&#xff0c;但却还在持续生产&#xff0c;你会不会觉得这是不可思议的事&#xff1f; 如果不是亲眼见证&#…

Simulink 自动代码生成:手写代码替换生成代码Code Replacement Tool使用

目录 前言 代码替换库操作步骤 代码生成验证 总结 前言 在实际工程开发过程中&#xff0c;Simulink生成的代码都是构建的算法实现的&#xff0c;纯软件实现&#xff0c;生成的代码大多也是直接在CPU上运行的。实际还有一些MCU集成了像Cordic&#xff0c;协处理器等。有些代…

WinEdt 11.1编辑器的尝鲜体验

WinEdt 11.1编辑器的尝鲜体验 2023年5月19日&#xff0c;WinEdt 11.1版本发布了&#xff0c;相比WinEdt 10.3, 最新版更加漂亮&#xff0c;更加友好&#xff0c;更加好用了&#xff01; 最大的改变是WinEdt 11.1 有了自带的WinEdtPDF阅读器&#xff0c;所以不再需要下载第三方…

ros2工作空间

我们先不管ros2工作空间是什么样子的&#xff0c;如果是我自己来搞一个工作空间&#xff0c;我一定是这样安排 一个文件夹用来放自己存放的文件&#xff0c;。。。。。。。。。。对应src文件夹 一个文件夹用来放编译后的文件&#xff0c;。。。。。。。。。。。对应intall文件…

2024测试工程师必学的Jmeter:利用jmeter插件收集性能测试结果汇总报告和聚合报告

利用jmeter插件收集性能测试结果 汇总报告&#xff08;Summary Report &#xff09; 用来收集性能测试过程中的请求以及事务各项指标。通过监听器--汇总报告 可以添加该元件。界面如下图所示 汇总报告界面介绍&#xff1a; 所有数据写入一个文件&#xff1a;保存测试结果到本地…

ZYNQ_project:LCD

模块框图&#xff1a; 时序图&#xff1a; 代码&#xff1a; /* // 24h000000 4324 9Mhz 480*272 // 24h800000 7084 33Mhz 800*480 // 24h008080 7016 50Mhz 1024*600 // 24h000080 4384 33Mhz 800*480 // 24h800080 1018 70Mhz 1280*800 */ module rd_id(i…

解决java在idea运行正常,但是打成jar包后中文乱码问题

目录 比如&#xff1a; 打包命令使用utf-8编码&#xff1a; 1.当在idea中编写的程序,运行一切正常.但是当被打成jar包时,执行的程序会中文乱码.产生问题的原因和解决方案是什么呢? 一.问题分析 分别使用idea和jar包形式打印出System中所有的jvm参数---代码如下: public static…

【设计模式】行为型设计模式

行为型设计模式 文章目录 行为型设计模式一、概述二、责任链模式&#xff08;Chain of Responsibility Pattern&#xff09;三、命令模式&#xff08;Command Pattern&#xff09;四、解释器模式&#xff08;Interpreter Pattern&#xff09;五、迭代器模式&#xff08;Iterato…

Stable Diffusion专场公开课

从SD原理、本地部署到其二次开发 分享时间&#xff1a;11月25日14&#xff1a;00-17&#xff1a;00 分享大纲 从扩散模型DDPM起步理解SD背后原理 SD的本地部署:在自己电脑上快速搭建、快速出图如何基于SD快速做二次开发(以七月的AIGC模特生成系统为例) 分享人简介 July&#…