论文快过(图像配准|Coarse_LoFTR_TRT)|适用于移动端的LoFTR算法的改进分析 1060显卡上45fps

项目地址:https://github.com/Kolkir/Coarse_LoFTR_TRT
创建时间:2022年
相关训练数据:BlendedMVS
在这里插入图片描述
LoFTR [19]是一种有效的深度学习方法,可以在图像对上寻找合适的局部特征匹配。本文报道了该方法在低计算性能和有限内存条件下的设备上的优化工作。原来的LoFTR方法是基于一个ResNet [6]backbone和两个基于线性transformer[22]架构的模块。在本研究中,只剩下粗匹配块,参数的数量显著减少,并使用知识蒸馏技术对网络进行了训练。对比结果表明,在粗匹配块中,尽管模型大小显著减少,但该方法仍可以获得适当的特征检测精度。此外,本文还展示了使模型与NVIDIA TensorRT运行时兼容所需的额外步骤,并展示了一种优化针对低端gpu的训练方法的方法。

简化后的算法运行速度在1060显卡上提升了45倍,针对640×480图像,fps可达45

1、改进思考

1.1 算法背景

为了解决高计算复杂度的问题,对transformer架构的各种修改已经被开发出来。LoFTR方法采用线性变换[22]方法,该方法提出通过将注意层中使用的指数核替换为𝜑(·)=𝑒𝑙𝑢(·)+1.的替代核𝑠𝑖𝑚(𝑄,𝐾)=𝜑(𝑄)* 𝜑 ( K ) T (K)^T (K)T从而降低计算复杂度到𝑂(𝑁)该方法对计算机视觉任务具有良好的计算性能的提高和内存消耗的降低。这很重要,因为这种类型的任务的序列长度等于输入图像中的像素数。使用线性transformer允许对640x480张的图像进行特征匹配,在高端gpu上具有可接受的性能。然而,该体系结构中的更改仍然不足使transformer可以运行在低端gpu上。

1.2 优化方向

使复杂模型适应低端器件[5]要求的工程方法主要有:量化、剪枝[9]和知识蒸馏。

量化是用于计算和权重存储的数据类型的位宽降低。通常浮点计算转换为16位浮点或8位整数类型。为了达到与原始模型相媲美的精度,这种方法通常需要一个特殊的训练过程,考虑到缩小或缩小后的附加模型校准。量化通常在消费者级或嵌入式gpu上不可用,而且它的实现只能在高端gpu中可用。然而,对于基于cpu的设备,该方法是可用的,可以提供良好的效果。

剪枝是一种去除网络参数的方法,它对结果的精度没有太大的贡献。通常一个合适的剪枝条件是权重接近于零。由此得到的模型可能需要更少的内存,而且在推理方面可能更有效。有许多剪枝类型,但可以区分以下两种主要类型:结构化剪枝,当对称的权值块被删除时,例如层,和非结构化剪枝,当被删除的块可能是不同的形状时。由于这种方法改变了模型架构,因此通常需要进行手动调整来恢复正常的模型工作。结构化方法可能更可取,因为它对全局架构进行的更改更少,而且恢复模型操作更容易,甚至可能不必要。然而,流行的深度学习框架通常会实现非结构化的方法。在复杂模型中应用非结构化处理后适应网络操作可能是一项重要的任务,需要很多时间来解决,而且由于该方法不能保证一个稳定的结果,因此应用它并不总是合适的。

知识蒸馏是在教师的帮助下训练模型的一种方法。教师可以是具有相同架构但具有更多参数的网络,也可以是具有其他架构的网络。大多数训练是使用复杂的损失函数,转移教师的知识。转移到学生模型中的知识元素可以是教师网络中某些层的输出值,例如,在分类中可能是softmax之前的输出。也可以使用教师网络[2]的内部层输出值。知识蒸馏在保持所需的精度的同时,显示了良好的结果,但没有标准的方法来组织这样的过程。而成功则取决于正确的知识转移技术、精心选择的损失函数和学生模型架构。

如上所示,没有单一的方法来优化低端设备的深度学习模型。因此,通常会针对特定的架构开发专门的解决方案。本文提出了一种针对LoFTR特征匹配方法的优化方法。

1.3 本文方案

该方法的主要思想是显著减少模型参数的数量和从原始模型中的知识转移。
决定只保留一个transformer block用于粗特征匹配,尽管原始模型包含第二个transformer模块用于细匹配同时,在所有模型块中进行了手动迭代选择较少的层网络结构简化。设计了知识蒸馏损失函数,并使用了一个较小的训练数据集设计知识蒸馏方案。然而,地面-真实的特征点的匹配也可以使用深度图来确定。训练过程是开发使用自动混合精度(AMP)技术和梯度积累方法来节省内存和加快计算。

源代码被改编为以NVIDIA TensorRT [13]引擎格式编译。选择工作内存大小为2Gb的NVIDIA Jetson Nano [12]作为目标设备。并选择了基于英特尔i5处理器和Nvidia GTX 1060 6Gb GPU的桌面机作为训练平台。

2、模型改进

2.1 适配性修改

最初的LoFTR模型是用Python编写的,使用PyTorch作为深度学习框架。为了创建TensorTR模型,有两种可能性,一种是使用torch-TensorRT[14]编译器,第二种是将模型转换为ONNX [1]格式,然后使用NVIDIA TensorTR SDK编译它。由于目标平台的资源有限,无法应用第一个选项,因为使用Torch-TensorRT编译成TensorTR格式意味着在目标设备上运行它以进行实时优化。实验发现,编译ONNX需要的资源更少,并且在目标设备上是可能的,因此选择了第二个选项。

然而,einsum操作在onnx中并不支持。
在这里插入图片描述
所有将运算方式修改为以下,使onnx与tensorRT都支持。
在这里插入图片描述

2.2 结构优化

为了目标设备上实现可接受的性能,即选择块中的层的数量和尺寸。为此目的,我们开发了一个在实时网络摄像头图像上搜索特征匹配的演示应用程序。性能是通过呈现相应匹配时的FPS数量来估计的。然后,在此应用程序的帮助下,迭代地选择了表1中所示的模型配置。
在这里插入图片描述

原始模型的作者报告说,完整模型在RTX 2080Ti上处理116 ms处理一对640×480图像,约8 FPS [19]。简化后的算法运行速度在1060显卡上提升了45倍。
在这里插入图片描述
表3显示了参数数量的变化。从表中可以看到,原始模型的尺寸显著减少,以便在目标设备上实现可接受的性能。
在这里插入图片描述

2.3 训练设置

针对低性能硬件的局限性,对知识精馏训练过程进行了优化。为了加速梯度计算和减少内存消耗,我们使用了自动混合精度(AMP)技术,因为它的实现在PyTorch深度学习框架中可用。该技术的本质是,梯度计算所需的一些操作使用浮点32,而另一部分使用浮点16种数据类型。例如,卷积运算和线性层相关的矩阵计算使用float16计算速度更快。而其他操作,如减法,需要使用一个浮动32范围。这项技术使我们能够为模型训练中涉及的所有操作自动选择适当的数据类型。它的使用可以显著减少模型ResNet+FPN头的内存消耗。然而,AMP技术存在较小梯度值的数值计算问题。因此,为了稳定损失函数,增加了放大因子。

尽管使用了AMP,但在GTX 1060上进行训练也只能是支持到batch为4的640x480的图像。因此,为了增加batch的大小,我们采用了梯度积累的方法。这意味着大batchsize被分为𝑛系列的小batchsize。对于每个系列,进行正向和反向循环,不清除产生的梯度值,而是求和。其中,𝑛=𝐵𝑖𝑔𝐵𝑎𝑡𝑐ℎ𝑆𝑖𝑧𝑒、𝑆𝑚𝑎𝑙𝑙𝐵𝑎𝑡𝑐ℎ𝑆𝑖𝑧𝑒。在每次迭代中,损失函数值乘以比例因子1/𝑛。只有经过所有𝑛迭代后才更新网络参数,然后将梯度归零。因此,利用该技术模拟了大批量的训练。在这项工作中,虚拟批处理大小等于32。尽管,在现实中,硬件处理了8批梯度累计,每批的尺寸为4。梯度积累技术并没有实现实际大批量使用的精确对应关系,因此这两种方法的损失和梯度值将是不同的。

此外,我们还注意到,应用学习速率调度器可以显著加快训练过程。本研究采用了具有标准参数的AdamW [10]优化算法。初始学习速率值为10−3,每15个epoch乘以10−3。

每个epoch都从原始数据集中随机选择大小为5000对图像。

2.4 训练效果

图1显示了有教师和没有教师的训练的损失函数值。这张图清楚地表明,当与教师一起进行训练时,绝对损失函数值明显更小,学习过程本身更稳定。
在这里插入图片描述
图2,它显示了平均绝对误差(MAE)与训练持续时间的依赖关系。它显示了预测的特征匹配分数与地面真实值之间的平均差异。我们可以看到,当与老师一起进行训练时,MAE值远远接近于零。我们可以假设,在没有教师的情况下训练一个较小的网络会使它对其结果缺乏信心。然而,与此同时,这个图1显示了所选择的模型架构能够在没有老师的情况下学习,但可能需要更长的时间来获得可比的结果,并且需要更低的阈值来确定最重要的匹配。
在这里插入图片描述
图3显示了在数据集图像上的模型结果的示例。白点表示原模型作为教师使用的粗LoFTR模块的匹配结果。黑点表示较小模型的结果。从实验结果中可以清楚地看出,较小的模型比教师模型更关注图像的不同部分。最可能的原因是头层数量较少,transformer参数不同,使得模型强调更明显的特征点。也可以注意到较小模型的特征匹配中存在错误,尽管通常特征匹配是相当准确的。
在这里插入图片描述
室外数据配准效果
在这里插入图片描述

3、代码运行

打开 https://github.com/Kolkir/Coarse_LoFTR_TRT,即可下载项目
在这里插入图片描述

3.1 前置修改

如果电脑没有摄像头,则需要进行下列额外代码修改

修改一: webcam.py中默认参数camid,类型修改为str,默认值修改为自己准备好的视频文件

def main():parser = argparse.ArgumentParser(description='LoFTR demo.')parser.add_argument('--weights', type=str, default='weights/outdoor_ds.ckpt',help='Path to network weights.')# parser.add_argument('--camid', type=int, default=0,#                     help='OpenCV webcam video capture ID, usually 0 or 1.')parser.add_argument('--camid', type=str, default=r"C:\Users\Administrator\Videos\风景视频素材分享_202477135455.mp4",help='OpenCV webcam video capture ID, usually 0 or 1.')

修改二:camera.py中的代码修改为以下,用于支持读取视频文件

import cv2
from threading import Threadclass Camera(object):def __init__(self, index):self.index=indexif isinstance(self.index,int):#加载摄像头视频流self.cap = cv2.VideoCapture(self.index, cv2.CAP_V4L2)else:#加载视频self.cap = cv2.VideoCapture(self.index)if not self.cap.isOpened():print('Failed to open camera {0}'.format(index))exit(-1)# self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1920)# self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 1080)self.thread = Thread(target=self.update, args=())self.thread.daemon = Trueself.thread.start()self.status = Falseself.frame = Nonedef update(self):while True:try:if self.cap.isOpened():(self.status, self.frame) = self.cap.read()if not self.status:if isinstance(self.index,int):#加载摄像头视频流self.cap = cv2.VideoCapture(self.index, cv2.CAP_V4L2)else:#加载视频self.cap = cv2.VideoCapture(self.index)else:breakexcept cv2.error as e:print(e)breakdef get_frame(self):return self.frame, self.statusdef close(self):self.cap.release()self.thread.join()

3.2 运行效果

然后运行webcam.py,可以发现fps为25左右,此时硬件环境为win10笔记本、1660显卡,26ms即可处理完一个640*480的图片。但整体fps稳定在16~26左右。
在这里插入图片描述
再次加速,将推理时的图像分辨率修改为320x240 ,即将webcam.py中的 img_size 设置(320, 240),loftr\utils\cvpr_ds_config.py中对应的设置。发现速度没有显著提升,但整体fps稳定在22~28左右。

_CN.INPUT_WIDTH = 320
_CN.INPUT_HEIGHT = 240

在这里插入图片描述
onnx运行效果如下,整体fps稳定在20左右
在这里插入图片描述
将模型配置loftr\utils\cvpr_ds_config.py 中的尺寸修改如下,然后重新运行export_onnx.py,导出模型,再基于webcam.py运行onnx模型,可以发现fps高达40以上。

_CN.INPUT_WIDTH = 320
_CN.INPUT_HEIGHT = 320

在这里插入图片描述

3.3 图像配准

使用Coarse_LoFTR_TRT进行图像配准可以参考
https://blog.csdn.net/a486259/article/details/140241276 中章节5的操作。操作前最好先修改 loftr\utils\cvpr_ds_config.py 的尺寸为320,具体修改如下,然后重新运行export_onnx.py,导出模型。

_CN.INPUT_WIDTH = 320
_CN.INPUT_HEIGHT = 320

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/50632.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【PyTorch】基于LSTM网络的气温预测模型实现

假设CSV文件名为temperature_data.csv,其前五行和标题如下: 这里,我们只使用Temperature列进行单步预测。以下是整合的代码示例: import pandas as pd import numpy as np import torch import torch.nn as nn import torch.op…

RocketMQ消息短暂而又精彩的一生(荣耀典藏版)

目录 前言 一、核心概念 二、消息诞生与发送 2.1.路由表 2.2.队列的选择 2.3.其它特殊情况处理 2.3.1.发送异常处理 2.3.2.消息过大的处理 三、消息存储 3.1.如何保证高性能读写 3.1.1.传统IO读写方式 3.2零拷贝 3.2.1.mmap() 3.2.2sendfile() 3.2.3.CommitLog …

Redis 7.x 系列【27】集群原理之通信机制

有道无术,术尚可求,有术无道,止于术。 本系列Redis 版本 7.2.5 源码地址:https://gitee.com/pearl-organization/study-redis-demo 文章目录 1. 概述2 节点和节点2.1 集群拓扑2.2 集群总线协议2.3 流言协议2.4 心跳机制2.5 节点握…

OpenGauss和GaussDB有何不同

OpenGauss和GaussDB是两个不同的数据库产品,它们都具有高性能、高可靠性和高可扩展性等优点,但是它们之间也有一些区别和相似之处。了解它们之间的关系、区别、建议、适用场景和如何学习,对于提高技能和保持行业敏感性非常重要。本文将深入探…

蓝桥强化宝典(4)Dijkstra

前言 Dijkstra算法(迪杰斯特拉算法),又称狄克斯特拉算法,是由荷兰计算机科学家Edsger W. Dijkstra于1959年提出的。该算法主要用于在加权图中查找从一个起始节点到所有其他节点的最短路径,解决的是有权图中最短路径问题…

NLP基础知识2【各种大模型的注意力】

注意力 传统Attention存在的问题优化方向变体有哪些现在的主要变体集中在KVMulti-Query AttentionGrouped-query AttentionFlashAttention 传统Attention存在的问题 上下文约束速度慢,显存占用大(因为注意力考虑整体信息,所以每一个位置都要…

Study--Oracle-07-ASM相关参数(四)

一、ASM主要进程 1、ASM主要后台进程 ASM实例除了传统的DBWn、LGWR、CKPT、SMON和PMON等进程还包含如下几个新后台进程: 2、牛人笔记 邦德图文解读ASM架构,超详细 - 墨天轮 二、数据库实例于ASM实例之间的交互关系 数据库实例与ASM实例之间的交互关系涉及多个步骤和过程,…

PHP家政系统自营+多商户独立端口系统源码小程序

家政行业的新篇章 引言:家政行业的数字化转型 近年来,随着科技的飞速发展和人们生活节奏的加快,家政服务行业也迎来了数字化转型的浪潮。为了提升服务效率、优化用户体验,越来越多的家政公司开始探索“家政系统自营多商户小程序…

用yoloV5做一个口罩检测的全流程实现

制作数据集 收集相关图片: 可以使用爬虫在百度爬取。爬虫代码如下: # -*- coding: UTF-8 -*-""" import requests import tqdmdef configs(search, page, number):""":param search::param page::param number::return:…

界面控件Telerik UI for WPF 2024 Q2亮点 - 全新的AIPrompt组件

Telerik UI for WPF拥有超过100个控件来创建美观、高性能的桌面应用程序,同时还能快速构建企业级办公WPF应用程序。UI for WPF支持MVVM、触摸等,创建的应用程序可靠且结构良好,非常容易维护,其直观的API将无缝地集成Visual Studio…

C++:类和对象2

1.类的默认成员函数 默认成员函数就是用户没有显示实现编译器会自动生成的成员函数称为默认成员函数。一个类,我们在不写的情况下编译器会默认生成6个默认成员函数,分别是构造函数,析构函数,拷贝构造函数,拷贝赋值运算…

kitti数据集转为bag

下载原始的数据集后,通过终端来运行: unzip 2011_10_03_calib.zip和 unzip 2011_10_03_drive_0047_sync.zip这样这个文件夹才算准备好: 然后去下载kitti2bag工具: pip install kitti2bag然后去2011_10_03文件夹下执行&#xf…

大疆创新2025校招内推

大疆2025校招-内推 一、我们是谁? 大疆研发软件团队,致力于把大疆的硬件设备和大疆用户紧密连接在一起,我们的使命是“让机器有温度,让数据会说话”。 在消费和手持团队,我们的温度来自于激发用户灵感并助力用户创作…

嵌入式C++、STM32、MySQL、GPS、InfluxDB和MQTT协议数据可视化:智能物流管理系统设计思路流程(附代码示例)

目录 项目概述 系统设计 硬件设计 软件设计 系统架构图 代码实现 1. STM32微控制器与传感器代码 代码讲解 2. MQTT Broker设置 3. 数据接收与处理 代码讲解 4. 数据存储与分析 5. 数据分析与可视化 代码讲解 6. 数据可视化 项目总结 项目概述 随着电子商务的快…

深度学习目标检测入门实战

深度学习目标检测入门实战 一、什么是目标检测二、目标检测常用的数据集(开源)(一)VOC数据集(1)背景知识(2)数据集的下载(3)VOC2007 数据集的标注&#xff08…

C++从入门到起飞之——初始化列表类型转换static成员 全方位剖析!

🌈个人主页:秋风起,再归来~🔥系列专栏:C从入门到起飞 🔖克心守己,律己则安 目录 1、初始化列表 2、 类型转换 3. static成员 4、完结散花 1、初始化列表 • 之前我们实现构造函数…

[Unity] ShaderGraph实现DeBuff污染 溶解叠加效果

本篇是在之前的基础上,继续做的功能衍生。 [Unity] ShaderGraph实现Sprite消散及受击变色 完整连连看如下所示:

简洁高效的设备稼动率采集系统(一)

前言: 在自动化生产行业,每个公司都需要一款高效的生产设备,那我们怎么体现出设备的高效呢? 可以采集设备的状态,经过成熟的算法,得到设备的稼动率。设备稼动率是衡量生产设备在一定时间内真正处于生产状态…

Django学习(二)

get请求 练习: views.py def test_method(request):if request.method GET:print(request.GET)# 如果链接中没有参数a会报错print(request.GET[a])# 使用这个方法,当查询不到参数时,不会报错而是返回你设置的值print(request.GET.get(c,n…

ElasticSearch(三)—文档字段参数设置以及元字段

一、 字段参数设置 analyzer: 指定分词器。elasticsearch 是一款支持全文检索的分布式存储系统,对于 text类型的字段,首先会使用分词器进行分词,然后将分词后的词根一个一个存储在倒排索引中,后续查询主要是针对词根…