YOLOv5 分类模型 数据集加载 3

YOLOv5 分类模型 数据集加载 3 自定义类别

flyfish

YOLOv5 分类模型 数据集加载 1 样本处理
YOLOv5 分类模型 数据集加载 2 切片处理
YOLOv5 分类模型的预处理(1) Resize 和 CenterCrop
YOLOv5 分类模型的预处理(2)ToTensor 和 Normalize
YOLOv5 分类模型 Top 1和Top 5 指标说明
YOLOv5 分类模型 Top 1和Top 5 指标实现

之前的处理方式是类别名字是文件夹名字,类别ID是按照文件夹名字的字母顺序
现在是类别名字是文件夹名字,按照文件列表名字顺序 例如

classes_name=['n02086240', 'n02087394', 'n02088364', 'n02089973', 'n02093754', 
'n02096294', 'n02099601', 'n02105641', 'n02111889', 'n02115641']

n02086240 类别ID是0
n02087394 类别ID是1
代码处理是

if classes_name is None or not classes_name:classes, class_to_idx = self.find_classes(self.root)print("not classes_name")else:classes = classes_nameclass_to_idx ={cls_name: i for i, cls_name in enumerate(classes)}print("is classes_name")

完整

import time
from models.common import DetectMultiBackend
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
import cv2
import numpy as npimport torch
from PIL import Image
import torchvision.transforms as transformsimport sysclasses_name=['n02086240', 'n02087394', 'n02088364', 'n02089973', 'n02093754', 'n02096294', 'n02099601', 'n02105641', 'n02111889', 'n02115641']class DatasetFolder:def __init__(self,root: str,) -> None:self.root = rootif classes_name is None or not classes_name:classes, class_to_idx = self.find_classes(self.root)print("not classes_name")else:classes = classes_nameclass_to_idx ={cls_name: i for i, cls_name in enumerate(classes)}print("is classes_name")print("classes:",classes)print("class_to_idx:",class_to_idx)samples = self.make_dataset(self.root, class_to_idx)self.classes = classesself.class_to_idx = class_to_idxself.samples = samplesself.targets = [s[1] for s in samples]@staticmethoddef make_dataset(directory: str,class_to_idx: Optional[Dict[str, int]] = None,) -> List[Tuple[str, int]]:directory = os.path.expanduser(directory)if class_to_idx is None:_, class_to_idx = self.find_classes(directory)elif not class_to_idx:raise ValueError("'class_to_index' must have at least one entry to collect any samples.")instances = []available_classes = set()for target_class in sorted(class_to_idx.keys()):class_index = class_to_idx[target_class]target_dir = os.path.join(directory, target_class)if not os.path.isdir(target_dir):continuefor root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):for fname in sorted(fnames):path = os.path.join(root, fname)if 1:  # 验证:item = path, class_indexinstances.append(item)if target_class not in available_classes:available_classes.add(target_class)empty_classes = set(class_to_idx.keys()) - available_classesif empty_classes:msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "return instancesdef find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())if not classes:raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}return classes, class_to_idxdef __getitem__(self, index: int) -> Tuple[Any, Any]:path, target = self.samples[index]sample = self.loader(path)return sample, targetdef __len__(self) -> int:return len(self.samples)def loader(self, path):print("path:", path)#img = cv2.imread(path)  # BGR HWCimg=Image.open(path).convert("RGB") # RGB HWCreturn imgdef time_sync():return time.time()#sys.exit() 
dataset = DatasetFolder(root="/media/a/flyfish/source/yolov5/datasets/imagewoof/val")#image, label=dataset[7]#
weights = "/home/a/classes.pt"
device = "cpu"
model = DetectMultiBackend(weights, device=device, dnn=False, fp16=False)
model.eval()
print(model.names)
print(type(model.names))mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
def preprocess(images):#实现 PyTorch Resizetarget_size =224img_w = images.widthimg_h = images.heightif(img_h >= img_w):# hwresize_img = images.resize((target_size, int(target_size * img_h / img_w)), Image.BILINEAR)else:resize_img = images.resize((int(target_size * img_w  / img_h),target_size), Image.BILINEAR)#实现 PyTorch CenterCropwidth = resize_img.widthheight = resize_img.heightcenter_x,center_y = width//2,height//2left = center_x - (target_size//2)top = center_y- (target_size//2)right =center_x +target_size//2bottom = center_y+target_size//2cropped_img = resize_img.crop((left, top, right, bottom))#实现 PyTorch ToTensor Normalizeimages = np.asarray(cropped_img)print("preprocess:",images.shape)images = images.astype('float32')images = (images/255-mean)/stdimages = images.transpose((2, 0, 1))# HWC to CHWprint("preprocess:",images.shape)images = np.ascontiguousarray(images)images=torch.from_numpy(images)#images = images.unsqueeze(dim=0).float()return imagespred, targets, loss, dt = [], [], 0, [0.0, 0.0, 0.0]
# current batch size =1
for i, (images, labels) in enumerate(dataset):print("i:", i)im = preprocess(images)images = im.unsqueeze(0).to("cpu").float()print(images.shape)t1 = time_sync()images = images.to(device, non_blocking=True)t2 = time_sync()# dt[0] += t2 - t1y = model(images)y=y.numpy()#print("y:", y)t3 = time_sync()# dt[1] += t3 - t2#batch size >1 图像推理结果是二维的#y: [[     4.0855     -1.1707     -1.4998      -0.935     -1.9979      -2.258     -1.4691     -1.0867     -1.9042    -0.99979]]tmp1=y.argsort()[:,::-1][:, :5]#batch size =1 图像推理结果是一维的, 就要处理下argsort的维度#y: [     3.7441      -1.135     -1.1293     -0.9422     -1.6029     -2.0561      -1.025     -1.5842     -1.3952     -1.1824]#print("tmp1:", tmp1)pred.append(tmp1)#print("labels:", labels)targets.append(labels)#print("for pred:", pred)  # list#print("for targets:", targets)  # list# dt[2] += time_sync() - t3pred, targets = np.concatenate(pred), np.array(targets)
print("pred:", pred)
print("pred:", pred.shape)
print("targets:", targets)
print("targets:", targets.shape)
correct = ((targets[:, None] == pred)).astype(np.float32)
print("correct:", correct.shape)
print("correct:", correct)
acc = np.stack((correct[:, 0], correct.max(1)), axis=1)  # (top1, top5) accuracy
print("acc:", acc.shape)
print("acc:", acc)
top = acc.mean(0)
print("top1:", top[0])
print("top5:", top[1])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/161238.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

scrapy自定义日志

自定义日志系统 首先,在Scrapy的settings.py文件中添加以下代码: LOG_LEVEL DEBUG # 日志级别 LOG_FILE /path/to/logfile.log # 日志文件路径 LOG_ENABLED True # 是否启用日志 LOG_STDOUT False # 是否输出到标准输出这些设置将指定Scrapy日…

【PHP】PHP生成全年日历

👉博__主👈:米码收割机 👉技__能👈:C/Python语言 👉公众号👈:测试开发自动化【获取源码商业合作】 👉荣__誉👈:阿里云博客专家博主、5…

5-8输出水仙花数

#include<stdio.h> int main(){int i,j,k;int n;for(n100;n<1000;n){in/100;jn/10-i*10;kn%10;if(ni*i*ij*j*jk*k*k)printf("%d ",n);}printf("\n");return 0; }

Dubbo从入门到上天系列第十八篇:Dubbo引入注册中心简介以及DubboAdmin简要介绍,为后续详解Dubbo各种注册中心做铺垫!

一&#xff1a;Dubbo注册中心引言 1&#xff1a;什么是Dubbo的注册中心&#xff1f; Dubbo注册中心是Dubbo服务治理中极其重要的一个概念。它主要是用于对Rpc集群应用实例进行管理。 对于我们的Dubbo服务来讲&#xff0c;至少有两部分构成&#xff0c;一部分是Provider一部分是…

uniapp开发小程序-如何判断小程序是在手机端还是pc端打开

官方说明 https://developers.weixin.qq.com/miniprogram/dev/devtools/pc-dev.html 小程序如何判断是 PC 平台&#xff1f; 通过 getSystemInfo 官方接口&#xff08;platform 是 windows&#xff09; 通过 UA&#xff08;PC UA 包含 MiniProgramEnv/Windows&#xff09; …

section header

section header table 是一个section header的集合&#xff0c;每个section header是一个描述section的结构体。在同一个ELF文件中&#xff0c;每个section header大小是相同的。 每个section都有一个section header描述它&#xff0c;但是一个section header可能在文件中没有…

云计算实验如何结合AI来提高效率!

随着AI助手的流行&#xff0c;我们现在无论是学习还是工作都会带着一个他/她&#xff0c;如何让AI助手提高我们的工作效率是我们需要进化的方向。下面结合“云计算实验”来分享一下如何让AI帮助我们学得更快学得更好。 一、学习某个软件或复杂命令 比如在学习RockyLinux9.2中…

Android Spannable 使用​注意事项

1、当前示例中间的 "评论"&#xff0c;使用SpannableStringBuilder实现&#xff0c;点击评论会有高亮效果加粗&#xff0c;但再点击其它Bar时无法恢复默认样式。 2、因为SpannableString或SpannableStringBuilder中的效果是叠加的&#xff0c;恢复默认样式需要先移除…

【Qt-25】控件篇

一、comboBox控件 1、获取item数量 ui->comboBox_2->count(); 2、根据索引值获取文本 ui->comboBox->itemText(i); 3、调整当前显示文本内容 ui->comboBox->setCurrentIndex(j); 4、添加item ui->comboBox->addItem("");//添加一个内…

基于SSM的济南旅游网站设计与实现

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;Vue 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#xff1a;是 目录…

汽车级全保护型六路半桥驱动器NCV7708FDWR2G 原理、参数及应用

NCV7708FDWR2G 是一款全保护型六路半桥驱动器&#xff0c;特别适用于汽车和工业运动控制应用。六个高压侧和低压侧驱动器可自由配置&#xff0c;也可单独控制。因此可实现高压侧、低压侧和 H 桥控制。H 桥控制提供正向、逆向、制动和高阻抗状态。驱动器通过标准 SPI 接口进行控…

python 基于gdal,richdem,pysheds实现 实现洼填、D8流向,汇流累计量计算,河网连接,分水岭及其水文分析与斜坡单元生成

python gdal实现水文分析算法及其斜坡单元生成 实现洼填、D8流向,汇流累计量计算,河网连接,分水岭 # utf-8 import richdem as rdre from River import * from pysheds.grid import Grid import time from time import time,sleep import numpy as np from osgeo import g…

【Pytorch】Visualization of Fature Maps(2)

学习参考来自 使用CNN在MNIST上实现简单的攻击样本https://github.com/wmn7/ML_Practice/blob/master/2019_06_03/CNN_MNIST%E5%8F%AF%E8%A7%86%E5%8C%96.ipynb 文章目录 在 MNIST 上实现简单的攻击样本1 训练一个数字分类网络2 控制输出的概率, 看输入是什么3 让正确的图片分…

分类预测 | Matlab实现基于DBN-SVM深度置信网络-支持向量机的数据分类预测

分类预测 | Matlab实现基于DBN-SVM深度置信网络-支持向量机的数据分类预测 目录 分类预测 | Matlab实现基于DBN-SVM深度置信网络-支持向量机的数据分类预测分类效果基本描述程序设计参考资料 分类效果 基本描述 1.利用DBN进行特征提取&#xff0c;将提取后的特征放入SVM进行分类…

vue中 多个请求,如果一个请出错,页面继续执行

vue中 多个请求&#xff0c;如果一个请出错&#xff0c;页面继续执行 在Vue中&#xff0c;可以通过Promise.all()方法来处理多个请求&#xff0c;即使其中一个请求出错&#xff0c;页面也可以继续执行其他的逻辑。 下面是一个示例代码&#xff0c;演示了如何在Vue中处理多个请…

Cookie与Session

文章目录 Cookie的介绍Cookie的由来什么是CookieCookie原理Cookie覆盖浏览器查看Cookie 在Django中操作Cookie设置Cookie查询浏览器携带的Cookie删除Cookie Cookie校验登录session Cookie的介绍 Cookie的由来 首先我们都应该明白HTTP协议是无连接的。 无状态的意思是每次请求…

CSS特效016:天窗扬起合上的效果

CSS常用示例100专栏目录 本专栏记录的是经常使用的CSS示例与技巧&#xff0c;主要包含CSS布局&#xff0c;CSS特效&#xff0c;CSS花边信息三部分内容。其中CSS布局主要是列出一些常用的CSS布局信息点&#xff0c;CSS特效主要是一些动画示例&#xff0c;CSS花边是描述了一些CSS…

【c++Leetcode】206. Reverse Linked List

问题入口 time complexity: O(n), space complexity:O(1) ListNode* reverseList(ListNode* head) {ListNode* prev nullptr;ListNode* curr head;while(curr){ListNode* forward curr->next;curr->next prev;prev curr;curr forward;}return prev; } time comp…

虹科Pico汽车示波器 | 汽车免拆检修 | 2017款东风本田XR-V车转向助力左右不一致

一、故障现象 一辆2017款东风本田XR-V车&#xff0c;搭载R18ZA发动机&#xff0c;累计行驶里程约为4万km。车主反映&#xff0c;车辆行驶或静止时&#xff0c;向右侧转向比向左侧转向沉重。 二、故障诊断 接车后试车&#xff0c;起动发动机&#xff0c;组合仪表上无故障灯点亮&…

数据仓库岗面试

1.自我介绍 2.求用户连续登录3天&#xff0c;要讲出多种解法 解法1&#xff08;使用SQL&#xff09;&#xff1a; SELECTuserid FROMloginrecord WHEREDATEDIFF(day, time, LAG(time) OVER (PARTITION BY userid ORDER BY time)) 1AND DATEDIFF(day, LAG(time) OVER (PARTI…