EfficientNet-v2-s图像分类训练(简洁版)

使用torchvision集成的efficientnet-v2-s模型,调用torchvision库中的Oxford IIIT Pet数据集,对模型进行训练。
若有修改要求,可以修改以下部分:

train_dataset = OxfordIIITPet(root='./data', split='trainval', download=True, transform=transform_train)
test_dataset = OxfordIIITPet(root='./data', split='test', download=True, transform=transform_test)
#常见数据集可以直接加载,若是自己的数据集就自己写个dataset/dataloader
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 37)
#37为数据集类别数,修改为自己对应的
scheduler = ReduceLROnPlateau(optimizer, 'max', patience=3, factor=0.1, verbose=True)
#学习率处可以自己调整,可玩性较高

训练截图:
在这里插入图片描述

其实十轮左右就稳定在90以上了,跑了三十轮,记得修改保存路径,我这里是用kaggle跑的。
代码如下:

import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.models import efficientnet_v2_s
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import OxfordIIITPet
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm# 数据预处理 + 数据增强
transform_train = transforms.Compose([transforms.Resize((256, 256)),  # 增大图片预处理尺寸transforms.RandomCrop((224, 224)),  # 随机裁剪到模型输入尺寸transforms.RandomHorizontalFlip(),  # 随机水平翻转transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # 颜色抖动transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载数据集
train_dataset = OxfordIIITPet(root='./data', split='trainval', download=True, transform=transform_train)
test_dataset = OxfordIIITPet(root='./data', split='test', download=True, transform=transform_test)train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)# 模型定义
model = efficientnet_v2_s(pretrained=True)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 37)# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.001)
scheduler = ReduceLROnPlateau(optimizer, 'max', patience=3, factor=0.1, verbose=True)# 训练模型
def train_model(num_epochs):model.train()best_accuracy = 0for epoch in range(num_epochs):model.train()running_loss = 0.0for inputs, labels in tqdm(train_loader, desc=f'Training Epoch {epoch + 1}'):inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()# 每个 epoch 后测试accuracy = test_model()scheduler.step(accuracy)# 如果当前模型表现更好,保存模型if accuracy > best_accuracy:best_accuracy = accuracytorch.save(model.state_dict(), '/kaggle/working/best_oxford_pets_efficientnetv2.pth')print(f'New best model saved with accuracy: {best_accuracy:.2f}%')def test_model():model.eval()correct = 0total = 0with torch.no_grad():for inputs, labels in tqdm(test_loader, desc='Testing'):inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()# 调试输出if total < 50:  # 只打印前50个样本的信息print(f'Predicted: {predicted[:10]}, Labels: {labels[:10]}')accuracy = 100 * correct / totalprint(f'Testing Accuracy: {accuracy:.2f}%')return accuracytrain_model(30)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/877377.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

loguru日志模块:简化Python自动化测试的日志管理!

引言 日志是软件开发中的关键组成部分&#xff0c;为开发和测试人员提供了调试和监控应用程序的重要手段。loguru 是一个第三方的 Python 日志库&#xff0c;以其简洁的 API 和自动化的功能脱颖而出。本文将探讨为什么项目中需要日志&#xff0c;loguru 为何受到青睐&#xff…

【Python机器学习】决策树的构造——递归构建决策树

我们可以采用递归的原则处理数据集&#xff0c;递归结束的条件是&#xff1a;程序遍历完所有划分数据集的属性&#xff0c;或者每个分支下的所有实例都具有相同的分类。如果所有实例具有相同的分类&#xff0c;则得到一个叶子节点或者终止块。任何到达叶子节点的数据必然属于叶…

年化27.9%,最大回撤-13.6%的可转债因子策略,结合机器学习特征筛选(附python代码)

原创文章第603篇&#xff0c;专注“AI量化投资、世界运行的规律、个人成长与财富自由"。 我们重新更新了可转债的全量数据&#xff0c;包含全量已经退市的转债。 ——这是与股票市场不一样的地方&#xff0c;股票退市相对少&#xff0c;而转债本身就有退出周期。 因此&…

x264 环路滤波原理系列:x264_frame_deblock_row 函数

x264_frame_deblock_row 函数 功能:该函数对视频帧中的一行宏块(Macroblock)进行去块滤波处理。去块滤波是视频编码中常用的一种技术,用于减少宏块之间的边界不连续性,从而提高视频质量。 函数关系与原理图: 函数原理流程梳理: 局部变量初始化;for 循环处理每个宏块:…

如何借助低代码 + BI 实现国央企数智化转型?

概要 在当前的软件开发时代&#xff0c;许多企业面临着核心技术缺失、专业人才短缺以及产品能力单一等问题&#xff0c;迫切需要加强技术实力&#xff0c;补充和扩展原有的业务和行业能力。将技术与业务需求深度结合&#xff0c;构建适应时代需求的技术业务模式&#xff0c;成…

容易发表的医学SCI期刊推荐,附投稿经验

常笑医学整理了适合医学生、医务工作者进行论文投稿的医学SCI期刊&#xff0c;附期刊详细参数与真实投稿经验&#xff0c;供大家参考。 1、ULTRASOUND IN MEDICINE AND BIOLOGY&#xff08;超声医学和生物学&#xff09; &#xff08;详细投稿信息请点击刊物名称查看&#xff…

MATLAB被360误杀的解决方案

前面被误杀&#xff0c;今天又被误杀。 前面误杀结果是缺少文件&#xff0c;重装MATLAB也不行。 结果重装了操作系统。 这次&#xff0c;看到了提示额外小心。 当时备份了“病毒”文件&#xff0c;结果备份的也被杀了。 解铃还须系铃人 在360安全卫士里面恢复&#xff0c;步骤…

数据库管理-第225期 Oracle DB 23.5新特性一览(20240730)

数据库管理225期 2024-07-30 数据库管理-第225期 Oracle DB 23.5新特性一览&#xff08;20240730&#xff09;1 二进制向量维度格式2 RAC上的复制HNSW向量索引3 JSON集合4 JSON_ID SQL函数5 优化的通过网络对NVMe设备的Oracle的原生访问6 DBCA支持PMEM存储7 DBCA支持标准版高可…

PPT图表制作

一、表格的底纹 插入→表格→绘制表格→表设计→选择单元格→底纹 二、把一张图片做成九宫格 1. 把一张图片画成九宫格&#xff08;处理过后还是一张图片&#xff0c;但是有框线&#xff09; 绘制33表格→插入图片→全选表格单元格→右键设置形状格式→填充→图片或纹理填充…

前后端分离开发遵循接口规范-YAPI

目前&#xff0c;网站主流开发方式是前后端分离。因此前后端必须遵循一套统一的规范&#xff0c;才能保证前后端进行正常的数据&#xff08;JSON数据格式&#xff09;请求、影响&#xff0c;这套规范即是 YAPI. 产品经理撰写原型&#xff1b; 前端或后端撰写接口文档。 YAPI…

一座山城如何打造教育“一张网”

教育新基建作为国家新基建的重要组成部分,是实现教育高质量发展的基础支撑。2021年,教育部等六部门印发相关部署意见时明确提出:到2025年,基本形成结构优化、集约高效、安全可靠的教育新型基础设施体系。 在此宏观导向下,山城重庆积极响应,立足本地情况,开启了其特色化的探索之…

K8s对接Ceph-csi配置手册(附带踩坑记录以及解决方法)

目录 Ceph CSI (Container Storage Interface) CSI 的作用&#xff1a; 前提配置 版本信息 获取Ceph认证信息 获取Ceph集群Monitor信息 下载并部署Ceph CSI 如果此时全部显示错误&#xff0c;那就代表镜像拉取错误&#xff0c;此时执行的yaml脚本&#xff0c;通过yaml脚…

进行良好的文献综述能否提高学术研究的可信度

VersaBot一键生成文献综述 进行良好的文献综述 对于从多个方面提高学术研究的可信度至关重要&#xff1b; 1. 展示专业知识&#xff1a; 全面的回顾表明您对您所在领域的现有知识和相关理论有深入的了解。这将使您成为权威&#xff0c;并将您的研究置于更广泛的背景下。 2.…

ValueError: invalid literal for int() with base 10: ‘a‘

ValueError: invalid literal for int() with base 10: ‘a‘ 目录 ValueError: invalid literal for int() with base 10: ‘a‘ 【常见模块错误】 【解决方案】 欢迎来到英杰社区https://bbs.csdn.net/topics/617804998 欢迎来到我的主页&#xff0c;我是博主英杰&#xff…

【CodinGame】趣味算法(教学用) CLASH OF CODE -20240730

文章目录 正文转换单位观察模式数字处理字符串处理 写在最后END 正文 转换单位 import sys import math# Auto-generated code below aims at helping you parse # the standard input according to the problem statement.n int(input()) for i in range(n):e int(input())…

win10 定时任务实战--开机启动 Java 应用

引言 在Windows 10系统中&#xff0c;可以通过结合任务计划程序&#xff08;Task Scheduler&#xff09;和批处理脚本&#xff08;.bat&#xff09;或PowerShell脚本来定期运行Java程序。以下是一个基本的步骤说明&#xff0c;展示如何设置这一过程。 第一步&#xff1a;准备…

爬虫“拥抱大模型”,有没有搞头?

验证码坐标识别 数据采集过程中&#xff0c;可能会碰到各种风控策略。其中&#xff0c;验证码人机验证是较为常见的&#xff0c;点选类验证码需要识别出相应的坐标&#xff0c;碰到这种情况&#xff0c;一般要么自己训练模型&#xff0c;要么对接打码平台。现在也可以将识别工…

多媒体技术:语音音频压缩

语音音频压缩 语音音频基础知识物理世界的声音——语音 语音音频编码方法波形编码波形编码原理常用波形编码技术脉冲编码调制PCM差分脉冲编码调制DPCM自适应差分脉冲编码调制ADPCM子带ADPCM 参数编码感知编码 语音音频编码框架语音编码框架音频编码框架混合编码框架 语音音频编…

JDK8的新特性

目录 接口的默认方法和静态方法 Lambda表达式1、匿名内部类2、函数式接口(FunctionalInterface)2.1 无参函数式接口匿名内部类方式-->Lambda表达式方式 2.2 有参函数式接口匿名内部类方式-->Lambda表达式方式 3、Lambda实战 3.1 循环遍历 3.2 集合排序 3.3 创建线程方…

黑马头条Day12-项目部署_持续集成

一、今日内容介绍 1. 什么是持续集成 持续集成&#xff08;Continuous integration&#xff0c;简称CI&#xff09;&#xff0c;指的是频繁地&#xff08;一天多次&#xff09;将代码集成到主干。 持续集成的组成要素&#xff1a; 一个自动构建过程&#xff0c;从检出代码、…