8.5.tensorRT高级(3)封装系列-基于生产者消费者实现的yolov5封装

目录

    • 前言
    • 1. yolov5封装
    • 总结

前言

杜老师推出的 tensorRT从零起步高性能部署 课程,之前有看过一遍,但是没有做笔记,很多东西也忘了。这次重新撸一遍,顺便记记笔记。

本次课程学习 tensorRT 高级-基于生产者消费者实现的yolov5封装

课程大纲可看下面的思维导图

在这里插入图片描述

1. yolov5封装

这节我们学习使用封装好的组件,配合生产者消费者,实现一个完整的 yolov5 推理

对于 yolov5 的封装,主要考虑以下:

1. 希望调用者是线程安全的,可以随意进行 commit,而不用考虑是否冲突

2. 希望结果是懒加载的,也就是需要的时候才等待,不需要的时候可以不等待

  • 由 promise 与 future 配合实现
  • 这样的灵活度和效率性能都是最好的

3. 希望最大化利用 GPU,如何利用呢?需要尽可能的使得计算密集

  • 实际体现就是抓取一个批次,一次性进行推理
  • 这在内部的消费者模型里面,一次抓一批

我们来看代码:

yolov5.hpp

#ifndef YOLOV5_HPP
#define YOLOV5_HPP#include <string>
#include <future>
#include <memory>
#include <opencv2/opencv.hpp>/
// 封装接口类
namespace YoloV5{struct Box{float left, top, right, bottom, confidence;int class_label;Box() = default;Box(float left, float top, float right, float bottom, float confidence, int class_label):left(left), top(top), right(right), bottom(bottom), confidence(confidence), class_label(class_label){}};typedef std::vector<Box> BoxArray;class Infer{public:virtual std::shared_future<BoxArray> commit(const cv::Mat& input) = 0;};std::shared_ptr<Infer> create_infer(const std::string& file,int gpuid=0, float confidence_threshold=0.25, float nms_threshold=0.45);
};#endif // YOLOV5_HPP

yolov5.cpp


#include "yolov5.hpp"
#include <thread>
#include <vector>
#include <condition_variable>
#include <mutex>
#include <string>
#include <future>
#include <queue>
#include <functional>
#include "trt-infer.hpp"
#include "cuda-tools.hpp"
#include "simple-logger.hpp"/
// 封装接口类
using namespace std;namespace YoloV5{struct Job{shared_ptr<promise<BoxArray>> pro;cv::Mat input;float d2i[6];};class InferImpl : public Infer{public:virtual ~InferImpl(){stop();}void stop(){if(running_){running_ = false;cv_.notify_one();}if(worker_thread_.joinable())worker_thread_.join();}bool startup(const string& file, int gpuid, float confidence_threshold, float nms_threshold){file_    = file;running_ = true;gpuid_   = gpuid;confidence_threshold_ = confidence_threshold;nms_threshold_        = nms_threshold;promise<bool> pro;worker_thread_ = thread(&InferImpl::worker, this, std::ref(pro));return pro.get_future().get();}virtual shared_future<BoxArray> commit(const cv::Mat& image) override{if(image.empty()){INFOE("Image is empty");return shared_future<BoxArray>();}Job job;job.pro.reset(new promise<BoxArray>());float scale_x = input_width_ / (float)image.cols;float scale_y = input_height_ / (float)image.rows;float scale   = std::min(scale_x, scale_y);float i2d[6];i2d[0] = scale;  i2d[1] = 0;  i2d[2] = (-scale * image.cols + input_width_ + scale  - 1) * 0.5;i2d[3] = 0;  i2d[4] = scale;  i2d[5] = (-scale * image.rows + input_height_ + scale - 1) * 0.5;cv::Mat m2x3_i2d(2, 3, CV_32F, i2d);cv::Mat m2x3_d2i(2, 3, CV_32F, job.d2i);cv::invertAffineTransform(m2x3_i2d, m2x3_d2i);job.input.create(input_height_, input_width_, CV_8UC3);cv::warpAffine(image, job.input, m2x3_i2d, job.input.size(), cv::INTER_LINEAR, cv::BORDER_CONSTANT, cv::Scalar::all(114));job.input.convertTo(job.input, CV_32F, 1 / 255.0f);shared_future<BoxArray> fut = job.pro->get_future();{lock_guard<mutex> l(lock_);jobs_.emplace(std::move(job));}cv_.notify_one();return fut;}vector<Box> cpu_decode(float* predict, int rows, int cols, float* d2i,float confidence_threshold = 0.25f, float nms_threshold = 0.45f){vector<Box> boxes;int num_classes = cols - 5;for(int i = 0; i < rows; ++i){float* pitem = predict + i * cols;float objness = pitem[4];if(objness < confidence_threshold)continue;float* pclass = pitem + 5;int label     = std::max_element(pclass, pclass + num_classes) - pclass;float prob    = pclass[label];float confidence = prob * objness;if(confidence < confidence_threshold)continue;float cx     = pitem[0];float cy     = pitem[1];float width  = pitem[2];float height = pitem[3];// 通过反变换恢复到图像尺度float left   = (cx - width * 0.5) * d2i[0] + d2i[2];float top    = (cy - height * 0.5) * d2i[0] + d2i[5];float right  = (cx + width * 0.5) * d2i[0] + d2i[2];float bottom = (cy + height * 0.5) * d2i[0] + d2i[5];boxes.emplace_back(left, top, right, bottom, confidence, (float)label);}std::sort(boxes.begin(), boxes.end(), [](Box& a, Box& b){return a.confidence > b.confidence;});std::vector<bool> remove_flags(boxes.size());std::vector<Box> box_result;box_result.reserve(boxes.size());auto iou = [](const Box& a, const Box& b){float cross_left   = std::max(a.left, b.left);float cross_top    = std::max(a.top, b.top);float cross_right  = std::min(a.right, b.right);float cross_bottom = std::min(a.bottom, b.bottom);float cross_area = std::max(0.0f, cross_right - cross_left) * std::max(0.0f, cross_bottom - cross_top);float union_area = std::max(0.0f, a.right - a.left) * std::max(0.0f, a.bottom - a.top) + std::max(0.0f, b.right - b.left) * std::max(0.0f, b.bottom - b.top) - cross_area;if(cross_area == 0 || union_area == 0) return 0.0f;return cross_area / union_area;};for(int i = 0; i < boxes.size(); ++i){if(remove_flags[i]) continue;auto& ibox = boxes[i];box_result.emplace_back(ibox);for(int j = i + 1; j < boxes.size(); ++j){if(remove_flags[j]) continue;auto& jbox = boxes[j];if(ibox.class_label == jbox.class_label){// class matchedif(iou(ibox, jbox) >= nms_threshold)remove_flags[j] = true;}}}return box_result;}void worker(promise<bool>& pro){// load modelcheckRuntime(cudaSetDevice(gpuid_));auto model = TRT::load_infer(file_);if(model == nullptr){// failedpro.set_value(false);INFOE("Load model failed: %s", file_.c_str());return;}auto input    = model->input();auto output   = model->output();input_width_  = input->size(3);input_height_ = input->size(2);// load successpro.set_value(true);int max_batch_size = model->get_max_batch_size();vector<Job> fetched_jobs;while(running_){{unique_lock<mutex> l(lock_);cv_.wait(l, [&](){return !running_ || !jobs_.empty();});if(!running_) break;for(int i = 0; i < max_batch_size && !jobs_.empty(); ++i){fetched_jobs.emplace_back(std::move(jobs_.front()));jobs_.pop();}}for(int ibatch = 0; ibatch < fetched_jobs.size(); ++ibatch){auto& job = fetched_jobs[ibatch];auto& image = job.input;cv::Mat channel_based[3];for(int i = 0; i < 3; ++i){// 这里实现bgr -> rgb// 做的是内存引用,效率最高channel_based[i] = cv::Mat(input_height_, input_width_, CV_32F, input->cpu<float>(ibatch, 2-i));}cv::split(image, channel_based);}// 一次加载一批,并进行批处理// forward(fetched_jobs)model->forward();for(int ibatch = 0; ibatch < fetched_jobs.size(); ++ibatch){auto& job = fetched_jobs[ibatch];float* predict_batch = output->cpu<float>(ibatch);auto boxes = cpu_decode(predict_batch, output->size(1), output->size(2), job.d2i, confidence_threshold_, nms_threshold_);job.pro->set_value(boxes);}fetched_jobs.clear();}// 避免外面等待unique_lock<mutex> l(lock_);while(!jobs_.empty()){jobs_.back().pro->set_value({});jobs_.pop();}INFO("Infer worker done.");}private:atomic<bool> running_{false};int gpuid_;float confidence_threshold_;float nms_threshold_;int input_width_;int input_height_;string file_;thread worker_thread_;queue<Job> jobs_;mutex lock_;condition_variable cv_;};shared_ptr<Infer> create_infer(const string& file, int gpuid, float confidence_threshold, float nms_threshold){shared_ptr<InferImpl> instance(new InferImpl());if(!instance->startup(file, gpuid, confidence_threshold, nms_threshold)){instance.reset();}return instance;}
};

头文件中定义了一个 Infer 类,该类只有一个 commit 纯虚函数,接收图像数据,然后 shared_future 对象,create_infer 函数用于 RAII,它创建并返回一个 Infer 接口类的实现,通过 startup 方法初始化实例

在 startup 函数中我们创建了一个 bool 类型的 promise 变量,用于判断资源是否获取成功,通过引用的方式传递到了消费者线程 worker 中,在 worker 线程里面会处理模型加载的过程,加载成功或者失败都会把结果反馈到对应的 promise 变量上,而通过 future 对象的 get 方法我们可以获取到是否成功,而成功之后消费者线程也启动了,会继续往下走

在 worker 消费者线程中,有个条件变量在等待,如果条件为 true 则退出等待,也就是当队列不为空时会退出等待,然后将队列中的数据移动到 vector 容器中,循环 vector 容器中的每个图像进行 brg2rgb 同时 split,然后将这一批数据送到网络进行推理拿到结果,然后对拿到的 box 进行decode,最好将通过 promise 的 set_value 方法将 boxes 返回回去

我们再来看下 commit 生产者线程,它会将接收的图像数据进行预处理后放入到队列中,然后利用条件变量通知消费者线程可以处理了,其中预处理使用的是 warpAffine 仿射变换,

总的来说,上述代码提供了 YOLOv5 的推理功能。它实现了一个生产者-消费者模式,其中生产者可以异步地提交推理任务,并使用 future-promise 获取结果。worker 线程作为消费者执行这些任务。以下是 yolov5 封装的关键点:

1. 封装的接口:定义了 Infer 接口,提供异步推理的方法

2. 生产者-消费者模式:该实现采用生产者-消费者模式,其中生产者提交推理任务,worker 线程作为消费者去执行

3. 异步处理:使用 future-promise 机制,生产者可以异步地提交任务并等待结果

4. 预处理:预处理采用 warpAffine 仿射变换

5. decode:采用 IM 逆矩阵进行解码恢复成框

6. 多线程安全:使用互斥锁和条件变量确保多线程安全

7. 资源管理:使用 RAII 原则确保资源的正确管理

最好我们来看下 main.cpp 中的内容变化:


// tensorRT include
// 编译用的头文件
#include <NvInfer.h>// onnx解析器的头文件
#include <onnx-tensorrt/NvOnnxParser.h>// 推理用的运行时头文件
#include <NvInferRuntime.h>// cuda include
#include <cuda_runtime.h>// system include
#include <stdio.h>
#include <math.h>#include <iostream>
#include <fstream>
#include <vector>
#include <memory>
#include <functional>
#include <unistd.h>
#include <opencv2/opencv.hpp>#include "trt-builder.hpp"
#include "simple-logger.hpp"
#include "yolov5.hpp"using namespace std;static const char* cocolabels[] = {"person", "bicycle", "car", "motorcycle", "airplane","bus", "train", "truck", "boat", "traffic light", "fire hydrant","stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse","sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack","umbrella", "handbag", "tie", "suitcase", "frisbee", "skis","snowboard", "sports ball", "kite", "baseball bat", "baseball glove","skateboard", "surfboard", "tennis racket", "bottle", "wine glass","cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich","orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake","chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv","laptop", "mouse", "remote", "keyboard", "cell phone", "microwave","oven", "toaster", "sink", "refrigerator", "book", "clock", "vase","scissors", "teddy bear", "hair drier", "toothbrush"
};static bool exists(const string& path){#ifdef _WIN32return ::PathFileExistsA(path.c_str());
#elsereturn access(path.c_str(), R_OK) == 0;
#endif
}// 上一节的代码
static bool build_model(){if(exists("yolov5s.trtmodel")){printf("yolov5s.trtmodel has exists.\n");return true;}//SimpleLogger::set_log_level(SimpleLogger::LogLevel::Verbose);TRT::compile(TRT::Mode::FP32,10,"yolov5s.onnx","yolov5s.trtmodel",1 << 28);INFO("Done.");return true;
}static void inference(){auto image = cv::imread("rq.jpg");auto yolov5 = YoloV5::create_infer("yolov5s.trtmodel");auto boxes = yolov5->commit(image).get();for(auto& box : boxes){cv::Scalar color(0, 255, 0);cv::rectangle(image, cv::Point(box.left, box.top), cv::Point(box.right, box.bottom), color, 3);auto name      = cocolabels[box.class_label];auto caption   = cv::format("%s %.2f", name, box.confidence);int text_width = cv::getTextSize(caption, 0, 1, 2, nullptr).width + 10;cv::rectangle(image, cv::Point(box.left-3, box.top-33), cv::Point(box.left + text_width, box.top), color, -1);cv::putText(image, caption, cv::Point(box.left, box.top-5), 0, 1, cv::Scalar::all(0), 2, 16);}cv::imwrite("image-draw.jpg", image);
}int main_old();int main(){// 旧的实现,请参照main-old.cppmain_old();// 新的实现if(!build_model()){return -1;}inference();return 0;
}

模型构建部分非常简单,一行代码解决模型编译问题,非常方便。推理部分通过 create_infer 拿到一个 shared_ptr 对象,通过调用 commit 方法把图像加入到生成者队列中,然后通过 get 方法等待消费者线程拿到推理结果,然后直接绘制目标框即可,相比于之前的纯裸的,无封装的 yolov5 来说非常简便了

总结

本次课程学习了基于生产者和消费者实现的 yolov5 封装,commit 生成者线程不断往队列中抛数据,通过条件变量 cv_ 通知消费者进行消费,worker 消费者线程会一直等待,直到队列不为空,它会把一批次数据全部扔到模型中进行推理,然后解码,通过 promise 将结果传递回来。create_infer 是 RAII 和接口模式的体现,通过 startup 获取资源并初始话,如果资源获取失败则直接退出,同时实例化的是实现类 InferImpl,而返回的是接口类 Infer,只对使用者暴露 commit 接口。这都是我们之前在生产者消费者课程中讲到过的知识,这边是直接拿过来用了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/45483.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

postgresql中基础sql查询

postgresql中基础sql查询 创建表插入数据创建索引删除表postgresql命令速查简单查询计算查询结果 利用查询条件过滤数据模糊查询 创建表 -- 部门信息表 CREATE TABLE departments( department_id INTEGER NOT NULL -- 部门编号&#xff0c;主键, department_name CHARACTE…

CentOS6.8图形界面安装Oracle11.2.0.1.0

Oracle11下载地址 https://edelivery.oracle.com/osdc/faces/SoftwareDelivery 一、环境 CentOS release 6.8 (Final)&#xff0c;测试环境&#xff1a;内存2G&#xff0c;硬盘20G&#xff0c;SWAP空间4G Oracle版本&#xff1a;Release 11.2.0.1.0 安装包&#xff1a;V175…

Lookup Singularity

1. 引言 Lookup Singularity概念 由Barry WhiteHat在2022年11月在zkResearch论坛 Lookup Singularity中首次提出&#xff1a; 其主要目的是&#xff1a;让SNARK前端生成仅需做lookup的电路。Barry预测这样有很多好处&#xff0c;特别是对于可审计性 以及 形式化验证&#xff…

【学习FreeRTOS】第8章——FreeRTOS列表和列表项

1.列表和列表项的简介 列表是 FreeRTOS 中的一个数据结构&#xff0c;概念上和链表有点类似&#xff0c;列表被用来跟踪 FreeRTOS中的任务。列表项就是存放在列表中的项目。 列表相当于链表&#xff0c;列表项相当于节点&#xff0c;FreeRTOS 中的列表是一个双向环形链表列表的…

微软Win11 Dev预览版Build23526发布

近日&#xff0c;微软Win11 Dev预览版Build23526发布&#xff0c;修复了不少问题。牛比如斯Microsoft&#xff0c;也有这么多bug&#xff0c;所以你写再多bug也不作为奇啊。 主要更新问题 [开始菜单&#xff3d; 修复了在高对比度主题下&#xff0c;打开开始菜单中的“所有应…

Spring Boot通过企业邮箱发件被Gmail退回的解决方法

这两天给我们开发的Chrome插件&#xff1a;Youtube中文配音 增加了账户注册和登录功能&#xff0c;其中有一步是邮箱验证&#xff0c;所以这边会在Spring Boot后台给用户的邮箱发个验证信息。如何发邮件在之前的文章教程里就有&#xff0c;这里就不说了&#xff0c;着重说说这两…

通过 kk 创建 k8s 集群和 kubesphere

官方文档&#xff1a;多节点安装 确保从正确的区域下载 KubeKey export KKZONEcn下载 KubeKey curl -sfL https://get-kk.kubesphere.io | VERSIONv3.0.7 sh -为 kk 添加可执行权限&#xff1a; chmod x kk创建 config 文件 KubeSphere 版本&#xff1a;v3.3 支持的 Kuber…

Linux 安全技术和防火墙

目录 1 安全技术 2 防火墙 2.1 防火墙的分类 2.1.1 包过滤防火墙 2.1.2 应用层防火墙 3 Linux 防火墙的基本认识 3.1 iptables & netfilter 3.2 四表五链 4 iptables 4.2 数据包的常见控制类型 4.3 实际操作 4.3.1 加新的防火墙规则 4.3.2 查看规则表 4.3.…

企事业数字培训及知识库平台

前言 随着信息化的进一步推进&#xff0c;目前各行各业都在进行数字化转型&#xff0c;本人从事过医疗、政务等系统的研发&#xff0c;和客户深入交流过日常办公中“知识”的重要性&#xff0c;再加上现在倡导的互联互通、数据安全、无纸化办公等概念&#xff0c;所以无论是企业…

打家劫舍 II——力扣213

动规 int robrange(vector<int>& nums, int start, int end){int first=nums[start]

CountDownLatch和CyclicBarrie

前置提要 什么是闭锁对象 闭锁对象&#xff08;Latch Object&#xff09;是一种同步工具&#xff0c;用于控制线程的等待和执行顺序。闭锁对象可以让一个或多个线程等待&#xff0c;直到特定的条件满足后才能继续执行。 在Java中&#xff0c;CountDownLatch就是一种常见的闭锁对…

STC15单片机PM2.5空气质量检测仪

一、系统方案 本设计采用STC15单片机作为主控制器&#xff0c;PM2.5传感器、按键设置&#xff0c;液晶1602显示&#xff0c;蜂鸣器报警。 二、硬件设计 原理图如下&#xff1a; 三、单片机软件设计 1、首先是系统初始化&#xff1a; void lcd_init()//液晶初始化设置 { de…

SQLite数据库实现数据增删改查

当前文章介绍的设计的主要功能是利用 SQLite 数据库实现宠物投喂器上传数据的存储&#xff0c;并且支持数据的增删改查操作。其中&#xff0c;宠物投喂器上传的数据包括投喂间隔时间、水温、剩余重量等参数。 实现功能&#xff1a; 创建 SQLite 数据库表&#xff0c;用于存储宠…

第一讲:BeanFactory和ApplicationContext接口

BeanFactory和ApplicationContext接口 1. 什么是BeanFactory?2. BeanFactory能做什么&#xff1f;3.ApplicationContext对比BeanFactory的额外功能?3.1 MessageSource3.2 ResourcePatternResolver3.3 EnvironmentCapable3.4 ApplicationEventPublisher 4.总结 1. 什么是BeanF…

解决C#报“MSB3088 未能读取状态文件*.csprojAssemblyReference.cache“问题

今天在使用vscode软件C#插件&#xff0c;编译.cs文件时&#xff0c;发现如下warning: 图(1) C#报cache没有更新 出现该warning的原因&#xff1a;当前.cs文件修改了&#xff0c;但是其缓存文件*.csprojAssemblyReference.cache没有更新&#xff0c;需要重新清理一下工程&#x…

【机器学习实战】朴素贝叶斯:过滤垃圾邮件

【机器学习实战】朴素贝叶斯&#xff1a;过滤垃圾邮件 0.收集数据 这里采用的数据集是《机器学习实战》提供的邮件文件&#xff0c;该文件有ham 和 spam 两个文件夹&#xff0c;每个文件夹中各有25条邮件&#xff0c;分别代表着 正常邮件 和 垃圾邮件。 这里需要注意的是需要…

【校招VIP】java语言考点之List和扩容

考点介绍&#xff1a; List是最基础的考点&#xff0c;但是很多同学拿不到满分。本专题从两种实现子类的比较&#xff0c;到比较复杂的数组扩容进行分析。 『java语言考点之List和扩容』相关题目及解析内容可点击文章末尾链接查看&#xff01;一、考点题目 1、以下关于集合类…

vue技术学习

vue快速入门 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>vue快速入门</title> </head> <body> <!--老师解读 1. div元素不是必须的&#xff0c;也可以是其它元素&#xff0…

操作系统——操作系统内存管理基础

文章目录 1.内存管理介绍2.常见的几种内存管理机制3.快表和多级页表快表多级页表总结 4.分页机制和分段机制的共同点和区别5.逻辑(虚拟)地址和物理地址6.CPU 寻址了解吗?为什么需要虚拟地址空间? 1.内存管理介绍 操作系统的内存管理主要是做什么&#xff1f; 操作系统的内存…

Apache DolphinScheduler 支持使用 OceanBase 作为元数据库啦!

DolphinScheduler是一个开源的分布式任务调度系统&#xff0c;拥有分布式架构、多任务类型、可视化操作、分布式调度和高可用等特性&#xff0c;适用于大规模分布式任务调度的场景。目前DolphinScheduler支持的元数据库有Mysql、PostgreSQL、H2&#xff0c;如果在业务中需要更好…