「MobileNet V3」70 个犬种的图片分类

✅作者简介:人工智能专业本科在读,喜欢计算机与编程,写博客记录自己的学习历程。
🍎个人主页:小嗷犬的个人主页
🍊个人网站:小嗷犬的技术小站
🥭个人信条:为天地立心,为生民立命,为往圣继绝学,为万世开太平。


本文目录

    • 数据集与 Notebook
    • 环境准备
    • 数据集
    • 可视化
    • 模型
    • 预测
    • Loss 与评价指标


数据集与 Notebook

数据集:70 Dog Breeds-Image Data Set
Notebook:「MobileNet V3」70 Dog Breeds-Image Classification


环境准备

import warnings
warnings.filterwarnings('ignore')

禁用警告,防止干扰。

!pip install lightning --quiet

安装 PyTorch Lightning。

import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as snssns.set_theme(style="darkgrid", font_scale=1.5, font="SimHei", rc={"axes.unicode_minus":False})

导入常用的库,设置绘图风格。

import torch
import torchmetrics
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets, models

导入 PyTorch 相关的库。

import lightning.pytorch as pl
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

导入 PyTorch Lightning 相关的库。

seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
pl.seed_everything(seed, workers=True)

设置随机种子。


数据集

batch_size = 64

设置批次大小。

train_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),
])test_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),
])

设置数据集的预处理。

train_dataset = datasets.ImageFolder(root="/kaggle/input/70-dog-breedsimage-data-set/train", transform=train_transform)
val_dataset = datasets.ImageFolder(root="/kaggle/input/70-dog-breedsimage-data-set/valid", transform=test_transform)
test_dataset = datasets.ImageFolder(root="/kaggle/input/70-dog-breedsimage-data-set/test", transform=test_transform)

读取数据集。

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

加载数据集。


可视化

class_names = train_dataset.classes
class_count = [train_dataset.targets.count(i) for i in range(len(class_names))]
df = pd.DataFrame({"Class": class_names, "Count": class_count})plt.figure(figsize=(12, 20), dpi=100)
sns.barplot(x="Count", y="Class", data=df)
plt.tight_layout()
plt.show()

绘制训练集的类别分布。

训练集的类别分布

plt.figure(figsize=(12, 20), dpi=100)
images, labels = next(iter(val_loader))
for i in range(8):ax = plt.subplot(8, 4, i + 1)plt.imshow(images[i].permute(1, 2, 0).numpy())plt.title(class_names[labels[i]])plt.axis("off")
plt.tight_layout()
plt.show()

绘制训练集的样本。

训练集的样本


模型

class LitModel(pl.LightningModule):def __init__(self, num_classes=1000):super().__init__()self.model = models.mobilenet_v3_large(weights="IMAGENET1K_V2")# for param in self.model.parameters():#     param.requires_grad = Falseself.model.classifier[3] = nn.Linear(self.model.classifier[3].in_features, num_classes)self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)self.precision = torchmetrics.Precision(task="multiclass", average="macro", num_classes=num_classes)self.recall = torchmetrics.Recall(task="multiclass", average="macro", num_classes=num_classes)self.f1score = torchmetrics.F1Score(task="multiclass", num_classes=num_classes)def forward(self, x):x = self.model(x)return xdef configure_optimizers(self):optimizer = optim.Adam(self.parameters(), lr=0.001, betas=(0.9, 0.99), eps=1e-08, weight_decay=1e-5)return optimizerdef training_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)self.log("train_loss", loss, on_step=True, on_epoch=False, prog_bar=True, logger=True)self.log_dict({"train_acc": self.accuracy(y_hat, y),"train_prec": self.precision(y_hat, y),"train_recall": self.recall(y_hat, y),"train_f1score": self.f1score(y_hat, y),},on_step=True,on_epoch=False,logger=True,)return lossdef validation_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)self.log("val_loss", loss, on_step=False, on_epoch=True, logger=True)self.log_dict({"val_acc": self.accuracy(y_hat, y),"val_prec": self.precision(y_hat, y),"val_recall": self.recall(y_hat, y),"val_f1score": self.f1score(y_hat, y),},on_step=False,on_epoch=True,logger=True,)def test_step(self, batch, batch_idx):x, y = batchy_hat = self(x)self.log_dict({"test_acc": self.accuracy(y_hat, y),"test_prec": self.precision(y_hat, y),"test_recall": self.recall(y_hat, y),"test_f1score": self.f1score(y_hat, y),})def predict_step(self, batch, batch_idx, dataloader_idx=None):x, y = batchy_hat = self(x)preds = torch.argmax(y_hat, dim=1)return preds

定义模型。

num_classes = len(class_names)
model = LitModel(num_classes=num_classes)
logger = CSVLogger("./")
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=5, verbose=False, mode="min"
)
trainer = pl.Trainer(max_epochs=20,enable_progress_bar=True,logger=logger,callbacks=[early_stop_callback],deterministic=True,
)
trainer.fit(model, train_loader, val_loader)

训练模型。

trainer.test(model, val_loader)

测试模型。


预测

pred = trainer.predict(model, test_loader)
pred = torch.cat(pred, dim=0)
pred = pd.DataFrame(pred.numpy(), columns=["Class"])
pred["Class"] = pred["Class"].apply(lambda x: class_names[x])plt.figure(figsize=(12, 20), dpi=100)
sns.countplot(y="Class", data=pred)
plt.tight_layout()
plt.show()

绘制预测结果的类别分布。

预测结果的类别分布


Loss 与评价指标

log_path = logger.log_dir + "/metrics.csv"
metrics = pd.read_csv(log_path)
x_name = "epoch"plt.figure(figsize=(8, 6), dpi=100)
sns.lineplot(x=x_name, y="train_loss", data=metrics, label="Train Loss", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="val_loss", data=metrics, label="Valid Loss", linewidth=2, marker="X", markersize=12)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.tight_layout()
plt.show()plt.figure(figsize=(14, 12), dpi=100)plt.subplot(2,2,1)
sns.lineplot(x=x_name, y="train_acc", data=metrics, label="Train Accuracy", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="val_acc", data=metrics, label="Valid Accuracy", linewidth=2, marker="X", markersize=12)
plt.xlabel("Epoch")
plt.ylabel("Accuracy")plt.subplot(2,2,2)
sns.lineplot(x=x_name, y="train_prec", data=metrics, label="Train Precision", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="val_prec", data=metrics, label="Valid Precision", linewidth=2, marker="X", markersize=12)
plt.xlabel("Epoch")
plt.ylabel("Precision")plt.subplot(2,2,3)
sns.lineplot(x=x_name, y="train_recall", data=metrics, label="Train Recall", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="val_recall", data=metrics, label="Valid Recall", linewidth=2, marker="X", markersize=12)
plt.xlabel("Epoch")
plt.ylabel("Recall")plt.subplot(2,2,4)
sns.lineplot(x=x_name, y="train_f1score", data=metrics, label="Train F1-Score", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="val_f1score", data=metrics, label="Valid F1-Score", linewidth=2, marker="X", markersize=12)
plt.xlabel("Epoch")
plt.ylabel("F1-Score")plt.tight_layout()
plt.show()

绘制 Loss 与评价指标的变化。

Loss

评价指标

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/156587.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

每日一题 2216. 美化数组的最少删除数(中等,贪心)

贪心,一开始可能会觉得如果删除前面一个相等的元素时,会导致后面的元素前移,造成产生更多的相等的元素对的情况但是在遍历过程中至少要在相等元素对中删除一个,也可以同时删除两个使得后面的元素奇偶关系不变,但是显然…

【C++上层应用】5. 文件和流

文章目录 【 1. 打开文件 】1.1 open 函数1.2 open 多种模式的结合使用 【 2. 关闭文件 】【 3. 写入 & 读取文件 】【 4. 文件位置指针 】 和 iostream 库中的 cin 标准输入流和 cout 标准输出流类似,C中另一个库 fstream 也存在文件的读取流和标准写入流。fst…

可视化大屏时代的到来:智慧城市管理的新思路

随着科技的不断发展,智能芯片作为一种新型的电子元件,被广泛应用于各个领域,其中智慧芯片可视化大屏是一种重要的应用形式。 一、智慧芯片可视化大屏的优势 智慧芯片可视化大屏是一种将智能芯片与大屏幕显示技术相结合的产品,山海…

从算法到应用:直播美颜滤镜SDK的全面解读与评测

直播美颜滤镜SDK技术逐渐成为直播平台不可或缺的一环。本文将对直播美颜滤镜SDK进行全面解读,深入探讨其算法原理和应用效果,并通过评测分析展现其在直播领域的实际价值。 一、算法原理解读 直播美颜滤镜的背后是复杂而精密的算法,旨在提升…

React结合antd5实现整个表格编辑

通过react hooks 结合antd的table实现整个表格新增编辑。 引入组件依赖 import React, { useState } from react; import { Table, InputNumber, Button, Space, Input } from antd;定义数据 const originData [{ key: 1, name: 白银会员, value: 0, equity: 0, reward: 0…

头歌 MySQL数据库 - 初识MySQL

本章内容是为了完成老师布置的作业,同时也是为了以后考试的时候方便复习。 数据库部分一条一条的写,可鼠标手动粘贴,除特定命令外未分大小写。 第1关:创建数据库 在操作数据库之前,需要连接它,输入命令&a…

怎么让NetCore接口支持Json参数

项目:NetCore Web API 接口支持Json参数需要安装Newtonsoft.Json.Linq和Microsoft.AspNetCore.Mvc.NewtonsoftJson Program代码 //支持json需要安装Microsoft.AspNetCore.Mvc.NewtonsoftJson using Newtonsoft.Json.Serialization;var builder WebApplication.Cr…

【C/PTA】函数专项练习(一)

本文结合PTA专项练习带领读者掌握函数,刷题为主注释为辅,在代码中理解思路,其它不做过多叙述。 目录 6-1 输出星期名6-2 三整数最大值6-3 数据排序6-4 多项式求值 6-1 输出星期名 请编写函数,根据星期数输出对应的星期名。 函数原…

【LeetCode刷题】--12.整数转罗马数字

12.整数转罗马数字 方法:模拟 分析罗马数字的规则是:对于罗马数字从左到右的每一位,选择尽可能大的符号值 根据罗马数字的唯一表示法,为了表示一个给定的整数num,寻找不超过num的最大符号值,将num减去该符…

CyNix

CyNix 一、主机发现和端口扫描 主机发现,靶机地址192.168.80.146 arp-scan -l端口扫描,只开放了80和6688端口 nmap -A -p- -sV 192.168.80.146二、信息收集 访问80端口 路径扫描 gobuster dir -u http://192.168.80.146/ -w /usr/share/wordlists/dir…

C++之内建函数对象

C之内建函数对象 算术仿函数 #include<iostream> using namespace std; #include<functional>//内建函数对象头文件 //内建函数对象 算术仿函数void test() {// negate 一元仿函数 取反仿函数negate<int>n;cout << n(100) << endl;//plus 二元仿…

软件测试/人工智能丨互联网大厂内的人工智能测试

互联网公司在人工智能&#xff08;AI&#xff09;测试方面一直处于不断发展和演变的状态。互联网公司人工智能测试目前趋势&#xff1a; 自动化测试的重要性增加&#xff1a; 随着人工智能应用的不断增多&#xff0c;互联网公司越来越意识到自动化测试的重要性。自动化测试框架…

可用于短期风速预测及光伏预测的LSTM/ELM预测程序

微❤关注“电气仔推送”获得资料&#xff08;专享优惠&#xff09; 程序内容&#xff1a; 该程序是预测类的基础性代码&#xff0c;程序对河北某地区的气象数据进行详细统计&#xff0c;程序最终得到pm2.5的预测结果&#xff0c;通过更改数据很容易得到风速预测结果。程序主要…

高精度人像背景分割SDK技术解决方案

图像处理技术已经成为企业和个人生活中不可或缺的一部分&#xff0c;特别是在人像处理方面&#xff0c;如何准确、高效地将人物与背景分离&#xff0c;一直是一个技术难题。然而&#xff0c;美摄科技凭借其在AI深度学习领域的深厚积累&#xff0c;推出了一款高精度的人像背景分…

报错:HikariPool-1 - Exception during pool initialization.

问题发现&#xff1a; 原本可以运行的springboot2项目突然无法运行且报错&#xff0c;HikariPool-1 - Exception during pool initialization。 问题分析&#xff1a; 观察报错信息发现是JDBC连接失败&#xff0c;进而搜索HikariPool-1&#xff0c;搜索得知应该是applicatio…

01-论文阅读-Deep learning for anomaly detection in log data: a survey

01-论文阅读-Deep learning for anomaly detection in log data: a survey 文章目录 01-论文阅读-Deep learning for anomaly detection in log data: a survey摘要I 介绍II 背景A 初步定义B 挑战 III 调查方法A 搜索策略B 审查的功能 IV 调查结果A 文献计量学B 深度学习技术C …

Springboot+vue的社区医院管理系统(有报告),Javaee项目,springboot vue前后端分离项目

演示视频&#xff1a; Springbootvue的社区医院管理系统(有报告)&#xff0c;Javaee项目&#xff0c;springboot vue前后端分离项目 项目介绍&#xff1a; 本文设计了一个基于Springbootvue的前后端分离的应急物资管理系统&#xff0c;采用M&#xff08;model&#xff09;V&am…

el-form动态表单动态验证(先验证不为空,再验证长度在20以内,最后向后台发送请求验证账号是否重复)

data(){var checkSno (rule, value, callback) > {if (!value) {callback(new Error("请输入账号"));} else if (value.length > 20) {callback(new Error("长度为1-20"));} else {if (this.form.id) {// 修改时检查账号是否重复selectLoginId({ sn…

美国国家安全实验室员工详细数据在网上泄露

一个从事出于政治动机的攻击的网络犯罪组织破坏了爱达荷国家实验室&#xff08;INL&#xff09;的人力资源应用程序&#xff0c;该组织周日在电报上发帖称&#xff0c;已获得该核研究实验室员工的详细信息。 黑客组织 SiegedSec 表示&#xff0c;它已经访问了“数十万用户、员…

JMeter —— 接口自动化测试(数据驱动)

前言 之前我们的用例数据都是配置在HTTP请求中&#xff0c;每次需要增加&#xff0c;修改用例都需要打开JMeter重新编辑&#xff0c;当用例越来越多的时候&#xff0c;用例维护起来就越来越麻烦&#xff0c;有没有好的方法来解决这种情况呢&#xff1f;我们可以将用例的数据存…