yolov8实战第四天——yolov8图像分类 ResNet50图像分类(保姆式教程)

yolov8实战第一天——yolov8部署并训练自己的数据集(保姆式教程)_yolov8训练自己的数据集-CSDN博客在前几天,我们使用yolov8进行了部署,并在目标检测方向上进行自己数据集的训练与测试,今天我们训练下yolov8的图像分类,看看效果如何,同时使用resnet50也训练一个分类模型,看看哪个效果好!

图像分类是指将输入的图像自动分类为不同的类别。它是计算机视觉领域的一个重要应用,可以用于人脸识别、物体识别、场景分类等任务。

通常情况下,图像分类的流程如下:

  1. 收集和准备数据集:收集与任务相关的图像数据,并将其打上标签。
  2. 定义模型:选择一种适合于你的任务的深度学习模型,例如卷积神经网络(CNN)。
  3. 训练模型:使用收集到的数据集对模型进行训练,通过反向传播算法来更新模型参数,使其可以根据输入图像进行正确的分类。
  4. 评估模型性能:使用测试集对已经训练好的模型进行评估,比较模型预测结果与真实标签之间的差异,从而评估模型的性能。
  5. 使用模型进行预测:使用已经训练好的模型对新的图像进行分类预测。

在实际应用中,可以使用各种深度学习框架(例如 TensorFlow、PyTorch、Keras 等)来构建图像分类模型,并使用各种数据增强技术(例如旋转、缩放、裁剪等)来增加数据集的多样性和数量。

如果你想学习如何使用深度学习框架来构建图像分类模型,可以参考一些在线教程、书籍或者 MOOC。

一、yolov8图像分类

1.模型选型

下载yolov8分类模型。

分别使用模型进行测试:

yolov8n-cls效果:

yolov8m-cls效果:

总结:n效果不咋地,还是得使用m进行后续训练工作。 

2.数据集准备

皮肤癌检测_数据集-飞桨AI Studio星河社区

同目标检测,还是放在datasets下。

直接改成这个,省去分数据集操作。 

 3.训练

yolo classify train data=./datasets/skin-cancer-detection model=yolov8n-cls.pt epochs=100

测试:

yolo classify predict model=runs/classify/train4/weights/best.pt source='./datasets/skin-cancer-detection/train/nevus'

  

label: 

 pred:

总结:数据集比较小,yolov8效果不太好。

、resnet50图像分类

Resnet50 网络中包含了 49 个卷积层、一个全连接层。如图下图所示,Resnet50网络结构可以分成七个部分,第一部分不包含残差块,主要对输入进行卷积、正则化、激活函数、最大池化的计算。第二、三、四、五部分结构都包含了残差块,图 中的绿色图块不会改变残差块的尺寸,只用于改变残差块的维度。在 Resnet50 网 络 结 构 中 , 残 差 块 都 有 三 层 卷 积 , 那 网 络 总 共 有1+3×(3+4+6+3)=49个卷积层,加上最后的全连接层总共是 50 层,这也是Resnet50 名称的由来。网络的输入为 224×224×3,经过前五部分的卷积计算,输出为 7×7×2048,池化层会将其转化成一个特征向量,最后分类器会对这个特征向量进行计算并输出类别概率。

运行train.py即可。

train.py

import torch
from torchvision import datasets, models, transforms
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import timeimport numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm# 一、建立数据集
# animals-6
#   --train
#       |--dog
#       |--cat
#       ...
#   --valid
#       |--dog
#       |--cat
#       ...
#   --test
#       |--dog
#       |--cat
#       ...
# 我的数据集中 train 中每个类别60张图片,valid 中每个类别 10 张图片,test 中每个类别几张到几十张不等,一共 6 个类别。# 二、数据增强
# 建好的数据集在输入网络之前先进行数据增强,包括随机 resize 裁剪到 256 x 256,随机旋转,随机水平翻转,中心裁剪到 224 x 224,转化成 Tensor,正规化等。
image_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),transforms.RandomRotation(degrees=15),transforms.RandomHorizontalFlip(),transforms.CenterCrop(size=224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])]),'valid': transforms.Compose([transforms.Resize(size=256),transforms.CenterCrop(size=224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
}# 三、加载数据
# torchvision.transforms包DataLoader是 Pytorch 重要的特性,它们使得数据增加和加载数据变得非常简单。
# 使用 DataLoader 加载数据的时候就会将之前定义的数据 transform 就会应用的数据上了。
dataset = 'skin-cancer-detection'
train_directory = './skin-cancer-detection/train'
valid_directory = './skin-cancer-detection/val'batch_size = 32
num_classes = 9 #分类种类数
print(train_directory)
data = {'train': datasets.ImageFolder(root=train_directory, transform=image_transforms['train']),'valid': datasets.ImageFolder(root=valid_directory, transform=image_transforms['valid'])
}
print("训练集图片类别及其对应编号(种类名:编号):",data['train'].class_to_idx)
print("测试集图片类别及其对应编号:",data['valid'].class_to_idx)train_data_size = len(data['train'])
valid_data_size = len(data['valid'])train_data = DataLoader(data['train'], batch_size=batch_size, shuffle=True, num_workers=0)
valid_data = DataLoader(data['valid'], batch_size=batch_size, shuffle=True, num_workers=0)print("训练集图片数量:",train_data_size, "测试集图片数量:",valid_data_size)# 四、迁移学习
# 这里使用ResNet-50的预训练模型。
#resnet50 = models.resnet50(pretrained=True)
resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)# 在PyTorch中加载模型时,所有参数的‘requires_grad’字段默认设置为true。这意味着对参数值的每一次更改都将被存储,以便在用于训练的反向传播图中使用。
# 这增加了内存需求。由于预训练的模型中的大多数参数已经训练好了,因此将requires_grad字段重置为false。
for param in resnet50.parameters():param.requires_grad = False# 为了适应自己的数据集,将ResNet-50的最后一层替换为,将原来最后一个全连接层的输入喂给一个有256个输出单元的线性层,接着再连接ReLU层和Dropout层,然后是256 x 6的线性层,输出为6通道的softmax层。
fc_inputs = resnet50.fc.in_features
resnet50.fc = nn.Sequential(nn.Linear(fc_inputs, 256),nn.ReLU(),nn.Dropout(0.4),nn.Linear(256, num_classes),nn.LogSoftmax(dim=1)
)# 用GPU进行训练。
resnet50 = resnet50.to('cuda:0')# 定义损失函数和优化器。
loss_func = nn.NLLLoss()
optimizer = optim.Adam(resnet50.parameters())# 五、训练
def train_and_valid(model, loss_function, optimizer, epochs=25):device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")history = []best_acc = 0.0best_epoch = 0for epoch in range(epochs):epoch_start = time.time()print("Epoch: {}/{}".format(epoch+1, epochs))model.train()train_loss = 0.0train_acc = 0.0valid_loss = 0.0valid_acc = 0.0for i, (inputs, labels) in enumerate(tqdm(train_data)):inputs = inputs.to(device)labels = labels.to(device)#因为这里梯度是累加的,所以每次记得清零optimizer.zero_grad()outputs = model(inputs)loss = loss_function(outputs, labels)print("标签值:",labels)print("输出值:",outputs)loss.backward()optimizer.step()train_loss += loss.item() * inputs.size(0)ret, predictions = torch.max(outputs.data, 1)correct_counts = predictions.eq(labels.data.view_as(predictions))acc = torch.mean(correct_counts.type(torch.FloatTensor))train_acc += acc.item() * inputs.size(0)with torch.no_grad():model.eval()for j, (inputs, labels) in enumerate(tqdm(valid_data)):inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)loss = loss_function(outputs, labels)valid_loss += loss.item() * inputs.size(0)ret, predictions = torch.max(outputs.data, 1)correct_counts = predictions.eq(labels.data.view_as(predictions))acc = torch.mean(correct_counts.type(torch.FloatTensor))valid_acc += acc.item() * inputs.size(0)avg_train_loss = train_loss/train_data_sizeavg_train_acc = train_acc/train_data_sizeavg_valid_loss = valid_loss/valid_data_sizeavg_valid_acc = valid_acc/valid_data_sizehistory.append([avg_train_loss, avg_valid_loss, avg_train_acc, avg_valid_acc])if best_acc < avg_valid_acc:best_acc = avg_valid_accbest_epoch = epoch + 1epoch_end = time.time()print("Epoch: {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}%, \n\t\tValidation: Loss: {:.4f}, Accuracy: {:.4f}%, Time: {:.4f}s".format(epoch+1, avg_valid_loss, avg_train_acc*100, avg_valid_loss, avg_valid_acc*100, epoch_end-epoch_start))print("Best Accuracy for validation : {:.4f} at epoch {:03d}".format(best_acc, best_epoch))torch.save(model, 'models/'+dataset+'_model_'+str(epoch+1)+'.pt')return model, historynum_epochs = 100 #训练周期数
trained_model, history = train_and_valid(resnet50, loss_func, optimizer, num_epochs)
torch.save(history, 'models/'+dataset+'_history.pt')history = np.array(history)
plt.plot(history[:, 0:2])
plt.legend(['Tr Loss', 'Val Loss'])
plt.xlabel('Epoch Number')
plt.ylabel('Loss')
plt.ylim(0, 1)
plt.savefig(dataset+'_loss_curve.png')
plt.show()plt.plot(history[:, 2:4])
plt.legend(['Tr Accuracy', 'Val Accuracy'])
plt.xlabel('Epoch Number')
plt.ylabel('Accuracy')
plt.ylim(0, 1)
plt.savefig(dataset+'_accuracy_curve.png')
plt.show()

测试:图片名改下即可。

import torch
from torchvision import  models, transforms
import torch.nn as nn
import cv2
classes = ["1","2","3","4","5","6","7","8","9"] #识别种类名称(顺序要与训练时的数据导入编号顺序对应,可以使用datasets.ImageFolder().class_to_idx来查看)transf = transforms.ToTensor()
device = torch.device('cuda:0')
num_classes = 2
model_path = "models/skin-cancer-detection_model_3.pt"
image_input = cv2.imread("ISIC_0000019.jpg")
image_input = transf(image_input)
image_input = torch.unsqueeze(image_input,dim=0).cuda()
#搭建模型
resnet50 = models.resnet50(pretrained=True)
for param in resnet50.parameters():param.requires_grad = Falsefc_inputs = resnet50.fc.in_features
resnet50.fc = nn.Sequential(nn.Linear(fc_inputs, 256),nn.ReLU(),nn.Dropout(0.4),nn.Linear(256, num_classes),nn.LogSoftmax(dim=1)
)
resnet50 = torch.load(model_path)outputs = resnet50(image_input)
value,id =torch.max(outputs,1)
print(outputs,"\n","结果是:",classes[id])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/585899.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LangChain.js 实战系列:入门介绍

&#x1f4dd; LangChain.js 是一个快速开发大模型应用的框架&#xff0c;它提供了一系列强大的功能和工具&#xff0c;使得开发者能够更加高效地构建复杂的应用程序。LangChain.js 实战系列文章将介绍在实际项目中使用 LangChain.js 时的一些方法和技巧。 LangChain.js 是一个…

环形锻件全自动尺寸测量法兰三维检测自动化设备-CASAIM自动化蓝光检测系统

锻造是一种利用锻压机械对金属坯料施加压力&#xff0c;使其产生塑性变形以获得具有一定机械性能、一定形状和尺寸锻件的加工方法&#xff0c;锻压&#xff08;锻造与冲压&#xff09;的两大组成部分之一。 目前客户使用专用直径千分尺、塞规、塞尺等对锻件形状尺寸误差进行测…

使用 JWT(JSON Web 令牌)实现登录身份验证和令牌续订

文档链接 文档链接, PDF中包含一部分宣传大字制作不易还望多多支持互相交流 使用 JWT&#xff08;JSON Web 令牌&#xff09;实现登录身份验证和令牌续订。它将 JWT 与基于会话的身份验证进行了比较&#xff0c;并强调了每种方法的差异、优点和缺点。本文档介绍了基于会话的方…

2022年2月-2023年12月总结

时间&#xff1a;2023年12月30日 14:28 地点&#xff1a;博库书城 0、序言 2023年上半年除了阿里16N拆分&#xff0c;我就不记得其他的事情了&#xff0c;就好像我的记忆只从今年9月份自己的折腾开始。是不是只有自己认真对待过的日子&#xff0c;才会让人深刻呢&#xff1f;…

T527 Android13遥控适配

T527 Android13遥控的适配和官方提供的文档有些不一样&#xff0c;按照官方的文档不能够正常适配到自己的遥控器。 首先确保驱动是否有打开CONFIG_AW_IR_RX和CONFIG_RC_DECODERSy 以及CONFIG_IR_NEC_DECODERm&#xff0c;这个可以在longan/out/t527对应的目录下的.config查看是…

算法每日一题:购买两块巧克力 | 两个最小值的遍历

大家好&#xff0c;我是星恒 今天的每日一题是寻找一个数组中的两个最小值&#xff0c;看似简单的一道题&#xff0c;其实有不少门道&#xff01; 话不多说&#xff0c;我们直接来看&#xff1a; 题目&#xff1a;给你一个整数数组 prices &#xff0c;它表示一个商店里若干巧克…

数模学习day05-插值算法

插值算法有什么作用呢&#xff1f; 答&#xff1a;数模比赛中&#xff0c;常常需要根据已知的函数点进行数据、模型的处理和分析&#xff0c;而有时候现有的数据是极少的&#xff0c;不足以支撑分析的进行&#xff0c;这时就需要使用一些数学的方法&#xff0c;“模拟产生”一些…

Windows下Qt使用MSVC编译出现需要转为unicode的提示

参考 Qt5中文编码问题解决办法_qt5设置编码-CSDN博客 致敬 提示&#xff1a;warning: C4819: 该文件包含不能在当前代码页(936)中表示的字符。请将该文件保存为 Unicode 格式以防止数据丢失。 出现此问题&#xff0c;应该是Unix格式下代码的编码格式是UTF-8&#xff0c;注意不…

SOLIDWORKS Flow Simulation热环境分析

关于室内通风的问题&#xff0c;其实室内通风方面的与我们之前聊到的数据中心通风散热问题相类似&#xff0c;只不过本次会引入一个新的模块——人体舒适度问题&#xff0c;在Flow Simulation中有一个HVAC模块就是专门用于研究人体舒适度的&#xff0c;它可以预测人们在热环境中…

LeetCode75| 哈希表/哈希集合

目录 2215 找出两数组的不同 1207 独一无二的出现次数 1657 确定两个字符串是否接近 2352 相等行列对 2215 找出两数组的不同 class Solution { public:vector<vector<int>> findDifference(vector<int>& nums1, vector<int>& nums2) {un…

力扣题目学习笔记(OC + Swift)206. 反转链表

206. 反转链表 给你单链表的头节点 head &#xff0c;请你反转链表&#xff0c;并返回反转后的链表。 示例 方法一、迭代 在遍历链表时&#xff0c;将当前节点的 next\textit{next}next 指针改为指向前一个节点。由于节点没有引用其前一个节点&#xff0c;因此必须事先存储其…

【Proteus仿真】【51单片机】自动除湿器系统

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 本项目使用Proteus8仿真51单片机控制器&#xff0c;使用按键、LCD1602液晶、DHT11温湿度、继电器除湿模块等。 主要功能&#xff1a; 系统运行后&#xff0c;LCD1602显示DHT11传感器检测的湿度值阈值&am…

面试数据库八股文五问五答第三期

面试数据库八股文五问五答第三期 作者&#xff1a;程序员小白条&#xff0c;个人博客 相信看了本文后&#xff0c;对你的面试是有一定帮助的&#xff01; ⭐点赞⭐收藏⭐不迷路&#xff01;⭐ 1&#xff09;MyIsAm 和 InnoDB 的区别 事务支持&#xff1a;MyISAM不支持事务&a…

开源radishes高仿网易云音乐完整源码,可试听和下载“灰色”歌曲,跨平台的无版权音乐平台

源码介绍 Radishes是项目名称&#xff0c;是由萝卜翻译而来。可以在这里试听和下载“灰色”歌曲&#xff0c;是一个可以跨平台的无版权音乐平台。 萝卜音乐界面和功能参考 windows 网易云音乐界面和 ios 的网易云音乐 安装依赖 cd radishes/ yarn bootstrap 运行项目 web:…

docker-compose Install TeamCity

前言 TeamCity 是一个通用的 CI/CD 软件平台,可实现灵活的工作流程、协作和开发实践。允许在您的 DevOps 流程中成功实现持续集成、持续交付和持续部署。 系统支持 docker download TeamCity TeamCity 文档参考项目离线包百度网盘获取

PPT录制视频的方法,轻松提升演示效果!

在现代工作和学习中&#xff0c;ppt是一种常见的演示工具&#xff0c;而将ppt转化为视频形式更能方便分享和传播。本文将介绍两种ppt录制视频的方法&#xff0c;每一种方法都将有详细的步骤和简要的介绍&#xff0c;通过这些方法&#xff0c;你可以轻松将ppt制作成视频&#xf…

Servlet入门

目录 1.Servlet介绍 1.1什么是Servlet 1.2Servlet的使用方法 1.3Servlet接口的继承结构 2.Servlet快速入门 2.1创建javaweb项目 2.1.1创建maven工程 2.1.2添加webapp目录 2.2添加依赖 2.3创建servlet实例 2.4配置servlet 2.5设置打包方式 2.6部署web项目 3.servl…

SSH -L:安全、便捷、无边界的网络通行证

欢迎来到我的博客&#xff0c;代码的世界里&#xff0c;每一行都是一个故事 SSH -L&#xff1a;安全、便捷、无边界的网络通行证 前言1. SSH -L基础概念SSH -L 的基本语法&#xff1a;端口转发的原理和作用&#xff1a; 2. SSH -L的基本用法远程访问本地示例&#xff1a;访问本…

Linux 账号的权限管理

文章目录 Linux 账号的权限管理一、用户账号和组账号的概述1、Linux基于用户身份对资源访问进行控制1.1 用户账号1.2 组账号1.3 UID和GID系统如何区别用户身份 二、用户账号文件1、用户账号文件/etc/passwd保存用户名称&#xff0c;宿主目录、登录shell等基本信息 2、用户账号文…

【北亚服务器数据恢复】san环境下LUN Mapping出错导致文件系统一致性出错的数据恢复案例

服务器数据恢复环境&#xff1a; san环境下的存储上一组由6块硬盘组建的RAID6&#xff0c;划分为若干LUN&#xff0c;MAP到跑不同业务的服务器上&#xff0c;服务器上层是SOLARIS操作系统UFS文件系统。 服务器故障&#xff1a; 业务需求需要增加一台服务器跑新增的应用&#xf…