python基于DETR(DEtection TRansformer)开发构建钢铁产业产品智能自动化检测识别系统

在前文中我们基于经典的YOLOv5开发构建了钢铁产业产品智能自动化检测识别系统,这里本文的主要目的是想要实践应用DETR这一端到端的检测模型来开发构建钢铁产业产品智能自动化检测识别系统。

DETR (DEtection TRansformer) 是一种基于Transformer架构的端到端目标检测模型。与传统的基于区域提议的目标检测方法(如Faster R-CNN)不同,DETR采用了全新的思路,将目标检测问题转化为一个序列到序列的问题,通过Transformer模型实现目标检测和目标分类的联合训练。

DETR的工作流程如下:

输入图像通过卷积神经网络(CNN)提取特征图。
特征图作为编码器输入,经过一系列的编码器层得到图像特征的表示。
目标检测问题被建模为一个序列到序列的转换任务,其中编码器的输出作为解码器的输入。
解码器使用自注意力机制(self-attention)对编码器的输出进行处理,以获取目标的位置和类别信息。
最终,DETR通过一个线性层和softmax函数对解码器的输出进行分类,并通过一个线性层预测目标框的坐标。
DETR的优点包括:

端到端训练:DETR模型能够直接从原始图像到目标检测结果进行端到端训练,避免了传统目标检测方法中复杂的区域提议生成和特征对齐的过程,简化了模型的设计和训练流程。
不受固定数量的目标限制:DETR可以处理变长的输入序列,因此不受固定数量目标的限制。这使得DETR能够同时检测图像中的多个目标,并且不需要设置预先确定的目标数量。
全局上下文信息:DETR通过Transformer的自注意力机制,能够捕捉到图像中不同位置的目标之间的关系,提供了更大范围的上下文信息。这有助于提高目标检测的准确性和鲁棒性。
然而,DETR也存在一些缺点:

计算复杂度高:由于DETR采用了Transformer模型,它在处理大尺寸图像时需要大量的计算资源,导致其训练和推理速度相对较慢。
对小目标的检测性能较差:DETR模型在处理小目标时容易出现性能下降的情况。这是因为Transformer模型在处理小尺寸目标时可能会丢失细节信息,导致难以准确地定位和分类小目标。

首先看下实例效果:
 

简单看下数据集:

PyTorch训练代码和DETR(DEDetection-TRansformer)的预训练模型。我们用Transformer替换了完全复杂的手工制作的对象检测管道,并将Faster R-CNN与ResNet-50匹配,使用一半的计算能力(FLOP)和相同数量的参数在COCO上获得42个AP。

官方项目地址在这里,如下所示:

可以看到目前已经收获了超过1.2w的star量,还是很不错的了。

DETR整体数据流程示意图如下所示:

官方也提供了对应的预训练模型,可以自行使用:

本文选择的预训练官方权重是detr-r50-e632da11.pth,首先需要基于官方的预训练权重开发能够用于自己的 个性化数据集的权重,如下所示:

pretrained_weights = torch.load("./weights/detr-r50-e632da11.pth")
num_class = 10 + 1
pretrained_weights["model"]["class_embed.weight"].resize_(num_class+1,256)
pretrained_weights["model"]["class_embed.bias"].resize_(num_class+1)
torch.save(pretrained_weights,'./weights/detr_r50_%d.pth'%num_class)

因为这里我的类别数量为10,所以num_class修改为:10+1,根据自己的实际情况修改即可。生成后如下所示:

之后按照官方说明准备好数据集即可,启动训练模型命令如下所示:

python main.py --dataset_file "coco" --coco_path "/0000" --epoch 100 --lr=1e-4 --batch_size=32 --num_workers=0 --output_dir="outputs" --resume="weights/detr_r50_11.pth"

借助于plot_util.py模块可以实现对模型的评估和可视化,如下:

def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):'''Function to plot specific fields from training log(s). Plots both training and test results.:: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file- fields = which results to plot from each log file - plots both training and test for each field.- ewm_col = optional, which column to use as the exponential weighted smoothing of the plots- log_name = optional, name of log file if different than default 'log.txt'.:: Outputs - matplotlib plots of results in fields, color coded for each log file.- solid lines are training results, dashed lines are test results.'''func_name = "plot_utils.py::plot_logs"# verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,# convert single Path to list to avoid 'not iterable' errorif not isinstance(logs, list):if isinstance(logs, PurePath):logs = [logs]print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")else:raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \Expect list[Path] or single Path obj, received {type(logs)}")# Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dirfor i, dir in enumerate(logs):if not isinstance(dir, PurePath):raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")if not dir.exists():raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")# verify log_name existsfn = Path(dir / log_name)if not fn.exists():print(f"-> missing {log_name}.  Have you gotten to Epoch 1 in training?")print(f"--> full path of missing log file: {fn}")return# load log file(s) and plotdfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):for j, field in enumerate(fields):if field == 'mAP':coco_eval = pd.DataFrame(np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]).ewm(com=ewm_col).mean()axs[j].plot(coco_eval, c=color)else:df.interpolate().ewm(com=ewm_col).mean().plot(y=[f'train_{field}', f'test_{field}'],ax=axs[j],color=[color] * 2,style=['-', '--'])for ax, field in zip(axs, fields):ax.legend([Path(p).name for p in logs])ax.set_title(field)def plot_precision_recall(files, naming_scheme='iter'):if naming_scheme == 'exp_id':# name becomes exp_idnames = [f.parts[-3] for f in files]elif naming_scheme == 'iter':names = [f.stem for f in files]else:raise ValueError(f'not supported {naming_scheme}')fig, axs = plt.subplots(ncols=2, figsize=(16, 5))for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):data = torch.load(f)# precision is n_iou, n_points, n_cat, n_area, max_detprecision = data['precision']recall = data['params'].recThrsscores = data['scores']# take precision for all classes, all areas and 100 detectionsprecision = precision[0, :, :, 0, -1].mean(1)scores = scores[0, :, :, 0, -1].mean(1)prec = precision.mean()rec = data['recall'][0, :, 0, -1].mean()print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +f'score={scores.mean():0.3f}, ' +f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}')axs[0].plot(recall, precision, c=color)axs[1].plot(recall, scores, c=color)axs[0].set_title('Precision / Recall')axs[0].legend(names)axs[1].set_title('Scores / Recall')axs[1].legend(names)return fig, axs

结果如下所示:

iter 000: mAP@50= 24.0, score=0.317, f1=0.341
iter 050: mAP@50= 27.7, score=0.339, f1=0.400
iter latest: mAP@50= 26.4, score=0.348, f1=0.393
iter 000: mAP@50= 24.0, score=0.317, f1=0.341
iter 050: mAP@50= 27.7, score=0.339, f1=0.400
iter latest: mAP@50= 26.4, score=0.348, f1=0.393

可视化如下所示:

【Precision曲线】
精确率曲线(Precision-Recall Curve)是一种用于评估二分类模型在不同阈值下的精确率性能的可视化工具。它通过绘制不同阈值下的精确率和召回率之间的关系图来帮助我们了解模型在不同阈值下的表现。
精确率(Precision)是指被正确预测为正例的样本数占所有预测为正例的样本数的比例。召回率(Recall)是指被正确预测为正例的样本数占所有实际为正例的样本数的比例。
绘制精确率曲线的步骤如下:
使用不同的阈值将预测概率转换为二进制类别标签。通常,当预测概率大于阈值时,样本被分类为正例,否则分类为负例。
对于每个阈值,计算相应的精确率和召回率。
将每个阈值下的精确率和召回率绘制在同一个图表上,形成精确率曲线。
根据精确率曲线的形状和变化趋势,可以选择适当的阈值以达到所需的性能要求。
通过观察精确率曲线,我们可以根据需求确定最佳的阈值,以平衡精确率和召回率。较高的精确率意味着较少的误报,而较高的召回率则表示较少的漏报。根据具体的业务需求和成本权衡,可以在曲线上选择合适的操作点或阈值。
精确率曲线通常与召回率曲线(Recall Curve)一起使用,以提供更全面的分类器性能分析,并帮助评估和比较不同模型的性能。
【Recall曲线】
召回率曲线(Recall Curve)是一种用于评估二分类模型在不同阈值下的召回率性能的可视化工具。它通过绘制不同阈值下的召回率和对应的精确率之间的关系图来帮助我们了解模型在不同阈值下的表现。
召回率(Recall)是指被正确预测为正例的样本数占所有实际为正例的样本数的比例。召回率也被称为灵敏度(Sensitivity)或真正例率(True Positive Rate)。
绘制召回率曲线的步骤如下:
使用不同的阈值将预测概率转换为二进制类别标签。通常,当预测概率大于阈值时,样本被分类为正例,否则分类为负例。
对于每个阈值,计算相应的召回率和对应的精确率。
将每个阈值下的召回率和精确率绘制在同一个图表上,形成召回率曲线。
根据召回率曲线的形状和变化趋势,可以选择适当的阈值以达到所需的性能要求。
通过观察召回率曲线,我们可以根据需求确定最佳的阈值,以平衡召回率和精确率。较高的召回率表示较少的漏报,而较高的精确率意味着较少的误报。根据具体的业务需求和成本权衡,可以在曲线上选择合适的操作点或阈值。
召回率曲线通常与精确率曲线(Precision Curve)一起使用,以提供更全面的分类器性能分析,并帮助评估和比较不同模型的性能。

【PR曲线】
精确率-召回率曲线(Precision-Recall Curve)是一种用于评估二分类模型性能的可视化工具。它通过绘制不同阈值下的精确率(Precision)和召回率(Recall)之间的关系图来帮助我们了解模型在不同阈值下的表现。
精确率是指被正确预测为正例的样本数占所有预测为正例的样本数的比例。召回率是指被正确预测为正例的样本数占所有实际为正例的样本数的比例。
绘制精确率-召回率曲线的步骤如下:
使用不同的阈值将预测概率转换为二进制类别标签。通常,当预测概率大于阈值时,样本被分类为正例,否则分类为负例。
对于每个阈值,计算相应的精确率和召回率。
将每个阈值下的精确率和召回率绘制在同一个图表上,形成精确率-召回率曲线。
根据曲线的形状和变化趋势,可以选择适当的阈值以达到所需的性能要求。
精确率-召回率曲线提供了更全面的模型性能分析,特别适用于处理不平衡数据集和关注正例预测的场景。曲线下面积(Area Under the Curve, AUC)可以作为评估模型性能的指标,AUC值越高表示模型的性能越好。
通过观察精确率-召回率曲线,我们可以根据需求选择合适的阈值来权衡精确率和召回率之间的平衡点。根据具体的业务需求和成本权衡,可以在曲线上选择合适的操作点或阈值。 

感兴趣的话可以自行动手实践尝试下!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/166443.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Django使用】10大章31模块md文档,第5篇:Django模板和数据库使用

当你考虑开发现代化、高效且可扩展的网站和Web应用时,Django是一个强大的选择。Django是一个流行的开源Python Web框架,它提供了一个坚实的基础,帮助开发者快速构建功能丰富且高度定制的Web应用 全套Django笔记直接地址: 请移步这…

外汇天眼:多名投资者账户被恶意清空,远离volofinance!

最近,外汇平台volofinance因有多名投资者投诉,“荣幸”成为外汇天眼黑平台榜单中的一员,那么volofinance到底做了什么导致投资者前来投诉曝光呢? 起底volofinace 在网络搜索中,关于volofinance的信息少之又少&#xf…

成为AI产品经理——模型评估指标

目录 一、模型评估分类 1.在线评估 2.离线评估 二、离线模型评估 1.特征评估 ① 特征自身稳定性 ② 特征来源稳定性 ③ 特征成本 2.模型评估 ① 统计性评估 覆盖度 最大值、最小值 分布形态 ② 模型性能指标 分类问题 回归问题 ③ 模型的稳定性 模型评估指标分…

配置mvn打包参数,不同环境使用不同的配置文件

方法一: 首先在/resource目录下创建各自环境的配置 要在不同的环境中使用不同的配置文件进行Maven打包,可以使用Maven的profiles特性和资源过滤功能。下面是配置Maven打包参数的步骤: 在项目的pom.xml文件中,添加profiles配置…

第一个Mybatis项目

(一)为什么要用Mybatis? (1)Mybatis对比JDBC而言,sql(单独写在xml的配置文件中)和java编码分开,功能边界清晰,一个专注业务,一个专注数据。 (2&…

【C++】:多态

朋友们、伙计们,我们又见面了,本期来给大家解读一下有关多态的知识点,如果看完之后对你有一定的启发,那么请留下你的三连,祝大家心想事成! C 语 言 专 栏:C语言:从入门到精通 数据结…

Linux(CentOS7)上安装mysql

在CentOS中默认安装有MariaDB(MySQL的一个分支),可先移除/卸载MariaDB。 yum remove mariadb // 查看是否存在mariadb rpm -qa|grep -i mariadb // 卸载 mariadb rpm -e --nodeps rpm -qa|grep mariadb yum安装 下载rpm // 5.6版本 wge…

XML映射文件

<?xml version"1.0" encoding"UTF-8" ?> <!DOCTYPE mapperPUBLIC "-//mybatis.org//DTD Mapper 3.0//EN""http://mybatis.org/dtd/mybatis-3-mapper.dtd"> <mapper namespace"org.mybatis.example.BlogMapper&q…

conan 入门(三十二):package_info中配置禁用CMakeDeps生成使用项目自己生成的config.cmake

conanfile.py中定义的package_info()方法用于向package的调用者(conumer)提供包库名&#xff0c;编译/连接选项&#xff0c;文件夹等等信息&#xff0c;有了这些信息构建工具的generator就可以根据它们生成对应的文件&#xff0c;用于调用者引用package. 比如基于cmake的CMakeD…

安全地公网访问树莓派等设备的服务 内网穿透--frp 23年11月方法

如果想要树莓派可以被公网访问&#xff0c;可以选择直接网上搜内网穿透提供商&#xff0c;一个月大概10块钱&#xff0c;也有免费的&#xff0c;但是免费的速度就不要希望很好了。 也可以选择接下来介绍的frp&#xff0c;这种方式不需要付费&#xff0c;但是需要你有一台有着公…

vue3自定义拖拽指令

<template><div v-move class"box"></div> </template><script setup lang"ts"> import { Directive } from vue const vMove:Directive (el:HTMLElement) >{const mousedown (e:MouseEvent) >{// 鼠标按下const s…

【Golang】解决使用interface{}解析json数字会变成科学计数法的问题

在使用解析json结构体的时候&#xff0c;使用interface{}接数字会发现变成了科学计数法格式的数字&#xff0c;不符合实际场景的使用要求。 举例代码如下&#xff1a; type JsonUnmStruct struct {Id interface{} json:"id"Name string json:"name"…

Linux 的性能调优的思路

Linux操作系统是一个开源产品&#xff0c;也是一个开源软件的实践和应用平台&#xff0c;在这个平台下有无数的开源软件支撑&#xff0c;我们常见的apache、tomcat、mysql等。 开源软件的最大理念是自由、开放&#xff0c;那么Linux作为一个开源平台&#xff0c;最终要实现的是…

uniApp微信支付实现

后端&#xff1a;小程序下单 - 小程序支付 | 微信支付商户文档中心 服务端需要请求&#xff1a;https://api.mch.weixin.qq.com该地址获取微信支付Api接口需要的参数。 服务端请求接口需要的Body参数&#xff1a; 客户端&#xff08;前端&#xff09;需要调用&#xff1a;wx.…

12V降3.3V100mA稳压芯片WT7133

12V降3.3V100mA稳压芯片WT7133 WT71XX系列是一款采用CMOS工艺实现的三端高输入电压、低压差、小输出电流电压稳压器。 它的输出电流可达到100mA&#xff0c;输入电压可达到18V。其固定输出电压的范围是2.5V&#xff5e;8.0V&#xff0c;用户 也可通过外围应用电路来实现可变电压…

加载minio中存储的静态文件html,不显示样式与js

问题描述:点击链接获取的就是纯静态文件,但是通过浏览器可以看到明明加载了css文件与js文件 原因:仔细看你会发现加载css文件显示的contentType:text/html文件,原来是minio上传文件时将所有文件的contentType设置成了text/html 要在上传时指定文件,根据文章的类型指定的Conten…

win10开机黑屏只有鼠标?这份指南帮你轻松解决!

win10是一个出色的操作系统&#xff0c;但有时用户可能会遇到开机后只有鼠标显示在屏幕上的问题&#xff0c;这种情况可能会让人感到困惑和沮丧。在本文中&#xff0c;我们将介绍三种解决win10开机黑屏只有鼠标的方法&#xff0c;以帮助您快速恢复正常的桌面环境。 方法1&#…

Ubuntu18.4中安装wkhtmltopdf + Odoo16配置【二】

deepin Linux 安装wkhtmltopdf 1、先从官网的链接里下载linux对应的包 wkhtmltopdf/wkhtmltopdf 下载需要的版本&#xff0c;推荐版本&#xff0c;新测有效&#xff1a; wkhtmltox-0.12.4_linux-generic-amd64.tar.xz 2、解压下载的文件 解压后会有一个wkhtmltox文件夹 3…

CTA-GAN:基于生成对抗性网络的主动脉和颈动脉非集中CT血管造影 CT到增强CT的合成技术

Generative Adversarial Network–based Noncontrast CT Angiography for Aorta and Carotid Arteries 基于生成对抗性网络的主动脉和颈动脉非集中CT血管造影背景贡献实验方法损失函数Thinking 基于生成对抗性网络的主动脉和颈动脉非集中CT血管造影 https://github.com/ying-f…

可自行DIY单TYPE-C接口设备实现DRP+OTG功能芯片

随着USB-C接口的普及&#xff0c;欧盟的法律法规强制越来越多的设备开始采用这种接口。由于 USB-C接口的高效性和便携性&#xff0c;使各种设备之间的连接和数据传输变得非常方便快捷&#xff0c;它们不仅提供了强大的功能&#xff0c;还为我们的日常生活和工作带来了极大的便利…