基于ROPNet项目训练modelnet40数据集进行3d点云的配置

项目地址: https://github.com/zhulf0804/ROPNet 在 MVP Registration Challenge (ICCV Workshop 2021)(ICCV Workshop 2021)中获得了第二名。项目可以在win10环境下运行。
论文地址: https://arxiv.org/abs/2107.02583

网络简介: 一种新的深度学习模型,该模型利用具有区别特征的代表性重叠点进行配准,将部分到部分配准转换为部分完全配准。基于pointnet输出的特征设计了一个上下文引导模块,使用一个编码器来提取全局特征来预测点重叠得分。为了更好地找到有代表性的重叠点,使用提取的全局特征进行粗对齐。然后,引入一种变压器来丰富点特征,并基于点重叠得分和特征匹配去除非代表性点。在部分到完全的模式下建立相似度矩阵,最后采用加权支持向量差来估计变换矩阵。
在这里插入图片描述
实施效果: 从数据上看ROPNet与RPMNet与保持了断崖式的领先地位
在这里插入图片描述

1、运行环境安装

1.1 项目下载

打开https://github.com/zhulf0804/ROPNet,点Download ZIP然后将代码解压到指定目录下即可。
在这里插入图片描述

1.2 依赖项安装

在装有pytorch的环境终端,进入ROPNet-master/src目录,执行以下安装命令。如果已经安装了torch 环境和open3d包,则不用再进行安装了

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118pip install open3d

1.3 模型与数据下载

modelnet40数据集 here [435M]
数据集下载后存储为以下路径即可。
在这里插入图片描述

官网预训练模型,无。
第三方预训练模型:使用ROPNet项目在modelnet40数据集上训练的模型

2、关键代码

2.1 dataloader

作者所提供的dataloader只能加载https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip 数据集,其所返回的tgt_cloud, src_cloud实质上是基于一个点云采样而来的。 其中的self.label2cat, self.cat2label, self.symmetric_labels等对象代码实际上是没有任何作用的。

import copy
import h5py
import math
import numpy as np
import os
import torchfrom torch.utils.data import Dataset
import sysBASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOR_DIR = os.path.dirname(BASE_DIR)
sys.path.append(ROOR_DIR)
from utils import  random_select_points, shift_point_cloud, jitter_point_cloud, \generate_random_rotation_matrix, generate_random_tranlation_vector, \transform, random_crop, shuffle_pc, random_scale_point_cloud, flip_pchalf1 = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl','car', 'chair', 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser','flower_pot', 'glass_box', 'guitar', 'keyboard', 'lamp']
half1_symmetric = ['bottle', 'bowl', 'cone', 'cup', 'flower_pot', 'lamp']half2 = ['laptop', 'mantel', 'monitor', 'night_stand', 'person', 'piano','plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs', 'stool','table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox']
half2_symmetric = ['tent', 'vase']class ModelNet40(Dataset):def __init__(self, root, split, npts, p_keep, noise, unseen, ao=False,normal=False):super(ModelNet40, self).__init__()self.single = False # for specific-class visualizationassert split in ['train', 'val', 'test']self.split = splitself.npts = nptsself.p_keep = p_keepself.noise = noiseself.unseen = unseenself.ao = ao # Asymmetric Objectsself.normal = normalself.half = half1 if split in 'train' else half2self.symmetric = half1_symmetric + half2_symmetricself.label2cat, self.cat2label = self.label2category(os.path.join(root, 'shape_names.txt'))self.half_labels = [self.cat2label[cat] for cat in self.half]self.symmetric_labels = [self.cat2label[cat] for cat in self.symmetric]files = [os.path.join(root, 'ply_data_train{}.h5'.format(i))for i in range(5)]if split == 'test':files = [os.path.join(root, 'ply_data_test{}.h5'.format(i))for i in range(2)]self.data, self.labels = self.decode_h5(files)print(f'split: {self.split}, unique_ids: {len(np.unique(self.labels))}')if self.split == 'train':self.Rs = [generate_random_rotation_matrix() for _ in range(len(self.data))]self.ts = [generate_random_tranlation_vector() for _ in range(len(self.data))]def label2category(self, file):with open(file, 'r') as f:label2cat = [category.strip() for category in f.readlines()]cat2label = {label2cat[i]: i for i in range(len(label2cat))}return label2cat, cat2labeldef decode_h5(self, files):points, normal, label = [], [], []for file in files:f = h5py.File(file, 'r')cur_points = f['data'][:].astype(np.float32)cur_normal = f['normal'][:].astype(np.float32)cur_label = f['label'][:].flatten().astype(np.int32)if self.unseen:idx = np.isin(cur_label, self.half_labels)cur_points = cur_points[idx]cur_normal = cur_normal[idx]cur_label = cur_label[idx]if self.ao and self.split in ['val', 'test']:idx = ~np.isin(cur_label, self.symmetric_labels)cur_points = cur_points[idx]cur_normal = cur_normal[idx]cur_label = cur_label[idx]if self.single:idx = np.isin(cur_label, [8])cur_points = cur_points[idx]cur_normal = cur_normal[idx]cur_label = cur_label[idx]points.append(cur_points)normal.append(cur_normal)label.append(cur_label)points = np.concatenate(points, axis=0)normal = np.concatenate(normal, axis=0)data = np.concatenate([points, normal], axis=-1).astype(np.float32)label = np.concatenate(label, axis=0)return data, labeldef compose(self, item, p_keep):tgt_cloud = self.data[item, ...]if self.split != 'train':np.random.seed(item)R, t = generate_random_rotation_matrix(), generate_random_tranlation_vector()else:tgt_cloud = flip_pc(tgt_cloud)R, t = generate_random_rotation_matrix(), generate_random_tranlation_vector()src_cloud = random_crop(copy.deepcopy(tgt_cloud), p_keep=p_keep[0])src_size = math.ceil(self.npts * p_keep[0])tgt_size = self.nptsif len(p_keep) > 1:tgt_cloud = random_crop(copy.deepcopy(tgt_cloud),p_keep=p_keep[1])tgt_size = math.ceil(self.npts * p_keep[1])src_cloud_points = transform(src_cloud[:, :3], R, t)src_cloud_normal = transform(src_cloud[:, 3:], R)src_cloud = np.concatenate([src_cloud_points, src_cloud_normal],axis=-1)src_cloud = random_select_points(src_cloud, m=src_size)tgt_cloud = random_select_points(tgt_cloud, m=tgt_size)if self.split == 'train' or self.noise:src_cloud[:, :3] = jitter_point_cloud(src_cloud[:, :3])tgt_cloud[:, :3] = jitter_point_cloud(tgt_cloud[:, :3])tgt_cloud, src_cloud = shuffle_pc(tgt_cloud), shuffle_pc(src_cloud)return src_cloud, tgt_cloud, R, tdef __getitem__(self, item):src_cloud, tgt_cloud, R, t = self.compose(item=item,p_keep=self.p_keep)if not self.normal:tgt_cloud, src_cloud = tgt_cloud[:, :3], src_cloud[:, :3]return tgt_cloud, src_cloud, R, tdef __len__(self):return len(self.data)

2.2 模型设计

模型设计如下:
在这里插入图片描述

2.3 loss设计

其主要包含Init_loss、Refine_loss和Ol_loss。
其中Init_loss是用于计算 预测点 云 0 预测点云_0 预测点0与目标点云的mse或mae loss,
Refine_loss用于计算 预测点 云 [ 1 : ] 预测点云_{[1:]} 预测点[1:]与目标点云的加权mae loss
Ol_loss用于计算两个输入点云输出的重叠分数,使两个点云对应点的重叠分数是一样的。
在这里插入图片描述

具体实现代码如上:


import math
import torch
import torch.nn as nn
from utils import square_distsdef Init_loss(gt_transformed_src, pred_transformed_src, loss_type='mae'):losses = {}num_iter = 1if loss_type == 'mse':criterion = nn.MSELoss(reduction='mean')for i in range(num_iter):losses['mse_{}'.format(i)] = criterion(pred_transformed_src[i],gt_transformed_src)elif loss_type == 'mae':criterion = nn.L1Loss(reduction='mean')for i in range(num_iter):losses['mae_{}'.format(i)] = criterion(pred_transformed_src[i],gt_transformed_src)else:raise NotImplementedErrortotal_losses = []for k in losses:total_losses.append(losses[k])losses = torch.sum(torch.stack(total_losses), dim=0)return lossesdef Refine_loss(gt_transformed_src, pred_transformed_src, weights=None, loss_type='mae'):losses = {}num_iter = len(pred_transformed_src)for i in range(num_iter):if weights is None:losses['mae_{}'.format(i)] = torch.mean(torch.abs(pred_transformed_src[i] - gt_transformed_src))else:losses['mae_{}'.format(i)] = torch.mean(torch.sum(weights * torch.mean(torch.abs(pred_transformed_src[i] -gt_transformed_src), dim=-1)/ (torch.sum(weights, dim=-1, keepdim=True) + 1e-8), dim=-1))total_losses = []for k in losses:total_losses.append(losses[k])losses = torch.sum(torch.stack(total_losses), dim=0)return lossesdef Ol_loss(x_ol, y_ol, dists):CELoss = nn.CrossEntropyLoss()x_ol_gt = (torch.min(dists, dim=-1)[0] < 0.05 * 0.05).long() # (B, N)y_ol_gt = (torch.min(dists, dim=1)[0] < 0.05 * 0.05).long() # (B, M)x_ol_loss = CELoss(x_ol, x_ol_gt)y_ol_loss = CELoss(y_ol, y_ol_gt)ol_loss = (x_ol_loss + y_ol_loss) / 2return ol_lossdef cal_loss(gt_transformed_src, pred_transformed_src, dists, x_ol, y_ol):losses = {}losses['init'] = Init_loss(gt_transformed_src,pred_transformed_src[0:1])if x_ol is not None:losses['ol'] = Ol_loss(x_ol, y_ol, dists)losses['refine'] = Refine_loss(gt_transformed_src,pred_transformed_src[1:],weights=None)alpha, beta, gamma = 1, 0.1, 1if x_ol is not None:losses['total'] = losses['init'] + beta * losses['ol'] + gamma * losses['refine']else:losses['total'] = losses['init'] + losses['refine']return losses

3、训练与预测

先进入src目录,并将modelnet40_ply_hdf5_2048.zip解压在src目录下
在这里插入图片描述

3.1 训练

训练命令及训练输出如下所示

python train.py --root modelnet40_ply_hdf5_2048/ --noise --unseen

python请添加图片描述
在训练过程中会在work_dirs\models\checkpoints目录下生成两个模型文件
在这里插入图片描述

3.2 验证

训练命令及训练输出如下所示

python eval.py --root modelnet40_ply_hdf5_2048/  --unseen --noise  --cuda --checkpoint work_dirs/models/checkpoints/min_rot_error.pth

请添加图片描述

3.3 测试

测试训练数据的命令如下

python vis.py --root modelnet40_ply_hdf5_2048/  --unseen --noise  --checkpoint work_dirs/models/checkpoints/min_rot_error.pth

具体配准效果如下所示,其中绿色点云为输入点云,红色点云为参考点云,蓝色点云为配准后的点云。可以看到蓝色点云基本与红色点云重合,可以确定其配准效果十分完好。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3.4 处理自己的数据集

基于该项目训练并处理自己数据的教程后续会给出。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/200097.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【PID学习笔记 6 】控制系统的性能指标之二

写在前面 上文介绍了控制系统的稳态与动态、过渡过程、阶跃响应以及阶跃信号作用下过渡过程的四种形式。本文紧接上文&#xff0c;首先总结过渡过程的分类&#xff0c;然后介绍控制系统的性能评价&#xff0c;最后重点介绍控制系统性能指标中的单项指标。 一、过渡过程的分类…

java版微信小程序商城免费搭建 java版直播商城平台规划及常见的营销模式有哪些?电商源码/小程序/三级分销

涉及平台 平台管理、商家端&#xff08;PC端、手机端&#xff09;、买家平台&#xff08;H5/公众号、小程序、APP端&#xff08;IOS/Android&#xff09;、微服务平台&#xff08;业务服务&#xff09; 2. 核心架构 Spring Cloud、Spring Boot、Mybatis、Redis …

解密IIS服务器API跨域问题的终极解决方案

在当今数字化时代&#xff0c;API已成为现代应用程序的核心组件。然而&#xff0c;当你使用IIS&#xff08;Internet Information Services&#xff09;服务器提供API时&#xff0c;你可能会遇到一个常见的挑战&#xff1a;API跨域问题。这个问题经常困扰着开发人员&#xff0c…

【Python3】【力扣题】383. 赎金信

【力扣题】题目描述&#xff1a; 题解&#xff1a; 两个字符串ransomNote和magazine&#xff0c;ransomNote中每个字母都在magazine中一一对应&#xff08;顺序可以不同&#xff09;。 即分别统计两个字符串中每个字母出现的次数&#xff0c;ransomNote中每个字母的个数小于等…

聚观早报 |国行PS5轻薄版开售;岚图汽车11月交付7006辆

【聚观365】12月2日消息 国行PS5轻薄版开售 岚图汽车11月交付7006辆 比亚迪推出12月限时优惠 特斯拉正式交付首批Cybertruck 昆仑万维发布「天工 SkyAgents」平台 国行PS5轻薄版开售 索尼最新的PlayStation5主机&#xff08;CFI-2000型号组-轻薄版&#xff09;国行版本正…

springboot详解Mybatis-Plus中分页插件PaginationInterceptor标红

1.问题描述 在springboot项目中&#xff0c;类中引用PaginationInterceptor&#xff0c;标红&#xff0c;如下图所示&#xff1a; 2.问题分析 可能是因为pom.xml中的配置原因&#xff0c;导致不支持PaginationInterceptor 3.解决问题 更换版本后 更换后&#xff0c;记得Rel…

挂耳式蓝牙耳机哪个好、性价比高的挂耳蓝牙耳机

近年来&#xff0c;开放式耳机呈现出火热的势头&#xff0c;相较传统的入耳式耳机&#xff0c;长久佩戴也不会有异物般的不适感&#xff0c;通常采用的耳挂式佩戴设计&#xff0c;不需要把耳机放进耳道里也能听见声音&#xff0c;全新的佩戴方式也更为舒适&#xff0c;能维护耳…

01.项目简介

开源数字货币交易所&#xff0c;基于Java开发的货币交易所 | BTC交易所 | ETH交易所 | 数字货币交易所 | 交易平台 | 撮合交易引擎。本项目基于SpringCloudAlibaba微服务开发&#xff0c;可用来搭建和二次开发数字货币交易所。 项目特色&#xff1a; 基于内存撮合引擎&#xf…

Navicat在分辨率不同的屏幕窗口显示大小不一致问题解决

1.主屏幕为2560*1600分辨率&#xff0c;能够显示较多数据连接 2.在第二屏幕分辨率低&#xff0c;字体变大&#xff0c;显示内容变少 解决办法&#xff1a; 1.右击navicat图标-属性 2.选择【兼容性】-在兼容性页面中选择**“更改高DPI设置”** 3…勾选“高DPI缩放替代”&a…

minio配置监听(对象操作日志)

minio配置监听对象操作 本文档适用于minio2021.3.17版本 有时我们需要查看minio中对象操作的日志&#xff0c;比如像监听minio某一个桶中的删除事件&#xff0c;就需要配置监听。minio支持将监听的结果输出到es、pg、amq等等&#xff0c;下面介绍一下将minio对象操作监听结果输…

【人体解剖学与组织胚胎学】练习一高度相联知识点整理及对应习题

文章目录 [toc]骨性鼻旁窦填空题问答题 关节填空题简答题 胸廓填空题简答题![在这里插入图片描述](https://img-blog.csdnimg.cn/direct/827e7d1db3af42858d8734bb81911fea.jpeg)补充 骨性鼻旁窦 填空题 问答题 关节 填空题 简答题 胸廓 填空题 简答题 补充 第二肋对应胸骨…

Leetcode.2477 到达首都的最少油耗

题目链接 Leetcode.2477 到达首都的最少油耗 rating : 2012 题目描述 给你一棵 n n n 个节点的树&#xff08;一个无向、连通、无环图&#xff09;&#xff0c;每个节点表示一个城市&#xff0c;编号从 0 0 0 到 n − 1 n - 1 n−1 &#xff0c;且恰好有 n − 1 n - 1 n−…

什么是呼叫中心的语音通道?呼叫中心语音线路有几种?

什么是呼叫中心的语音通道&#xff1f; 呼叫中心的语音通道是指在呼叫中心中使用的语音信号传输通道&#xff0c;它是呼叫中心中至关重要的一部分&#xff0c;负责将客户的语音信息传递给客服代表&#xff0c;以及将客服代表的语音信息传递给客户。在呼叫中心的运营中&#xf…

JAVA-JVM 之Class字节码文件的组成 【下篇】

字节码 类元数据接口元数据字段元数据方法元数据属性元数据 主页传送门&#xff1a;&#x1f4c0; 传送 类元数据 此部分元数据主要包含类索引&#xff08;This_Class&#xff09;和父类索引&#xff08;Super_Class&#xff09;。 类索引&#xff1a;指向Class字节码常量池表…

Python----Pandas

目录 Series属性 DataFrame的属性 Pandas的CSV文件 Pandas数据处理 Pandas的主要数据结构是Series&#xff08;一维数据&#xff09;与DataFrame&#xff08;二维数据&#xff09; Series属性 Series的属性如下&#xff1a; 属性描述pandas.Series(data,index,dtype,nam…

mybatis 的快速入门以及基于spring boot整合mybatis

MyBatis基础 MyBatis是一款非常优秀的持久层框架&#xff0c;用于简化JDBC的开发 准备工作&#xff1a; 1&#xff0c;创建sprong boot工程&#xff0c;引入mybatis相关依赖2&#xff0c;准备数据库表User&#xff0c;实体类User3&#xff0c; 配置MyBatis&#xff08;在applic…

2005-2021年地级市绿色发展注意力数据(根据政府报告文本词频统计)

2005-2021年地级市绿色发展注意力数据&#xff08;根据政府报告文本词频统计&#xff09; 1、时间&#xff1a;2005-2021年 2、指标&#xff1a;省、市、年份、一级指标、关键词、关键词词频、总词频 3、范围&#xff1a;270个地级市 4、来源&#xff1a;地级市政府工作报告…

【C++】动态内存管理——new和delete

这篇文章我们讲一下C的动态内存管理&#xff0c;从一个比较陌生的知识说起&#xff0c;我们知道&#xff0c;一个工程可以创建很多.c文件&#xff0c;我们如果定义一个全局变量&#xff0c;只要用extern声明一下&#xff0c;在每个文件都可以用。而用static修饰的全局变量只能在…

【ecology】通过F12抓取页面SQL

1、点击流程监控&#xff0c;打开浏览器的”开发者工具“&#xff08;F12&#xff09;&#xff1b; 2、点击搜索&#xff0c;在开发者工具中找到sessionkey&#xff0c;复制后面的值。 3、http://58.213.83.186:8081/api/ec/dev/table/getxml?dataKey 上面的网址的IP地址修改…

Gee教程6.模板(HTML Template)

这一章节的内容是介绍 Web 框架如何支持服务端渲染的场景 实现静态资源服务(Static Resource)。支持HTML模板渲染。 这一章节很多内容是基于net/http库的&#xff0c;该库已经实现了很多静态文件和HMML模板的相关功能的了。 静态文件 网页的三剑客&#xff0c;JavaScript、C…