使用PCReg.PyTorch项目训练自己的数据集进行点云配准

项目地址: https://github.com/zhulf0804/PCReg.PyTorch/tree/main
网络简介: 网络是基于PointNet + Concat + FC的,它没有其它复杂的结构,易于复现。因其简洁性,这里暂且把其称作点云配准的Benchmark。因作者源码中复杂的(四元数, 旋转矩阵, 欧拉角之间)的变换操作和冗余性,且其PyTorch版本的不完整性(缺少评估模型等,最近又更新了),
项目详细介绍: 基于深度学习的点云配准Benchmark

本文方法与常见的图像配准逻辑类似,基于采样与transfrom操作从源点云生成目标点云,然后进行训练与评测。总体看来效果不如open3d自带的fgr方法,可以作为入门级项目进行使用。

1、运行环境安装

1.1 项目下载

打开https://github.com/zhulf0804/PCReg.PyTorch/tree/main,点Download ZIP然后将代码解压到指定目录下即可。
在这里插入图片描述

1.2 依赖项安装

在装有pytorch的环境终端,进入PCReg.PyTorch-main目录,执行以下安装命令:

pip install -r requirements.txt
python -m pip install open3d>=0.9

emd loss编译
如果不做训练使用,可以不用进行emd loss编译。
cd loss/cuda/emd_torch & python setup.py install
在编译过程中,很有可能碰到报错

  File "C:\Users\Administrator\miniconda3\lib\site-packages\torch\utils\cpp_extension.py", line 499, in build_extensions_check_cuda_version(compiler_name, compiler_version)File "C:\Users\Administrator\miniconda3\lib\site-packages\torch\utils\cpp_extension.py", line 387, in _check_cuda_versionraise RuntimeError(CUDA_MISMATCH_MESSAGE.format(cuda_str_version, torch.version.cuda))
RuntimeError:
The detected CUDA version (12.1) mismatches the version that was used to compile
PyTorch (11.7). Please make sure to use the same CUDA versions.

这是由于PyTorch 的cuda版本与系统自带的cuda版本不同所导致的,可以先使用一下命令卸载过往的torch版本(慎重操作),然后重新安装torch;也可以在conda环境中重新创建一个符合系统cuda版本的torch环境。
pip uninstall torch torchvision torchaudio

在上述信息输出中,博主的cuda版本为12.1,我们可以打开pytorch官网找打符合自己电脑cuda版本的pytorch安装命令.
在这里插入图片描述
如果cuda版本较早,可以在 https://pytorch.org/get-started/previous-versions/ 中找到安装命令。

在正确的安装cuda版本后,重新执行命令即可实现emd-loss的安装
在这里插入图片描述

1.3 模型与数据下载

  • modelnet40数据集 [here, 435M]

  • 可用的预训练模型 [Complete, pwd: c4z7, 16.09 M] or [Paritial, pwd: pcno, 16.09] first) 模型下载好后将其放置到PCReg.PyTorch项目根路径下即可。

2. 关键代码说明

2.1 数据加载器

在data目录下有CustomData.py和ModelNet40.py两个文件,其中ModelNet40文件对应modelnet40数据集的加载,CustomData文件对应自己个人数据集的加载。从两个文件的__getitem__函数中可以发现,模型不是基于数据对进行训练的。其依据ref_cloud随机采样生成ref_cloud,然后对ref_cloud进行transform操作。具体实例如下所示:

    def __getitem__(self, item):file = self.files[item]ref_cloud = readpcd(file, rtype='npy')ref_cloud = random_select_points(ref_cloud, m=self.npts)ref_cloud = pc_normalize(ref_cloud)R, t = generate_random_rotation_matrix(-20, 20), \generate_random_tranlation_vector(-0.5, 0.5)src_cloud = transform(ref_cloud, R, t)if self.train:ref_cloud = jitter_point_cloud(ref_cloud)src_cloud = jitter_point_cloud(src_cloud)return ref_cloud, src_cloud, R, t

在以上代码中需要注意的是,所有的点云都进行了坐标值的归一化处理

2.2 模型结构

在model目录下有benchmark.py、fgr.py、icp.py,分别为模型配准,fgr配准,icp配准方法。其中fgr配准与icp配准方法是使用open3d库实现。

benchmark

benchmark为本文模型,其是基于PointNet所实现的一个孪生网络,核心代码如Benchmark类所示。其基于encoder提取2个点云的特征,然后简单的使用全连接层将两个点云的特征进行交互,然后再输出两个点云的特征。在这里最为重要的是loss的设计,即如何设计优化目标,使模型参数以点云配准为优化方向

class Benchmark(nn.Module):def __init__(self, gn, in_dim1, in_dim2=2048, fcs=[1024, 1024, 512, 512, 256, 7]):super(Benchmark, self).__init__()self.in_dim1 = in_dim1self.encoder = PointNet(in_dim=in_dim1, gn=gn)self.decoder = nn.Sequential()for i, out_dim in enumerate(fcs):self.decoder.add_module(f'fc_{i}', nn.Linear(in_dim2, out_dim))if out_dim != 7:if gn:self.decoder.add_module(f'gn_{i}',nn.GroupNorm(8, out_dim))self.decoder.add_module(f'relu_{i}', nn.ReLU(inplace=True))in_dim2 = out_dimdef forward(self, x, y):x_f, y_f = self.encoder(x), self.encoder(y)concat = torch.cat((x_f, y_f), dim=1)out = self.decoder(concat)batch_t, batch_quat = out[:, :3], out[:, 3:] / torch.norm(out[:, 3:], dim=1, keepdim=True)batch_R = batch_quat2mat(batch_quat)if self.in_dim1 == 3:transformed_x = batch_transform(x.permute(0, 2, 1).contiguous(),batch_R, batch_t)elif self.in_dim1 == 6:transformed_pts = batch_transform(x.permute(0, 2, 1)[:, :, :3].contiguous(),batch_R, batch_t)transformed_nls = batch_transform(x.permute(0, 2, 1)[:, :, 3:].contiguous(),batch_R)transformed_x = torch.cat([transformed_pts, transformed_nls], dim=-1)else:raise ValueErrorreturn batch_R, batch_t, transformed_xclass IterativeBenchmark(nn.Module):def __init__(self, in_dim, niters, gn):super(IterativeBenchmark, self).__init__()self.benckmark = Benchmark(gn=gn, in_dim1=in_dim)self.niters = nitersdef forward(self, x, y):transformed_xs = []device = x.deviceB = x.size()[0]transformed_x = torch.clone(x)batch_R_res = torch.eye(3).to(device).unsqueeze(0).repeat(B, 1, 1)batch_t_res = torch.zeros(3, 1).to(device).unsqueeze(0).repeat(B, 1, 1)for i in range(self.niters):batch_R, batch_t, transformed_x = self.benckmark(transformed_x, y)transformed_xs.append(transformed_x)batch_R_res = torch.matmul(batch_R, batch_R_res)batch_t_res = torch.matmul(batch_R, batch_t_res) \+ torch.unsqueeze(batch_t, -1)transformed_x = transformed_x.permute(0, 2, 1).contiguous()batch_t_res = torch.squeeze(batch_t_res, dim=-1)#transformed_x = transformed_x.permute(0, 2, 1).contiguous()return batch_R_res, batch_t_res, transformed_xs
fgr配准方法
import copy
import open3d as o3ddef fpfh(pcd, normals):pcd.normals = o3d.utility.Vector3dVector(normals)pcd_fpfh = o3d.registration.compute_fpfh_feature(pcd,o3d.geometry.KDTreeSearchParamHybrid(radius=0.3, max_nn=64))return pcd_fpfhdef execute_fast_global_registration(source, target, source_fpfh, target_fpfh):distance_threshold = 0.01result = o3d.registration.registration_fast_based_on_feature_matching(source, target, source_fpfh, target_fpfh,o3d.registration.FastGlobalRegistrationOption(maximum_correspondence_distance=distance_threshold))transformation = result.transformationestimate = copy.deepcopy(source)estimate.transform(transformation)R, t = transformation[:3, :3], transformation[:3, 3]return R, t, estimatedef fgr(source, target, src_normals, tgt_normals):source_fpfh = fpfh(source, src_normals)target_fpfh = fpfh(target, tgt_normals)R, t, estimate = execute_fast_global_registration(source=source,target=target,source_fpfh=source_fpfh,target_fpfh=target_fpfh)return R, t, estimate
ICP配准方法
import copy
import numpy as np
import open3d as o3ddef icp(source, target):max_correspondence_distance = 2 # 0.5 in RPM-Netinit = np.eye(4, dtype=np.float32)estimation_method = o3d.pipelines.registration.TransformationEstimationPointToPoint()reg_p2p = o3d.pipelines.registration.registration_icp(source=source,target=target,init=init,max_correspondence_distance=max_correspondence_distance,estimation_method=estimation_method)transformation = reg_p2p.transformationestimate = copy.deepcopy(source)estimate.transform(transformation)R, t = transformation[:3, :3], transformation[:3, 3]return R, t, estimate

3.基本使用

modelnet40数据集的评测及训练可以使用一下代码实现

    # Iterative Benchmarkpython modelnet40_evaluate.py --root your_data_path/modelnet40_ply_hdf5_2048 --checkpoint your_ckpt_path/test_min_loss.pth --cuda# Visualization# python modelnet40_evaluate.py --root your_data_path/modelnet40_ply_hdf5_2048 --checkpoint your_ckpt_path/test_min_loss.pth  --show# ICP# python modelnet40_evaluate.py --root your_data_path/modelnet40_ply_hdf5_2048 --method icp# FGR# python modelnet40_evaluate.py --root your_data_path/modelnet40_ply_hdf5_2048 --method fgr --normal
  • train

    CUDA_VISIBLE_DEVICES=0 python modelnet40_train.py --root your_data_path/modelnet40_ply_hdf5_2048
    

这里注意讲述训练与评测自己的数据集,其中自己数据集的路径如下所示,里面都是处理好的pcd点云数据。
在这里插入图片描述
具体格式为:

    |- CustomData(dir)|- train_data(dir)- train1.pcd- train2.pcd- ...|- val_data(dir)- val1.pcd- val2.pcd- ...

3.1 ICP方法性能评测

可以加上 --show 参数来查看每一个配准的数据

python custom_evaluate.py --root cumstom_data --infer_npts 2048  --method icp --normal

如果出现以下报错,则用open3d.pipelines.registration替换open3d.registration,具体可以用本博文的ICP配准方法替换掉models\icp.py中的内容

Traceback (most recent call last):File "custom_evaluate.py", line 142, in <module>evaluate_icp(args, test_loader)File "custom_evaluate.py", line 98, in evaluate_icpR, t, pred_ref_cloud = icp(npy2pcd(src_cloud), npy2pcd(ref_cloud))File "D:\点云AI配准\PCReg.PyTorch-main\models\icp.py", line 9, in icpestimation_method = o3d.registration.TransformationEstimationPointToPoint()
AttributeError: module 'open3d' has no attribute 'registration'

具体执行输出如下所示:
在这里插入图片描述

3.2 模型训练

训练命令:
python custom_train.py --root cumstom_data --train_npts 2048
在这里插入图片描述
训练好的模型保存在work_dirs\models\checkpoints目录中
在这里插入图片描述

评测命令:
python custom_evaluate.py --infer_npts 2048 --root cumstom_data --checkpoint work_dirs\models\checkpoints\test_min_loss.pth --show
其中绿色点云为源点云,红色点云为参考点云,蓝色点云为配准后的源点云,可以看到蓝色点云与红色点云完全没有对齐,这表明训练效果极其不佳。这或许是训练数据太少所导致的,毕竟本次实验只有18个点云数据。
在这里插入图片描述

4、原文效果

下图是作者论文中的配准效果图
在这里插入图片描述
在modelnet40数据集上相关精度信息如下所示,可以确定,本文方法与FGR方法相比没有显著性优势。

  • Point-to-Point Correspondences(R error is large due to EMDLoss, see here)
Methodisotropic Risotropic tanisotropic R(mse, mae)anisotropic t(mse, mae)time(s)
ICP11.440.1617.64(5.48)0.22(0.07)0.07
FGR0.010.000.07(0.00)0.00(0.00)0.19
IBenchmark5.680.079.77(2.69)0.12(0.03)0.022
IBenchmark + ICP3.650.049.22(1.66)0.11(0.02)
  • Noise Data(infer_npts = 1024)
Methodisotropic Risotropic tanisotropic R(mse, mae)anisotropic t(mse, mae)
ICP12.140.1718.32(5.86)0.23(0.08)
FGR4.270.0611.55(2.43)0.09(0.03)
IBenchmark6.250.089.28(2.94)0.12(0.04)
IBenchmark + ICP5.100.0710.51(2.39)0.13(0.03)
  • Partial-to-Complete Registration(infer_npts = 1024)
Methodisotropic Risotropic tanisotropic R(mse, mae)anisotropic t(mse, mae)
ICP21.330.3222.83(10.51)0.31(0.15)
FGR9.490.1219.51(5.58)0.17(0.06)
IBenchmark15.020.2215.78(7.45)0.21(0.10)
IBenchmark + ICP9.210.1314.73(4.43)0.18(0.06)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/192075.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

剑指 Offer(第2版)面试题 14:剪绳子

剑指 Offer&#xff08;第2版&#xff09;面试题 14&#xff1a;剪绳子 剑指 Offer&#xff08;第2版&#xff09;面试题 14&#xff1a;剪绳子解法1&#xff1a;动态规划解法2&#xff1a;数学 剑指 Offer&#xff08;第2版&#xff09;面试题 14&#xff1a;剪绳子 题目来源…

DOM 事件的注册和移除

前端面试大全DOM 事件的注册和移除 &#x1f31f;经典真题 &#x1f31f;DOM 注册事件 HTML 元素中注册事件 DOM0 级方式注册事件 DOM2 级方式注册事件 &#x1f31f;DOM 移除事件 &#x1f31f;真题解答 &#x1f31f;总结 &#x1f31f;经典真题 总结一下 DOM 中如何…

TCP连接为什么是三次握手,而不是两次和四次

答案 阻止重复的历史连接同步初始序列号避免资源浪费 原因 阻止重复的历史连接&#xff08;首要原因&#xff09; 考虑这样一种情况&#xff1a; 客户端现在要给服务端建立连接&#xff0c;向服务端发送了一个SYN报文段&#xff08;第一次握手&#xff09;&#xff0c;以表示请…

固定Microsoft Edge浏览器的位置设置,避免自动回调至中国

问题描述 在使用Copilot等功能时&#xff0c;需要将Microsoft Edge浏览器的位置设置为国外。但每次重新打开浏览器后&#xff0c;位置设置又自动回调至中国&#xff0c;导致每次均需要手动调整。 原因分析 这个问题的出现是因为每次启动Microsoft Edge时&#xff0c;默认打开…

cmake和vscode 下的cmake的使用详解(三)

第七讲&#xff1a;【实战】使用 VSCode 进行完整项目开发 案例&#xff1a;士兵突击 需求&#xff1a; 1. 士兵 许三多 有一把 枪 &#xff0c;叫做 AK47 2. 士兵 可以 开火 3. 士兵 可以 给枪装填子弹 4. 枪 能够 发射 子弹 5. 枪 能够 装填子弹 ——…

2022年9月6日 Go生态洞察:Go的漏洞管理新支持

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页——&#x1f405;&#x1f43e;猫头虎的博客&#x1f390; &#x1f433; 《面试题大全专栏》 &#x1f995; 文章图文…

docker-速通

1.命令-镜像操作 docker pull nginx #下载最新版 docker pull nginx:1.20.1 #下载指定版本 镜像名:版本名&#xff08;标签&#xff09; docker images #查看所有镜像 # 如果只写镜像名实际就是redis redis:latest 记住这个不是命令 docker rmi 镜像名:版本号/镜像id…

利用段落检索和生成模型进行开放域问答12.2

利用段落检索和生成模型进行开放域问答 摘要引言2 相关工作3 方法 摘要 事实证明&#xff0c;开放域问答的生成模型具有竞争力&#xff0c;无需借助外部知识。虽然很有希望&#xff0c;但这种方法需要使用具有数十亿个参数的模型&#xff0c;而这些模型的训练和查询成本很高。…

在linux服上部署vue+springboot+nginx项目

一、环境准备 1、安装winscp便于可视化操作linux&#xff1a;winscp安装及关联putty使用_putty.exe没有找到_cherishSpring的博客-CSDN博客 2、安装jdk&#xff1a;linux系统安装jdk-CSDN博客 3、安装mysql&#xff1a;Linux7安装mysql数据库以及navicat远程连接mysql-CSDN博…

Fiddler抓包工具之fiddler设置断点和简单的并发测试

断点有两种方式&#xff1a; 1、全局断点 2、局部断点 全局断点 全局断点的特点是&#xff1a;不能针对一个请求&#xff0c;是给所有抓到的请求打断点 全局断点如何设置&#xff1a; 1、快速设置断点&#xff1a;直接点击底部状态栏断点处 &#xff1b;点击第一下是请求…

【算法专题】二分查找

二分查找 二分查找1. 二分查找2. 在排序数组中查找元素的第一和最后一个位置3. 搜索插入位置4. x 的平方根5. 山脉数组的峰顶索引6. 寻找峰值7. 寻找旋转排序数组中的最小值8. 点名 二分查找 1. 二分查找 题目链接 -> Leetcode -704.二分查找 Leetcode -704.二分查找 题…

【Geoserver】SLD点位样式(PointSymbolizer)设计全通

SLD文件可以控制geoserver的样式管理&#xff0c;这里专门针对点位进行设计&#xff0c;首先点位的设计需要用到这面这个大标签 之前的项目中已经用到了很多关于面的样式管理&#xff0c;这里新学习的是关于点的样式管理 PointSymbolizer 参考资料地址&#xff1a;https://doc…

LeetCode算法题解(动态规划)|LeetCode1143. 最长公共子序列、LeetCode1035. 不相交的线、LeetCode53. 最大子数组和

一、LeetCode1143. 最长公共子序列 题目链接&#xff1a;1143. 最长公共子序列 题目描述&#xff1a; 给定两个字符串 text1 和 text2&#xff0c;返回这两个字符串的最长 公共子序列 的长度。如果不存在 公共子序列 &#xff0c;返回 0 。 一个字符串的 子序列 是指这样一…

论文编写软件latex安装教程

目录 1.下载安装包2.安装texlive 本人系统为windows&#xff0c;本教程基于windows系统&#xff0c;如果是其它系统请参考对应教程&#xff0c;注意选择对应系统的安装包&#xff01; 1.下载安装包 有三种集成环境安装包 texlive 是主流的环境&#xff0c;集成了较多的包&…

【数据结构】二叉树---C语言版

二叉树 一、树的概念及结构1.树的概念2.树的相关概念3.树的表示4.树在实际中的应用 二、二叉树的概念及结构1.二叉树的概念2.满二叉树3.完全二叉树4.二叉树的性质5.二叉树的储存结构 三、二叉树的遍历1.前序遍历2.中序遍历3.后序遍历4.层序遍历 四、手撕二叉树&#xff08;务必…

MySQL 临时数据空间不足导致SQL被killed 的问题与扩展

开头还是介绍一下群&#xff0c;如果感兴趣PolarDB ,MongoDB ,MySQL ,PostgreSQL ,Redis, Oceanbase, Sql Server等有问题&#xff0c;有需求都可以加群群内&#xff0c;可以解决你的问题。加群请联系 liuaustin3 &#xff0c;&#xff08;共1730人左右 1 2 3 4 5&#xff0…

Mover Creator--功能简介

Mover Creator是一款AFSIM软件工具&#xff0c;提供方便易用的基于GUI的应用程序&#xff0c;帮助用户创建用于空中运动器的AFSIM输入文件&#xff0c;包括WSF_P6DOF_MOVER和WSF_GUIDED_MOVER。使用自定义定义的基于图形的模型定义&#xff0c;用户可以对飞机、武器和发动机进行…

邮政快递查询,邮政快递单号查询,用表格导出查询好的物流信息

批量查询邮政快递单号的物流信息&#xff0c;并以表格的形式导出查询好的物流信息。 所需工具&#xff1a; 一个【快递批量查询高手】软件 邮政快递单号若干 操作步骤&#xff1a; 步骤1&#xff1a;运行【快递批量查询高手】软件&#xff0c;第一次使用的伙伴记得先注册&am…

linux后端基础---笔记整理(tmux、vim、shell、ssh/scp、git、thrift、docker)

目录 1.Linux常用文件管理命令 2.tmux终端复用器/vim命令式文本编辑器 3.Shell语法 3.1 Shell—版本3.2 新建一个test.sh文件3.3 Shell文件—运行方式3.4 Shell—注释3.5 Shell—变量3.6 Shell—默认变量&#xff0c;文件参数, “$”的用法3.7 Shell—数组3.8 shell—expr命令…

AD7124-4 实测热电偶数据读取,电压精度到稳定到±1uV, 电压波动260nV, 温度精度到±0.01℃

AD7124-4 实测热电偶数据读取&#xff0c;电压精度到稳定到1uV, 电压波动260nV, 温度精度到0.01℃ AD7124_STM32_ADI官网例程使用stm32 和ad7124做温控调试&#xff0c;发现效果还是不错的&#xff0c;至少比ads1256的效果好多啦&#xff01;Chapter1 AD7124-4 实测热电偶数据读…