DBNet详解及训练ICDAR2015数据集

论文地址:https://arxiv.org/pdf/1911.08947.pdf

开源代码pytorch版本:GitHub - WenmuZhou/DBNet.pytorch: A pytorch re-implementation of Real-time Scene Text Detection with Differentiable Binarization

前言

在这篇论文之前,文字检测算法主要分为两类:基于回归的方法和基于分割的方法。基于分割的方法通常涉及以下流程,如下图蓝色箭头所示:首先,通过网络输出图像的文本分割结果,即概率图,其中每个像素表示是否属于正样本的概率。然后,通过使用预设的阈值将分割结果图转换为二值图。最后,通过一些聚合操作,例如连通域分析,将像素级的结果转换为最终的文本检测结果。然而,由于涉及使用阈值来判定前景和背景的不可微分操作,因此这一部分流程无法被直接放入网络中进行训练。所以本文引入了一种新的方法。具体而言,通过学习阈值映射(threshmap)并采用可微分的操作,将阈值的转换过程嵌入到网络中进行训练。这一创新的流程如下图中红色箭头所示,通过可微分的操作来处理阈值的学习,使得整个流程可以在神经网络的训练中进行端到端的优化。通过这种方式,文本检测模型能够自适应地学习阈值,更有效地捕捉文本的分割信息,提高了检测性能。这一方法有助于简化原有基于分割方法的后处理流程,同时使整个模型更具可训练性。

网络结构

其实从下图的网络结构中不难看出,相比较于PSENet,多了一条threshold map分支罢了,该分支的主要目的是和分割图联合得到更接近二值化的二值图,属于辅助分支。

整个网络结构流程:

图像输入特征提取主干: 使用图像输入,经过一个特征提取的主干网络,该网络负责从输入图像中提取高层次的语义特征。这可以是一个卷积神经网络(CNN)的主要部分,如ResNet或其他先进的架构。

特征金字塔上采样和级联: 从特征提取主干获得的特征被送入特征金字塔。在特征金字塔中,通过上采样将不同尺寸的特征图调整到相同的尺寸,并将它们级联在一起,形成一个具有丰富多尺度信息的特征F。这有助于模型对不同大小和尺度的目标进行有效的检测和分割。

预测概率图和阈值图: 利用级联的特征F,进行概率图(probability map P)和阈值图(threshold map T)的预测。概率图通常表示每个像素属于某个类别(在这里可能是目标文本与非文本的概率),而阈值图则用于指导后续的二值化操作。这一步的目的是产生用于后续计算的中间结果。

计算近似二值图: 利用概率图P和阈值图T,通过一定的计算过程(可能是使用阈值或其他运算),得到一个近似的二值图B。这个近似二值图用于最终的文本检测,其中文本区域被二值化为前景,而非文本区域为背景。

在训练过程中,该模型通过使用相同的监督信号对概率图 P 和近似二值图 B 进行监督训练,其中概率图表示文本区域的概率,而近似二值图是文本二值化结果。在推理阶段,只需使用概率图 P 或者近似二值图B 中的任一即可获取文本检测结果,无需依赖额外的阈值图。这种设计简化了推理流程,提高了模型的实际应用效率。

模型的输出

Probability Map(概率图): 这是一个大小为w×h×1 的张量,其中 w 和 ℎ分别表示图像的宽度和高度。概率图的每个像素表示相应位置是否为文本的概率。对于二进制文本检测任务,概率图的值通常在 0 到 1 之间,表示每个像素点属于文本的概率,1 表示高置信度是文本,0 表示低置信度是文本。

Threshold Map(阈值图): 阈值图也是一个大小为 w×h×1 的张量,其中每个像素点包含一个阈值。这些阈值用于二值化概率图,将其转换为最终的二值图。阈值图的每个值表示相应位置的二值化操作的阈值。

Binary Map(二值图): 由概率图和阈值图计算得到,也是一个大小为 w×h×1 的张量。它表示最终的文本检测结果,其中每个像素点被二值化为前景(文本)或背景(非文本)。这里提到使用了 "DB 公式" 来计算二值图,而 DB(Differentiable Binarization)通常是一个近似二值化的函数,通过可微分的操作来实现对阈值的学习和调整。

DB公式

标准二值化

一般使用分割网络(segmentation network)产生的概率图(probability map P),将P转化为一个二值图P,当像素为1的时候,认定其为有效的文本区域。i和j代表了坐标点的坐标,t是预定义的阈值

可微二值化(differentiable Binarization)

可微二值化的公式如下,其实就是带一个系数的sigmoid,其中其中T是阈值图,k取50

从图像上不难看出,二值化和标准二值化很相似,且可微分,因此可以和分割网络一起联合优化

从(b)(c)图我们不难看出通过增加参数 K,可以在模型的训练过程中加速对正确预测区域和错误预测区域的学习,以更快地收敛到最优解。这样的调整可以在某些情况下提高模型的训练效率和性能。原图,gt图,threshold map图如下所示

模型训练

自动下载的预训练模型下载地址:/home/xuzhen/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth(我看了代码,他是判断有没有预训练模型没有的话才下载)

这个源代码在配置文件中加载的是train.txt和test.txt,所以我写了一个脚本,根据img文件夹和gt文件夹自动生成这两个文件的脚本

import osdef create_gt_file(img_dir, gt_dir, output_file_path):# 检查文件夹是否存在if not os.path.exists(img_dir) or not os.path.exists(gt_dir):print("Error: One or both folders do not exist.")returnimg_paths = []  gt_paths = []   # 循环读取文件夹1中的文件名for filename in os.listdir(img_dir):img_path = os.path.join(img_dir,filename)img_paths.append(img_path)# 去掉后缀并在前面加上 "gt_"gt_path = os.path.join(gt_dir, "gt_" + os.path.splitext(filename)[0] + ".txt")gt_paths.append(gt_path)# 写入文件with open(output_file_path, 'w') as output_file:# 将 img_paths 和 gt_paths 写入文件for img_path, gt_path in zip(img_paths, gt_paths):output_file.write(f"{img_path}\t{gt_path}\n")print(f"{img_path}\t{gt_path}Strings written to {output_file_path}")# 主函数
def main():img_dir = "/data2/xuzhen8/yzh/datasets/ICDAR2015/test_images"gt_dir = "/data2/xuzhen8/yzh/datasets/ICDAR2015/testing_localization_transcription_gt"output_file_path = "/data2/xuzhen8/yzh/projects/DBNet.pytorch/datasets/test.txt"create_gt_file(img_dir, gt_dir, output_file_path)if __name__ == "__main__":main()

每一轮训练都会打印信息,我想对这个打印信息说明一下,以便后面复习

FPS(Frames Per Second): 99.37

表示每秒处理的图像帧数。在这个上下文中,表示模型在测试阶段的推断速度。这是通过测量模型在测试集上处理图像的速度来得到的,其单位是帧数/秒。

test: recall: 0.031477, precision: 0.596330, f1: 0.059798

提供了模型在测试集上的性能指标。在这里,包括了召回率(recall)、精确度(precision)和 F1 分数(f1-score)。这些指标用于衡量模型在检测任务中的性能,其中:

召回率表示正确检测到的正类别样本占所有实际正类别样本的比例。

精确度表示模型正确检测的正类别样本占所有模型检测为正类别的样本的比例。

F1 分数是召回率和精确度的调和平均数,综合考虑了这两个指标。

current best, recall: 0.101695, precision: 0.726644, hmean: 0.178420, train_loss: 1.706732, best_model_epoch: 5.000000

提供了模型在测试集上的当前最佳性能以及训练期间的一些指标。其中:

recall、precision、hmean 是测试集上的召回率、精确度和 F1 分数。

train_loss 表示模型在训练集上的损失值,用于衡量训练过程中模型的拟合情况。

best_model_epoch 表示在训练过程中取得最佳性能的模型所对应的训练轮次。

Saving checkpoint:DBNet.pytorch/output/DBNet_resnet18_FPN_DBHead/checkpoint/

model_latest.pth

表示当前训练轮次的模型参数被保存到了指定路径下的 model_latest.pth 文件中。这通常发生在模型在测试集上取得了更好性能后,保存了当前状态的模型参数,以备将来使用或继续训练。

 

小辉问:能不能举个例子说明一下召回率、精确度、F1 分数。以便更好的理解

小G答:假设有一个二分类任务,目标是检测患有某种疾病的患者。我们的模型对每个样本都进行预测,可以分为以下四种情况:

True Positive (TP): 模型正确地预测了患有疾病的患者。

True Negative (TN): 模型正确地预测了没有患疾病的健康人。

False Positive (FP): 模型错误地预测了没有患疾病的健康人为患病。

False Negative (FN): 模型错误地预测了患有疾病的患者为健康人。

现在,我们可以使用这些概念来解释这些指标:

召回率(Recall):

召回率衡量了模型在所有实际患有疾病的样本中,有多少被成功地检测到。计算公式:

例如,如果总共有 100 名患有疾病的患者,而模型成功地检测到其中的 80 人,则召回率为 80/80+20 =0.8 或 80%。

精确度(Precision):

精确度衡量了模型在所有预测为患有疾病的样本中,有多少实际上是真正患有疾病的人。计算公式:

例如,如果模型预测了 90 个人患有疾病,而其中有 80 人确实是患有疾病的,则精确度为 80/80+10=0.888 或 88.8%。

F1 分数:

F1 分数是召回率和精确度的调和平均数,它综合考虑了两者的性能。计算公式:

F1 分数的取值范围在 [0,1],越接近 1 表示模型在召回率和精确度之间取得了更好的平衡。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/669185.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

微信小程序的图片色彩分析,窃取网络图片的主色调

1、安装 Mini App Color Thief 包 包括下载包,简单使用都有,之前写了,这里就不写了 网址:微信小程序的图片色彩分析,窃取主色调,调色板-CSDN博客 2、 问题和解决方案 问题:由于我们的窃取图片的…

基于Langchain-Chatchat + chatGLM3 轻松在本地部署一个知识库

前面三篇分别讲解了如何在本地环境部署大模型,那么今天简单的通过 Langchain-Chatchat 和 chatGLM3结合在本地环境搭建一套属于自己的大模型知识库。 往期llm系列文章 基于MacBook Pro M1芯片运行chatglm2-6b大模型如何在本地部署chatGLM3基于ChatGLM.cpp实现低成…

论文阅读-通过云特征增强的深度学习预测云工作负载转折点

论文名称:Cloud Workload Turning Points Prediction via Cloud Feature-Enhanced Deep Learning 摘要 云工作负载转折点要么是代表工作负载压力的局部峰值点,要么是代表资源浪费的局部谷值点。预测这些关键点对于向系统管理者发出警告、采取预防措施以…

企业动态 | UFAPKU“金融科技”沙龙走进同创永益——前沿技术在金融科技领域的应用

金融科技作为金融发展的驱动力量,对金融行业有着深远的影响。金融行业通过技术创新和数字化转型,极大地提高了金融服务和产品的效率和便捷性。1月21日,UFAPKU“金融科技”第二期沙龙在北大校友企业同创永益北京总部举办,数十位来自…

请问CTF是什么?请介绍一下关于隐水印的知识特点技术原理应用领域技术挑战

目录 请问CTF是什么? 请介绍一下关于隐水印的知识 特点 技术原理 应用领域 技术挑战 请问CTF是什么? CTF(Capture The Flag,夺旗比赛)是一种信息安全竞赛,常见于计算机安全领域。这种比赛模拟各种信…

fastjson 导致的OOM

fastjson 导致的OOM 示例代码 public static void main(String[] args) throws Exception {try {List<Integer> list JSONObject.parseArray("[2023,2024", Integer.class);}catch (Exception e){System.err.println("error");}System.out.println…

一文搞懂 springboot 如何融合数据源

1、简介 springboot 支持关系型数据库的相关组件进行配置&#xff0c;包括数据源、连接池、事务管理器等的自动配置。降低了数据库使用的难度&#xff0c;除了 mysql 还支持 Derby、H2等嵌入式数据库的自动配置&#xff0c;MongoDB、Redis、elasticsearch等常用的 NoSQL 的数据…

BGP邻居故障检测

第一种情况:如果AR2和AR4采用直连建立邻居,则排查步骤如下: 1)在AR2和AR4上使用ping x.x.x.x命令检查AR2和AR4用于建立EBGP邻居关系的直连地址连通性是否正常。如果不能ping通。则需要使用二分法从网络层向下层逐层进行排查,首先检查接口地址及路由的可达性,修改完成后,如…

Codeforces Round 914 (Div. 2)(D1/D2)--ST表

Codeforces Round 914 (Div. 2)(D1/D2)–ST表 D1. Set To Max (Easy Version) 题意&#xff1a; 给出长度为n的数组a和b&#xff0c;可以对a进行任意次数操作&#xff0c;操作方式为选择任意区间将区间内值全部变成该区间的最大值&#xff0c; 是否有可能使得数组a等于数组b…

C# CAD界面-自定义窗体(三)

运行环境 vs2022 c# cad2016 调试成功 一、引用 二、开发代码进行详细的说明 初始化与获取AutoCAD核心对象&#xff1a; Database db HostApplicationServices.WorkingDatabase;&#xff1a;这行代码获取当前工作中的AutoCAD数据库对象。在AutoCAD中&#xff0c;所有图形数…

《短链接--阿丹》--技术选型与架构分析

整个短链接专栏会持续更新。有兴趣的可以关注一下我的这个专栏。 《短链接--搭建解析》--立项+需求分析文档-CSDN博客 阿丹: 其实整套项目中的重点,根据上面的简单需求分析来看,整体的项目难题有两点。 1、快速的批量生成短链,并找到对应的存储。 并且要保持唯一性质。…

【Linux驱动】块设备驱动(二)—— 块设备读写(使用请求队列)

块设备的操作函数并没有类似于字符驱动中的read 和write函数&#xff0c;要实现读写操作&#xff0c;只能在请求处理函数中实现。这就分为两种&#xff0c;是否要使用请求队列&#xff0c;请求队列的主要作用是管理和调度IO请求。在以下情况中&#xff0c;一般需要用到请求队队…

跑路页面HTML源码

简单的HTMLJSCSS&#xff0c;记事本修改内容&#xff0c;喜欢的朋友可以下载 https://download.csdn.net/download/huayula/88811984

vivado RTL综合中的多线程

RTL综合中的多线程 在多处理器系统上&#xff0c;RTL合成默认情况下利用多个CPU核心&#xff08;最多四个&#xff09;来加快编译时间。同时运行的线程的最大数量会有所不同&#xff0c;具体取决于处理器的数量可在系统、操作系统和流程阶段使用&#xff08;请参阅Vivado Desi…

HTTP1.1、HTTP2、HTTP3

HTTP1.1 HTTP/1.1 相比 HTTP/1.0 性能上的改进&#xff1a; 使用长连接的方式改善了 HTTP/1.0 短连接造成的性能开销。支持管道&#xff08;pipeline&#xff09;网络传输&#xff0c;只要第一个请求发出去了&#xff0c;不必等其回来&#xff0c;就可以发第二个请求出去&…

在VM虚拟机上搭建MariaDB数据库服务器

例题&#xff1a;搭建MariaDB数据库服务器&#xff0c;并实现主主复制。 1.在二台服务器中分别MariaDB安装。 2.在二台服务器中分别配置my.cnf文件&#xff0c;开启log_bin。 3.在二台服务器中分别创建专用于数据库同步的用户replication_user&#xff0c;并授权SLAVE。&#x…

Matplotlib绘制炫酷柱状图的艺术与技巧【第60篇—python:Matplotlib绘制柱状图】

文章目录 Matplotlib绘制炫酷柱状图的艺术与技巧1. 簇状柱状图2. 堆积柱状图3. 横向柱状图4. 百分比柱状图5. 3D柱状图6. 堆积横向柱状图7. 多系列百分比柱状图8. 3D堆积柱状图9. 带有误差线的柱状图10. 分组百分比柱状图11. 水平堆积柱状图12. 多面板柱状图13. 自定义颜色和样…

c#string方法对比

字符串的截取匹配操作在开发中非常常见&#xff0c;比如下面这个示例&#xff1a;我要匹配查找出来字符串数组中以“abc”开头的字符串并打印&#xff0c;我下面分别用了两种方式实现&#xff0c;代码如下&#xff1a; using System; namespace ConsoleApp23{ class Progra…

aidl复杂流程封装

1 aidl相关困扰点 1 制作步骤复杂&#xff0c;先定义然后编译&#xff0c;然后复制&#xff0c;两边都要一一对应 2 增加回调&#xff0c;自定义对象流程更加麻烦&#xff0c;还要处理对象数据流是 in 还是out。 3 一方异常怎么办&#xff0c;虽然服务端可以用 RemoteCallbackL…

Retrofit源码分析及理解

参考文档&#xff1a; 12W字&#xff1b;2022最新Android11位大厂面试专题&#xff08;一&#xff09; - 掘金 Retrofit 版本号&#xff1a;2.9.0 Retrofit简单来说&#xff0c;就是对OkHttp上层进行了封装&#xff0c;已达到用户方便使用管理网络请求的目的。 Retrofit内部有…