yolov8实战第四天——yolov8图像分类 ResNet50图像分类(保姆式教程)

yolov8实战第一天——yolov8部署并训练自己的数据集(保姆式教程)_yolov8训练自己的数据集-CSDN博客在前几天,我们使用yolov8进行了部署,并在目标检测方向上进行自己数据集的训练与测试,今天我们训练下yolov8的图像分类,看看效果如何,同时使用resnet50也训练一个分类模型,看看哪个效果好!

图像分类是指将输入的图像自动分类为不同的类别。它是计算机视觉领域的一个重要应用,可以用于人脸识别、物体识别、场景分类等任务。

通常情况下,图像分类的流程如下:

  1. 收集和准备数据集:收集与任务相关的图像数据,并将其打上标签。
  2. 定义模型:选择一种适合于你的任务的深度学习模型,例如卷积神经网络(CNN)。
  3. 训练模型:使用收集到的数据集对模型进行训练,通过反向传播算法来更新模型参数,使其可以根据输入图像进行正确的分类。
  4. 评估模型性能:使用测试集对已经训练好的模型进行评估,比较模型预测结果与真实标签之间的差异,从而评估模型的性能。
  5. 使用模型进行预测:使用已经训练好的模型对新的图像进行分类预测。

在实际应用中,可以使用各种深度学习框架(例如 TensorFlow、PyTorch、Keras 等)来构建图像分类模型,并使用各种数据增强技术(例如旋转、缩放、裁剪等)来增加数据集的多样性和数量。

如果你想学习如何使用深度学习框架来构建图像分类模型,可以参考一些在线教程、书籍或者 MOOC。

一、yolov8图像分类

1.模型选型

下载yolov8分类模型。

分别使用模型进行测试:

yolov8n-cls效果:

yolov8m-cls效果:

总结:n效果不咋地,还是得使用m进行后续训练工作。 

2.数据集准备

皮肤癌检测_数据集-飞桨AI Studio星河社区

同目标检测,还是放在datasets下。

直接改成这个,省去分数据集操作。 

 3.训练

yolo classify train data=./datasets/skin-cancer-detection model=yolov8n-cls.pt epochs=100

测试:

yolo classify predict model=runs/classify/train4/weights/best.pt source='./datasets/skin-cancer-detection/train/nevus'

  

label: 

 pred:

总结:数据集比较小,yolov8效果不太好。

、resnet50图像分类

Resnet50 网络中包含了 49 个卷积层、一个全连接层。如图下图所示,Resnet50网络结构可以分成七个部分,第一部分不包含残差块,主要对输入进行卷积、正则化、激活函数、最大池化的计算。第二、三、四、五部分结构都包含了残差块,图 中的绿色图块不会改变残差块的尺寸,只用于改变残差块的维度。在 Resnet50 网 络 结 构 中 , 残 差 块 都 有 三 层 卷 积 , 那 网 络 总 共 有1+3×(3+4+6+3)=49个卷积层,加上最后的全连接层总共是 50 层,这也是Resnet50 名称的由来。网络的输入为 224×224×3,经过前五部分的卷积计算,输出为 7×7×2048,池化层会将其转化成一个特征向量,最后分类器会对这个特征向量进行计算并输出类别概率。

运行train.py即可。

train.py

import torch
from torchvision import datasets, models, transforms
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import timeimport numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm# 一、建立数据集
# animals-6
#   --train
#       |--dog
#       |--cat
#       ...
#   --valid
#       |--dog
#       |--cat
#       ...
#   --test
#       |--dog
#       |--cat
#       ...
# 我的数据集中 train 中每个类别60张图片,valid 中每个类别 10 张图片,test 中每个类别几张到几十张不等,一共 6 个类别。# 二、数据增强
# 建好的数据集在输入网络之前先进行数据增强,包括随机 resize 裁剪到 256 x 256,随机旋转,随机水平翻转,中心裁剪到 224 x 224,转化成 Tensor,正规化等。
image_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),transforms.RandomRotation(degrees=15),transforms.RandomHorizontalFlip(),transforms.CenterCrop(size=224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])]),'valid': transforms.Compose([transforms.Resize(size=256),transforms.CenterCrop(size=224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
}# 三、加载数据
# torchvision.transforms包DataLoader是 Pytorch 重要的特性,它们使得数据增加和加载数据变得非常简单。
# 使用 DataLoader 加载数据的时候就会将之前定义的数据 transform 就会应用的数据上了。
dataset = 'skin-cancer-detection'
train_directory = './skin-cancer-detection/train'
valid_directory = './skin-cancer-detection/val'batch_size = 32
num_classes = 9 #分类种类数
print(train_directory)
data = {'train': datasets.ImageFolder(root=train_directory, transform=image_transforms['train']),'valid': datasets.ImageFolder(root=valid_directory, transform=image_transforms['valid'])
}
print("训练集图片类别及其对应编号(种类名:编号):",data['train'].class_to_idx)
print("测试集图片类别及其对应编号:",data['valid'].class_to_idx)train_data_size = len(data['train'])
valid_data_size = len(data['valid'])train_data = DataLoader(data['train'], batch_size=batch_size, shuffle=True, num_workers=0)
valid_data = DataLoader(data['valid'], batch_size=batch_size, shuffle=True, num_workers=0)print("训练集图片数量:",train_data_size, "测试集图片数量:",valid_data_size)# 四、迁移学习
# 这里使用ResNet-50的预训练模型。
#resnet50 = models.resnet50(pretrained=True)
resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)# 在PyTorch中加载模型时,所有参数的‘requires_grad’字段默认设置为true。这意味着对参数值的每一次更改都将被存储,以便在用于训练的反向传播图中使用。
# 这增加了内存需求。由于预训练的模型中的大多数参数已经训练好了,因此将requires_grad字段重置为false。
for param in resnet50.parameters():param.requires_grad = False# 为了适应自己的数据集,将ResNet-50的最后一层替换为,将原来最后一个全连接层的输入喂给一个有256个输出单元的线性层,接着再连接ReLU层和Dropout层,然后是256 x 6的线性层,输出为6通道的softmax层。
fc_inputs = resnet50.fc.in_features
resnet50.fc = nn.Sequential(nn.Linear(fc_inputs, 256),nn.ReLU(),nn.Dropout(0.4),nn.Linear(256, num_classes),nn.LogSoftmax(dim=1)
)# 用GPU进行训练。
resnet50 = resnet50.to('cuda:0')# 定义损失函数和优化器。
loss_func = nn.NLLLoss()
optimizer = optim.Adam(resnet50.parameters())# 五、训练
def train_and_valid(model, loss_function, optimizer, epochs=25):device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")history = []best_acc = 0.0best_epoch = 0for epoch in range(epochs):epoch_start = time.time()print("Epoch: {}/{}".format(epoch+1, epochs))model.train()train_loss = 0.0train_acc = 0.0valid_loss = 0.0valid_acc = 0.0for i, (inputs, labels) in enumerate(tqdm(train_data)):inputs = inputs.to(device)labels = labels.to(device)#因为这里梯度是累加的,所以每次记得清零optimizer.zero_grad()outputs = model(inputs)loss = loss_function(outputs, labels)print("标签值:",labels)print("输出值:",outputs)loss.backward()optimizer.step()train_loss += loss.item() * inputs.size(0)ret, predictions = torch.max(outputs.data, 1)correct_counts = predictions.eq(labels.data.view_as(predictions))acc = torch.mean(correct_counts.type(torch.FloatTensor))train_acc += acc.item() * inputs.size(0)with torch.no_grad():model.eval()for j, (inputs, labels) in enumerate(tqdm(valid_data)):inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)loss = loss_function(outputs, labels)valid_loss += loss.item() * inputs.size(0)ret, predictions = torch.max(outputs.data, 1)correct_counts = predictions.eq(labels.data.view_as(predictions))acc = torch.mean(correct_counts.type(torch.FloatTensor))valid_acc += acc.item() * inputs.size(0)avg_train_loss = train_loss/train_data_sizeavg_train_acc = train_acc/train_data_sizeavg_valid_loss = valid_loss/valid_data_sizeavg_valid_acc = valid_acc/valid_data_sizehistory.append([avg_train_loss, avg_valid_loss, avg_train_acc, avg_valid_acc])if best_acc < avg_valid_acc:best_acc = avg_valid_accbest_epoch = epoch + 1epoch_end = time.time()print("Epoch: {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}%, \n\t\tValidation: Loss: {:.4f}, Accuracy: {:.4f}%, Time: {:.4f}s".format(epoch+1, avg_valid_loss, avg_train_acc*100, avg_valid_loss, avg_valid_acc*100, epoch_end-epoch_start))print("Best Accuracy for validation : {:.4f} at epoch {:03d}".format(best_acc, best_epoch))torch.save(model, 'models/'+dataset+'_model_'+str(epoch+1)+'.pt')return model, historynum_epochs = 100 #训练周期数
trained_model, history = train_and_valid(resnet50, loss_func, optimizer, num_epochs)
torch.save(history, 'models/'+dataset+'_history.pt')history = np.array(history)
plt.plot(history[:, 0:2])
plt.legend(['Tr Loss', 'Val Loss'])
plt.xlabel('Epoch Number')
plt.ylabel('Loss')
plt.ylim(0, 1)
plt.savefig(dataset+'_loss_curve.png')
plt.show()plt.plot(history[:, 2:4])
plt.legend(['Tr Accuracy', 'Val Accuracy'])
plt.xlabel('Epoch Number')
plt.ylabel('Accuracy')
plt.ylim(0, 1)
plt.savefig(dataset+'_accuracy_curve.png')
plt.show()

测试:图片名改下即可。

import torch
from torchvision import  models, transforms
import torch.nn as nn
import cv2
classes = ["1","2","3","4","5","6","7","8","9"] #识别种类名称(顺序要与训练时的数据导入编号顺序对应,可以使用datasets.ImageFolder().class_to_idx来查看)transf = transforms.ToTensor()
device = torch.device('cuda:0')
num_classes = 2
model_path = "models/skin-cancer-detection_model_3.pt"
image_input = cv2.imread("ISIC_0000019.jpg")
image_input = transf(image_input)
image_input = torch.unsqueeze(image_input,dim=0).cuda()
#搭建模型
resnet50 = models.resnet50(pretrained=True)
for param in resnet50.parameters():param.requires_grad = Falsefc_inputs = resnet50.fc.in_features
resnet50.fc = nn.Sequential(nn.Linear(fc_inputs, 256),nn.ReLU(),nn.Dropout(0.4),nn.Linear(256, num_classes),nn.LogSoftmax(dim=1)
)
resnet50 = torch.load(model_path)outputs = resnet50(image_input)
value,id =torch.max(outputs,1)
print(outputs,"\n","结果是:",classes[id])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/585899.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

环形锻件全自动尺寸测量法兰三维检测自动化设备-CASAIM自动化蓝光检测系统

锻造是一种利用锻压机械对金属坯料施加压力&#xff0c;使其产生塑性变形以获得具有一定机械性能、一定形状和尺寸锻件的加工方法&#xff0c;锻压&#xff08;锻造与冲压&#xff09;的两大组成部分之一。 目前客户使用专用直径千分尺、塞规、塞尺等对锻件形状尺寸误差进行测…

数模学习day05-插值算法

插值算法有什么作用呢&#xff1f; 答&#xff1a;数模比赛中&#xff0c;常常需要根据已知的函数点进行数据、模型的处理和分析&#xff0c;而有时候现有的数据是极少的&#xff0c;不足以支撑分析的进行&#xff0c;这时就需要使用一些数学的方法&#xff0c;“模拟产生”一些…

SOLIDWORKS Flow Simulation热环境分析

关于室内通风的问题&#xff0c;其实室内通风方面的与我们之前聊到的数据中心通风散热问题相类似&#xff0c;只不过本次会引入一个新的模块——人体舒适度问题&#xff0c;在Flow Simulation中有一个HVAC模块就是专门用于研究人体舒适度的&#xff0c;它可以预测人们在热环境中…

力扣题目学习笔记(OC + Swift)206. 反转链表

206. 反转链表 给你单链表的头节点 head &#xff0c;请你反转链表&#xff0c;并返回反转后的链表。 示例 方法一、迭代 在遍历链表时&#xff0c;将当前节点的 next\textit{next}next 指针改为指向前一个节点。由于节点没有引用其前一个节点&#xff0c;因此必须事先存储其…

【Proteus仿真】【51单片机】自动除湿器系统

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 本项目使用Proteus8仿真51单片机控制器&#xff0c;使用按键、LCD1602液晶、DHT11温湿度、继电器除湿模块等。 主要功能&#xff1a; 系统运行后&#xff0c;LCD1602显示DHT11传感器检测的湿度值阈值&am…

开源radishes高仿网易云音乐完整源码,可试听和下载“灰色”歌曲,跨平台的无版权音乐平台

源码介绍 Radishes是项目名称&#xff0c;是由萝卜翻译而来。可以在这里试听和下载“灰色”歌曲&#xff0c;是一个可以跨平台的无版权音乐平台。 萝卜音乐界面和功能参考 windows 网易云音乐界面和 ios 的网易云音乐 安装依赖 cd radishes/ yarn bootstrap 运行项目 web:…

docker-compose Install TeamCity

前言 TeamCity 是一个通用的 CI/CD 软件平台,可实现灵活的工作流程、协作和开发实践。允许在您的 DevOps 流程中成功实现持续集成、持续交付和持续部署。 系统支持 docker download TeamCity TeamCity 文档参考项目离线包百度网盘获取

PPT录制视频的方法,轻松提升演示效果!

在现代工作和学习中&#xff0c;ppt是一种常见的演示工具&#xff0c;而将ppt转化为视频形式更能方便分享和传播。本文将介绍两种ppt录制视频的方法&#xff0c;每一种方法都将有详细的步骤和简要的介绍&#xff0c;通过这些方法&#xff0c;你可以轻松将ppt制作成视频&#xf…

Servlet入门

目录 1.Servlet介绍 1.1什么是Servlet 1.2Servlet的使用方法 1.3Servlet接口的继承结构 2.Servlet快速入门 2.1创建javaweb项目 2.1.1创建maven工程 2.1.2添加webapp目录 2.2添加依赖 2.3创建servlet实例 2.4配置servlet 2.5设置打包方式 2.6部署web项目 3.servl…

SSH -L:安全、便捷、无边界的网络通行证

欢迎来到我的博客&#xff0c;代码的世界里&#xff0c;每一行都是一个故事 SSH -L&#xff1a;安全、便捷、无边界的网络通行证 前言1. SSH -L基础概念SSH -L 的基本语法&#xff1a;端口转发的原理和作用&#xff1a; 2. SSH -L的基本用法远程访问本地示例&#xff1a;访问本…

Linux 账号的权限管理

文章目录 Linux 账号的权限管理一、用户账号和组账号的概述1、Linux基于用户身份对资源访问进行控制1.1 用户账号1.2 组账号1.3 UID和GID系统如何区别用户身份 二、用户账号文件1、用户账号文件/etc/passwd保存用户名称&#xff0c;宿主目录、登录shell等基本信息 2、用户账号文…

【北亚服务器数据恢复】san环境下LUN Mapping出错导致文件系统一致性出错的数据恢复案例

服务器数据恢复环境&#xff1a; san环境下的存储上一组由6块硬盘组建的RAID6&#xff0c;划分为若干LUN&#xff0c;MAP到跑不同业务的服务器上&#xff0c;服务器上层是SOLARIS操作系统UFS文件系统。 服务器故障&#xff1a; 业务需求需要增加一台服务器跑新增的应用&#xf…

java注解和反射

java注解和反射 内置注解 Override 重写生命 Deprecated 已过时的方法&#xff0c;不推荐使用&#xff0c;可以使用 SuppressWarning 镇压警告&#xff0c;懂的都懂 元注解 作用&#xff1a;负责注解其他的注解 Target 描述注解的使用范围 Retention 描述注解的生命周期 Docu…

蚂蚁实习一面面经

蚂蚁实习一面面经 希望可以帮助到大家 tcp建立连接为什么要三次握手&#xff1f; 三次握手的过程 注意&#xff1a;三次握手的最主要目的是保证连接是双工的&#xff0c;可靠更多的是通过重传机制来保证的 所谓三次握手&#xff0c;即建立TCP连接&#xff0c;需要客户端和…

2023年高级软考系统架构师考题参考

对于一些有实践经验的同学来说&#xff0c;感觉不难&#xff0c;但是落笔到纸面上&#xff0c;就差强人意了&#xff0c;平时这方面要多练习&#xff0c;所想所思要落到纸面上&#xff0c;或者表达清晰让别人听懂&#xff0c;不仅是工作中的一个基本素质&#xff0c;也是个非常…

【数学建模美赛M奖速成系列】Matplotlib绘图技巧(三)

Matplotlib绘图技巧&#xff08;三&#xff09; 写在前面7. 雷达图7.1 圆形雷达图7.2 多边形雷达图 8. 极坐标图 subplot9. 折线图 plot10. 灰度图 meshgrid11. 热力图11.1 自定义colormap 12. 箱线图 boxplot 写在前面 终于更新完Matplotlib绘图技巧的全部内容&#xff0c;有…

web漏洞与修复

一、web漏洞 检测到目标X-Content-Type-Options响应头缺失 详细描述X-Content-Type-Options HTTP 消息头相当于一个提示标志&#xff0c;被服务器用来提示客户端一定要遵循在 Content-Type 首部中对 MIME 类型 的设定&#xff0c;而不能对其进行修改。这就禁用了客户端的 MIM…

跨境电商:让中国制造走向世界

跨境电商的崛起 跨境电商是指不同国家和地区之间的商业交易&#xff0c;通过互联网和物流等方式完成。随着全球化和互联网的普及&#xff0c;跨境电商迅速崛起&#xff0c;成为全球贸易的重要组成部分。中国作为全球最大的制造业国家&#xff0c;拥有着丰富的商品资源和供应链…

链表总结(2)

theme: fancy 又是链表专题啦&#xff0c;老样子&#xff0c;标题就是leetcode链接&#xff0c;在这里只放我的代码答案和注释 141环形链表 public class Solution {public boolean hasCycle(ListNode head) {if(head null || head.next null) return false;if(head.nex…

详解FreeRTOS:FreeRTOSConfig.h系统配置文件(拓展篇—1)

目录 1、“INCLUDE_”宏 2、“config”宏 实际使用FreeRTOS的时候,时常需要根据自己需求来配置 FreeRTOS,不同架构的MCU,配置也不同。 FreeRTOS的系统配置文件为FreeRTOSConfig.h,在配置文件中可以完成FreeRTOS的裁剪和配置,这是非常重要的一个文件,本篇博文就来讲解这…