YOLOv5 分类模型 预处理 OpenCV实现

YOLOv5 分类模型 预处理 OpenCV实现

flyfish

YOLOv5 分类模型 预处理 PIL 实现
YOLOv5 分类模型 OpenCV和PIL两者实现预处理的差异

YOLOv5 分类模型 数据集加载 1 样本处理
YOLOv5 分类模型 数据集加载 2 切片处理
YOLOv5 分类模型 数据集加载 3 自定义类别

YOLOv5 分类模型的预处理(1) Resize 和 CenterCrop
YOLOv5 分类模型的预处理(2)ToTensor 和 Normalize

YOLOv5 分类模型 Top 1和Top 5 指标说明
YOLOv5 分类模型 Top 1和Top 5 指标实现

判断图像是否是np.ndarray类型和维度

OpenCV读取一张图像时,类型类型就是<class 'numpy.ndarray'>,这里判断图像是否是np.ndarray类型
dim是dimension维度的缩写,shape属性的长度也是它的ndim
灰度图的shape为HW,二个维度
RGB图的shape为HWC,三个维度
在这里插入图片描述

def _is_numpy_image(img):return isinstance(img, np.ndarray) and (img.ndim in {2, 3})

实现ToTensor和Normalize

def totensor_normalize(img):print("preprocess:",img.shape)images = (img/255-mean)/stdimages = images.transpose((2, 0, 1))# HWC to CHWimages = np.ascontiguousarray(images)return images

实现Resize

插值可以是以下参数

# 'nearest': cv2.INTER_NEAREST,
# 'bilinear': cv2.INTER_LINEAR,
# 'area': cv2.INTER_AREA,
# 'bicubic': cv2.INTER_CUBIC,
# 'lanczos': cv2.INTER_LANCZOS4
def resize(img, size, interpolation=cv2.INTER_LINEAR):r"""Resize the input numpy ndarray to the given size.Args:img (numpy ndarray): Image to be resized.size: like pytroch about size interpretation flyfish.interpolation (int, optional): Desired interpolation. Default is``cv2.INTER_LINEAR``  Returns:numpy Image: Resized image.like opencv"""if not _is_numpy_image(img):raise TypeError('img should be numpy image. Got {}'.format(type(img)))if not (isinstance(size, int) or (isinstance(size, collections.abc.Iterable) and len(size) == 2)):raise TypeError('Got inappropriate size arg: {}'.format(size))h, w = img.shape[0], img.shape[1]if isinstance(size, int):if (w <= h and w == size) or (h <= w and h == size):return imgif w < h:ow = sizeoh = int(size * h / w)else:oh = sizeow = int(size * w / h)else:ow, oh = size[1], size[0]output = cv2.resize(img, dsize=(ow, oh), interpolation=interpolation)if img.shape[2] == 1:return output[:, :, np.newaxis]else:return output

实现CenterCrop

def crop(img, i, j, h, w):"""Crop the given Image flyfish.Args:img (numpy ndarray): Image to be cropped.i: Upper pixel coordinate.j: Left pixel coordinate.h: Height of the cropped image.w: Width of the cropped image.Returns:numpy ndarray: Cropped image."""if not _is_numpy_image(img):raise TypeError('img should be numpy image. Got {}'.format(type(img)))return img[i:i + h, j:j + w, :]def center_crop(img, output_size):if isinstance(output_size, numbers.Number):output_size = (int(output_size), int(output_size))h, w = img.shape[0:2]th, tw = output_sizei = int(round((h - th) / 2.))j = int(round((w - tw) / 2.))return crop(img, i, j, th, tw)

完整

import time
from models.common import DetectMultiBackend
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
import cv2
import numpy as np
import collections
import torch
import numbersclasses_name=['n02086240', 'n02087394', 'n02088364', 'n02089973', 'n02093754', 'n02096294', 'n02099601', 'n02105641', 'n02111889', 'n02115641']mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]def _is_numpy_image(img):return isinstance(img, np.ndarray) and (img.ndim in {2, 3})def totensor_normalize(img):print("preprocess:",img.shape)images = (img/255-mean)/stdimages = images.transpose((2, 0, 1))# HWC to CHWimages = np.ascontiguousarray(images)return imagesdef resize(img, size, interpolation=cv2.INTER_LINEAR):r"""Resize the input numpy ndarray to the given size.Args:img (numpy ndarray): Image to be resized.size: like pytroch about size interpretation flyfish.interpolation (int, optional): Desired interpolation. Default is``cv2.INTER_LINEAR``  Returns:numpy Image: Resized image.like opencv"""if not _is_numpy_image(img):raise TypeError('img should be numpy image. Got {}'.format(type(img)))if not (isinstance(size, int) or (isinstance(size, collections.abc.Iterable) and len(size) == 2)):raise TypeError('Got inappropriate size arg: {}'.format(size))h, w = img.shape[0], img.shape[1]if isinstance(size, int):if (w <= h and w == size) or (h <= w and h == size):return imgif w < h:ow = sizeoh = int(size * h / w)else:oh = sizeow = int(size * w / h)else:ow, oh = size[1], size[0]output = cv2.resize(img, dsize=(ow, oh), interpolation=interpolation)if img.shape[2] == 1:return output[:, :, np.newaxis]else:return outputdef crop(img, i, j, h, w):"""Crop the given Image flyfish.Args:img (numpy ndarray): Image to be cropped.i: Upper pixel coordinate.j: Left pixel coordinate.h: Height of the cropped image.w: Width of the cropped image.Returns:numpy ndarray: Cropped image."""if not _is_numpy_image(img):raise TypeError('img should be numpy image. Got {}'.format(type(img)))return img[i:i + h, j:j + w, :]def center_crop(img, output_size):if isinstance(output_size, numbers.Number):output_size = (int(output_size), int(output_size))h, w = img.shape[0:2]th, tw = output_sizei = int(round((h - th) / 2.))j = int(round((w - tw) / 2.))return crop(img, i, j, th, tw)class DatasetFolder:def __init__(self,root: str,) -> None:self.root = rootif classes_name is None or not classes_name:classes, class_to_idx = self.find_classes(self.root)print("not classes_name")else:classes = classes_nameclass_to_idx ={cls_name: i for i, cls_name in enumerate(classes)}print("is classes_name")print("classes:",classes)print("class_to_idx:",class_to_idx)samples = self.make_dataset(self.root, class_to_idx)self.classes = classesself.class_to_idx = class_to_idxself.samples = samplesself.targets = [s[1] for s in samples]@staticmethoddef make_dataset(directory: str,class_to_idx: Optional[Dict[str, int]] = None,) -> List[Tuple[str, int]]:directory = os.path.expanduser(directory)if class_to_idx is None:_, class_to_idx = self.find_classes(directory)elif not class_to_idx:raise ValueError("'class_to_index' must have at least one entry to collect any samples.")instances = []available_classes = set()for target_class in sorted(class_to_idx.keys()):class_index = class_to_idx[target_class]target_dir = os.path.join(directory, target_class)if not os.path.isdir(target_dir):continuefor root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):for fname in sorted(fnames):path = os.path.join(root, fname)if 1:  # 验证:item = path, class_indexinstances.append(item)if target_class not in available_classes:available_classes.add(target_class)empty_classes = set(class_to_idx.keys()) - available_classesif empty_classes:msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "return instancesdef find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())if not classes:raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}return classes, class_to_idxdef __getitem__(self, index: int) -> Tuple[Any, Any]:path, target = self.samples[index]sample = self.loader(path)return sample, targetdef __len__(self) -> int:return len(self.samples)def loader(self, path):print("path:", path)img = cv2.imread(path)  # BGR HWCimg=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)#RGBprint("type:",type(img))return imgdef time_sync():return time.time()dataset = DatasetFolder(root="/media/flyfish/datasets/imagewoof/val")
weights = "/home/classes.pt"
device = "cpu"
model = DetectMultiBackend(weights, device=device, dnn=False, fp16=False)
model.eval()def classify_transforms(img):img=resize(img,224)img=center_crop(img,224)img=totensor_normalize(img)return img;pred, targets, loss, dt = [], [], 0, [0.0, 0.0, 0.0]
# current batch size =1
for i, (images, labels) in enumerate(dataset):print("i:", i)print(images.shape, labels)im = classify_transforms(images)images=torch.from_numpy(im).to(torch.float32) # numpy to tensorimages = images.unsqueeze(0).to("cpu")print(images.shape)t1 = time_sync()images = images.to(device, non_blocking=True)t2 = time_sync()# dt[0] += t2 - t1y = model(images)y=y.numpy()print("y:", y)t3 = time_sync()# dt[1] += t3 - t2tmp1=y.argsort()[:,::-1][:, :5]print("tmp1:", tmp1)pred.append(tmp1)print("labels:", labels)targets.append(labels)print("for pred:", pred)  # listprint("for targets:", targets)  # list# dt[2] += time_sync() - t3pred, targets = np.concatenate(pred), np.array(targets)
print("pred:", pred)
print("pred:", pred.shape)
print("targets:", targets)
print("targets:", targets.shape)
correct = ((targets[:, None] == pred)).astype(np.float32)
print("correct:", correct.shape)
print("correct:", correct)
acc = np.stack((correct[:, 0], correct.max(1)), axis=1)  # (top1, top5) accuracy
print("acc:", acc.shape)
print("acc:", acc)
top = acc.mean(0)
print("top1:", top[0])
print("top5:", top[1])

结果

pred: [[0 3 6 2 1][0 7 2 9 3][0 5 6 2 9]...[9 8 7 6 1][9 3 6 7 0][9 5 0 2 7]]
pred: (3929, 5)
targets: [0 0 0 ... 9 9 9]
targets: (3929,)
correct: (3929, 5)
correct: [[          1           0           0           0           0][          1           0           0           0           0][          1           0           0           0           0]...[          1           0           0           0           0][          1           0           0           0           0][          1           0           0           0           0]]
acc: (3929, 2)
acc: [[          1           1][          1           1][          1           1]...[          1           1][          1           1][          1           1]]
top1: 0.86230594
top5: 0.98167473

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/166454.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

关于python 语音转字幕,字幕转语音大杂烩

文字转语音 Python语音合成之第三方库gTTs/pyttsx3/speech横评(内附使用方法)_python_脚本之家 代码示例 from gtts import gTTStts gTTS(你好你在哪儿&#xff01;,langzh-CN)tts.save(hello.mp3)import pyttsx3engine pyttsx3.init() #创建对象"""语速"…

目前比较好用的护眼台灯,小学生适合的护眼台灯推荐

随着技术的发展&#xff0c;灯光早已成为每家每户都需要的东西。但是灯光不好可能会对眼睛造成伤害是很多人没有注意到的。现在随着护眼灯产品越来越多&#xff0c;市场上台灯的选择越来越多样化&#xff0c;如何选择一个对眼睛无伤害、无辐射的台灯成为许多家长首先要考虑的问…

【C++初阶】四、类和对象(构造函数、析构函数、拷贝构造函数、赋值运算符重载函数)

相关代码gitee自取&#xff1a; C语言学习日记: 加油努力 (gitee.com) 接上期&#xff1a; 【C初阶】三、类和对象 &#xff08;面向过程、class类、类的访问限定符和封装、类的实例化、类对象模型、this指针&#xff09; -CSDN博客 引入&#xff1a;类的六个默认成员函数…

如何使用springboot服务端接口公网远程调试——实现HTTP服务监听

&#x1f308;个人主页&#xff1a;聆风吟 &#x1f525;系列专栏&#xff1a;网络奇遇记、Cpolar杂谈 &#x1f516;少年有梦不应止于心动&#xff0c;更要付诸行动。 文章目录 &#x1f4cb;前言一. 本地环境搭建1.1 环境参数1.2 搭建springboot服务项目 二. 内网穿透2.1 安装…

ATA-2042高压放大器在细胞的剪切应力传感器研究中的应用

微流控技术是一种通过微小的通道和微型装置对流体进行精确操控和分析的技术。它是现代医学技术发展过程中的一种重要的生物医学工程技术&#xff0c;具有广泛的应用前景和重要性。它在高通量分析、个性化医疗、细胞筛选等方面有着巨大的潜力&#xff0c;Aigtek安泰电子今天就将…

HR8833 双通道H桥电机驱动芯片

HR8833为玩具、打印机和其它电机一T化应用提供一种双通道电机驱动方案。HR8833提供两种封装&#xff0c;一种是带有L露焊盘的TSSOP-16封装&#xff0c;能改进散热性能&#xff0c;且是无铅产品&#xff0c;引脚框采用100&#xff05;无锡电镀。另一种封装为SOP16&#xff0c;不…

智驾芯片全矩阵「曝光」,这家企业的车载品牌正式官宣

随着汽车智能化加速&#xff0c;智能驾驶芯片格局逐渐清晰。 针对L0-L2&#xff0c;业内基本采用智能前视一体机方案&#xff1b;要实现高速NOA、城市NOA等更为高阶的智驾功能等&#xff0c;则基本采用域控制器方案。从前视一体机至域控&#xff0c;再逐步演进到舱驾一体、中央…

python基于DETR(DEtection TRansformer)开发构建钢铁产业产品智能自动化检测识别系统

在前文中我们基于经典的YOLOv5开发构建了钢铁产业产品智能自动化检测识别系统&#xff0c;这里本文的主要目的是想要实践应用DETR这一端到端的检测模型来开发构建钢铁产业产品智能自动化检测识别系统。 DETR (DEtection TRansformer) 是一种基于Transformer架构的端到端目标检…

【Django使用】10大章31模块md文档,第5篇:Django模板和数据库使用

当你考虑开发现代化、高效且可扩展的网站和Web应用时&#xff0c;Django是一个强大的选择。Django是一个流行的开源Python Web框架&#xff0c;它提供了一个坚实的基础&#xff0c;帮助开发者快速构建功能丰富且高度定制的Web应用 全套Django笔记直接地址&#xff1a; 请移步这…

外汇天眼:多名投资者账户被恶意清空,远离volofinance!

最近&#xff0c;外汇平台volofinance因有多名投资者投诉&#xff0c;“荣幸”成为外汇天眼黑平台榜单中的一员&#xff0c;那么volofinance到底做了什么导致投资者前来投诉曝光呢&#xff1f; 起底volofinace 在网络搜索中&#xff0c;关于volofinance的信息少之又少&#xf…

成为AI产品经理——模型评估指标

目录 一、模型评估分类 1.在线评估 2.离线评估 二、离线模型评估 1.特征评估 ① 特征自身稳定性 ② 特征来源稳定性 ③ 特征成本 2.模型评估 ① 统计性评估 覆盖度 最大值、最小值 分布形态 ② 模型性能指标 分类问题 回归问题 ③ 模型的稳定性 模型评估指标分…

配置mvn打包参数,不同环境使用不同的配置文件

方法一&#xff1a; 首先在/resource目录下创建各自环境的配置 要在不同的环境中使用不同的配置文件进行Maven打包&#xff0c;可以使用Maven的profiles特性和资源过滤功能。下面是配置Maven打包参数的步骤&#xff1a; 在项目的pom.xml文件中&#xff0c;添加profiles配置…

第一个Mybatis项目

&#xff08;一&#xff09;为什么要用Mybatis? &#xff08;1&#xff09;Mybatis对比JDBC而言&#xff0c;sql&#xff08;单独写在xml的配置文件中&#xff09;和java编码分开&#xff0c;功能边界清晰&#xff0c;一个专注业务&#xff0c;一个专注数据。 &#xff08;2&…

【C++】:多态

朋友们、伙计们&#xff0c;我们又见面了&#xff0c;本期来给大家解读一下有关多态的知识点&#xff0c;如果看完之后对你有一定的启发&#xff0c;那么请留下你的三连&#xff0c;祝大家心想事成&#xff01; C 语 言 专 栏&#xff1a;C语言&#xff1a;从入门到精通 数据结…

Linux(CentOS7)上安装mysql

在CentOS中默认安装有MariaDB&#xff08;MySQL的一个分支&#xff09;&#xff0c;可先移除/卸载MariaDB。 yum remove mariadb // 查看是否存在mariadb rpm -qa|grep -i mariadb // 卸载 mariadb rpm -e --nodeps rpm -qa|grep mariadb yum安装 下载rpm // 5.6版本 wge…

XML映射文件

<?xml version"1.0" encoding"UTF-8" ?> <!DOCTYPE mapperPUBLIC "-//mybatis.org//DTD Mapper 3.0//EN""http://mybatis.org/dtd/mybatis-3-mapper.dtd"> <mapper namespace"org.mybatis.example.BlogMapper&q…

conan 入门(三十二):package_info中配置禁用CMakeDeps生成使用项目自己生成的config.cmake

conanfile.py中定义的package_info()方法用于向package的调用者(conumer)提供包库名&#xff0c;编译/连接选项&#xff0c;文件夹等等信息&#xff0c;有了这些信息构建工具的generator就可以根据它们生成对应的文件&#xff0c;用于调用者引用package. 比如基于cmake的CMakeD…

安全地公网访问树莓派等设备的服务 内网穿透--frp 23年11月方法

如果想要树莓派可以被公网访问&#xff0c;可以选择直接网上搜内网穿透提供商&#xff0c;一个月大概10块钱&#xff0c;也有免费的&#xff0c;但是免费的速度就不要希望很好了。 也可以选择接下来介绍的frp&#xff0c;这种方式不需要付费&#xff0c;但是需要你有一台有着公…

vue3自定义拖拽指令

<template><div v-move class"box"></div> </template><script setup lang"ts"> import { Directive } from vue const vMove:Directive (el:HTMLElement) >{const mousedown (e:MouseEvent) >{// 鼠标按下const s…

【Golang】解决使用interface{}解析json数字会变成科学计数法的问题

在使用解析json结构体的时候&#xff0c;使用interface{}接数字会发现变成了科学计数法格式的数字&#xff0c;不符合实际场景的使用要求。 举例代码如下&#xff1a; type JsonUnmStruct struct {Id interface{} json:"id"Name string json:"name"…