YOLOv5 分类模型 Top 1和Top 5 指标实现

YOLOv5 分类模型 Top 1和Top 5 指标实现

flyfish

import time
from models.common import DetectMultiBackend
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
import cv2
import numpy as npimport torch
from utils.augmentations import classify_transformsclass DatasetFolder:def __init__(self,root: str,) -> None:self.root = rootclasses, class_to_idx = self.find_classes(self.root)samples = self.make_dataset(self.root, class_to_idx)self.classes = classesself.class_to_idx = class_to_idxself.samples = samplesself.targets = [s[1] for s in samples]@staticmethoddef make_dataset(directory: str,class_to_idx: Optional[Dict[str, int]] = None,) -> List[Tuple[str, int]]:directory = os.path.expanduser(directory)if class_to_idx is None:_, class_to_idx = self.find_classes(directory)elif not class_to_idx:raise ValueError("'class_to_index' must have at least one entry to collect any samples.")instances = []available_classes = set()for target_class in sorted(class_to_idx.keys()):class_index = class_to_idx[target_class]target_dir = os.path.join(directory, target_class)if not os.path.isdir(target_dir):continuefor root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):for fname in sorted(fnames):path = os.path.join(root, fname)if 1:  # 验证:item = path, class_indexinstances.append(item)if target_class not in available_classes:available_classes.add(target_class)empty_classes = set(class_to_idx.keys()) - available_classesif empty_classes:msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "return instancesdef find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())if not classes:raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}return classes, class_to_idxdef __getitem__(self, index: int) -> Tuple[Any, Any]:path, target = self.samples[index]sample = self.loader(path)return sample, targetdef __len__(self) -> int:return len(self.samples)def loader(self, path):print("path:", path)img = cv2.imread(path)  # BGR HWCreturn imgdef time_sync():return time.time()dataset = DatasetFolder(root="/media/flyfish/test/val")# image, label=dataset[7]
# print(image.shape)
#
weights = "/media/flyfish/yolov5-6.2/classes10.pt"
device = "cpu"
model = DetectMultiBackend(weights, device=device, dnn=False, fp16=False)
model.eval()transforms = classify_transforms(224)pred, targets, loss, dt = [], [], 0, [0.0, 0.0, 0.0]
# current batch size =1
for i, (images, labels) in enumerate(dataset):print("i:", i)print(images.shape, labels)im = cv2.cvtColor(images, cv2.COLOR_BGR2RGB)im = transforms(im)images = im.unsqueeze(0).to("cpu")print(images.shape)t1 = time_sync()images = images.to(device, non_blocking=True)t2 = time_sync()# dt[0] += t2 - t1y = model(images)y=y.numpy()print("y:", y)t3 = time_sync()# dt[1] += t3 - t2tmp1=y.argsort()[:,::-1][:, :5]print("tmp1:", tmp1)pred.append(tmp1)print("labels:", labels)targets.append(labels)print("for pred:", pred)  # listprint("for targets:", targets)  # list# dt[2] += time_sync() - t3pred, targets = np.concatenate(pred), np.array(targets)
print("pred:", pred)
print("pred:", pred.shape)
print("targets:", targets)
print("targets:", targets.shape)
correct = ((targets[:, None] == pred)).astype(np.float32)
print("correct:", correct.shape)
print("correct:", correct)
acc = np.stack((correct[:, 0], correct.max(1)), axis=1)  # (top1, top5) accuracy
print("acc:", acc.shape)
print("acc:", acc)
top = acc.mean(0)
print("top1:", top[0])
print("top5:", top[1])

输出

pred: [[7 4 0 5 9][9 2 4 6 7][8 9 6 2 1][8 9 6 2 7][9 2 4 6 3][6 7 1 2 9][4 2 1 8 9][6 8 9 5 2][8 7 4 2 6][9 8 2 6 4][2 9 8 0 6][7 4 8 6 3]]
pred: (12, 5)
targets: [0 0 0 0 1 1 1 1 2 2 2 2]
targets: (12,)
correct: (12, 5)
correct: [[          0           0           1           0           0][          0           0           0           0           0][          0           0           0           0           0][          0           0           0           0           0][          0           0           0           0           0][          0           0           1           0           0][          0           0           1           0           0][          0           0           0           0           0][          0           0           0           1           0][          0           0           1           0           0][          1           0           0           0           0][          0           0           0           0           0]]
acc: (12, 2)
acc: [[          0           1][          0           0][          0           0][          0           0][          0           0][          0           1][          0           1][          0           0][          0           1][          0           1][          1           1][          0           0]]
top1: 0.083333336
top5: 0.5

Yolov5 6.2 原版输出

pred: tensor([[6, 7, 1, 2, 9],[9, 2, 4, 6, 3],[7, 4, 0, 5, 9],[9, 8, 2, 6, 4],[6, 8, 9, 5, 2],[8, 7, 4, 2, 6],[9, 2, 4, 6, 7],[2, 9, 8, 0, 6],[8, 9, 6, 2, 7],[7, 4, 8, 6, 3],[4, 2, 1, 8, 9],[8, 9, 6, 2, 1]])
pred: torch.Size([12, 5])
targets: tensor([1, 1, 0, 2, 1, 2, 0, 2, 0, 2, 1, 0])
targets: torch.Size([12])
correct: torch.Size([12, 5])
acc: torch.Size([12, 2])
top1: 0.0833333358168602
top5: 0.5

文本代码是按照标签,即文件夹名字排序的,pred和target都是一一对应的,与Yolov5 6.2 原版相同

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/145530.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

集合对象的几种初始化方式

最简单的方式是通过new构建一个对象&#xff0c;构建对象后再进行赋值&#xff1a; // 通过new创建并赋值 List<Integer> list new ArrayList<>(); list.add(1); list.add(2); list.add(3);Set<Integer> set new HashSet<>(); set.add(1); set.add(…

小米路由器4A千兆版刷入OpenWRT并远程访问

小米路由器4A千兆版刷入OpenWRT并远程访问 文章目录 小米路由器4A千兆版刷入OpenWRT并远程访问前言1. 安装Python和需要的库2. 使用 OpenWRTInvasion 破解路由器3. 备份当前分区并刷入新的Breed4. 安装cpolar内网穿透4.1 注册账号4.2 下载cpolar客户端4.3 登录cpolar web ui管理…

SpringBoot-配置文件properties/yml分析+tomcat最大连接数及最大并发数

SpringBoot配置文件 yaml 中的数据是有序的&#xff0c;properties 中的数据是无序的&#xff0c;在一些需要路径匹配的配置中&#xff0c;顺序就显得尤为重要&#xff08;例如在 Spring Cloud Zuul 中的配置&#xff09;&#xff0c;此时一般采用 yaml。 Properties ①、位…

0基础如何学习软件测试?10分钟给你安排明白

先上一张学习路线&#xff1a; 在测试行业已经呆了5年多了&#xff0c;也算得上行业经验资深了吧&#xff0c;基本上也是摸清了这个行业的发展。 所以今天也想对有转行想法的朋友分享一下经验&#xff0c;能够让你对这个行业有个大致的了解和对以后的发展有所规划&#xff0c;…

基础工具类

IDate import java.text.SimpleDateFormat import java.util.{Calendar, Date}trait IDate extends Serializable {def onDate(dt: String): Unitdef invoke(dt: String, dt1: String) {if (dt1 < dt) {throw new IllegalArgumentException(s"dt1:${dt1}小于dt:${dt}…

Java设计模式-创建型模式-原型模式

原型模式 原型模式浅拷贝深拷贝 原型模式 要求&#xff1a;以一个已经创建的对象为原型&#xff0c;复制一个新的对象 使用场景&#xff1a; 创建对象的成本比较大的时候&#xff08;如从耗时较长的计算或者从查询耗时长的RPC接口获取数据&#xff09;&#xff0c;直接拷贝已…

双向链表的知识点+例题

1.链表的种类 题中常考查以下两种&#xff1a; 上一讲我们学了无头单向非循环链表&#xff0c;这节&#xff0c;让我们看一下双向链表的操作吧~ 2基本操作 1&#xff0c;定义双向链表 2&#xff0c;创建一个节点 3&#xff0c;初始化双链表 4&#xff0c;尾插一个节点 5打印…

全球温度数据下载

1.全球年平均温度下载https://www.ncei.noaa.gov/data/global-summary-of-the-year/archive/ 2.全球月平均气温下载https://www.ncei.noaa.gov/data/global-summary-of-the-month/archive/ 3.全球日平均气温下载https://www.ncei.noaa.gov/data/global-summary-of-the-day/ar…

一、认识STM32

目录 一、初识STM32 1.1 STM32的命名规则介绍 1.2 STM32F103ZET6资源配置介绍 二、如何识别芯片管脚 2.1 如何寻找 IO 的功能说明 三、构成最小系统的要素 一、初识STM32 1.1 STM32的命名规则介绍 以 STM32F103ZET6 来讲解下 STM32 的命名方法&#xff1a; &…

.Net8 Blazor 尝鲜

全栈 Web UI 随着 .NET 8 的发布&#xff0c;Blazor 已成为全堆栈 Web UI 框架&#xff0c;可用于开发在组件或页面级别呈现内容的应用&#xff0c;其中包含&#xff1a; 用于生成静态 HTML 的静态服务器呈现。使用 Blazor Server 托管模型的交互式服务器呈现。使用 Blazor W…

忆联消费级SSD AH660:将用户体验推向新高度

自1989年IBM推出世界上第一款固态硬盘&#xff08;SSD&#xff09;以来&#xff0c;SSD在三十多年的时间中历经了多次技术革新和市场变革&#xff0c;早已成为个人电脑、汽车电子、数据中心、物联网终端等领域的主流存储产品&#xff0c;并广泛应用于各行各业&#xff0c;在202…

node 第十八天 中间件express-session实现会话密钥

express-session 文档 express-session 一个简单的express会话中间件 使用场景 在一个系统中&#xff0c; 需要维持一个临时的与登录态无关的会话密钥 比如登录系统后&#xff0c; 请求某一个接口&#xff0c; 接口的行为与登录态无关&#xff0c; 也就是说任何人对接口的访问…

【JAVA-排列组合】一个套路速解排列组合题

说明 在初遇排列组合题目时&#xff0c;总让人摸不着头脑&#xff0c;但是做多了题目后&#xff0c;发现几乎能用同一个模板做完所有这种类型的题目&#xff0c;大大提高了解题效率。本文简要介绍这种方法。 题目列表 所有题目均从leetcode查找&#xff0c;便于在线验证 46.…

C语言判断素数(ZZULIOJ1057:素数判定)

题目描述 输入一个正整数n&#xff0c;判断n是否是素数&#xff0c;若n是素数&#xff0c;输出”Yes”,否则输出”No”。 注意&#xff1a;1不是素数。 输入&#xff1a;输入一个正整数n(n<1000) 输出&#xff1a;如果n是素数输出"Yes"&#xff0c;否则输出"…

spark性能调优 | 默认并行度

Spark Sql默认并行度 看官网&#xff0c;默认并行度200 https://spark.apache.org/docs/2.4.5/sql-performance-tuning.html#other-configuration-options 优化 在数仓中 task最好是cpu的两倍或者3倍(最好是倍数&#xff0c;不要使基数) 拓展 在本地 task需要自己设置&a…

如何使用Matplotlib模块的text()函数给柱形图添加美丽的标签数据?

如何使用Matplotlib模块的text函数给柱形图添加美丽的标签数据&#xff1f; 1 简单引入2 关于text()函数2.1 Matplotlib安装2.2 text()引入2.3 text()源码2.4 text()参数说明2.5 text()两个简单示例 3 柱形图绘制并添加标签3.1 目标数据3.2 读取excel数据3.3 设置窗口大小和xy轴…

如何提升软件测试效率?本文为你揭示秘密

在软件开发中&#xff0c;测试是至关重要的一个环节。它能帮助我们发现并修复问题&#xff0c;从而确保我们提供的软件具有高质量。然而&#xff0c;测试过程往往费时费力。那么&#xff0c;有没有方法可以提升我们的软件测试效率呢&#xff1f;答案是肯定的。下面&#xff0c;…

骨传导耳机品牌排名前十,盘点最受欢迎的五款TOP级骨传导耳机

骨传导耳机品牌排名前十&#xff0c;最受欢迎的五款TOP级骨传导耳机是什么&#xff1f; 耳机市场上有很多品牌和型号的骨传导耳机&#xff0c;每个人对耳机的需求和使用场景也不尽相同。因此&#xff0c;在选择耳机时&#xff0c;确实不能盲目跟风或者仅仅看重品牌。为了帮助大…

spring cloud之配置中心

Config 统一配置中心(*) 1.简介 # 统一配置中心 - 官网:https://cloud.spring.io/spring-cloud-static/spring-cloud-config/2.2.3.RELEASE/reference/html/#_spring_cloud_config_server- config 分为 config server 和 config client。用来统一管理所有微服务的配置统一配置…