point transformer v3复现及核心代码详解

point transformer v3复现及核心代码详解

  • 1. 复现
    • 1.1 复现
    • 1.2 数据预处理
    • 1.3 跑通
  • 2. 核心代码详解
    • 2.1 读取数据
    • 2.2 dataloder
    • 2.3 模型读取数据的逻辑
    • 2.4 forward
      • 2.4.1 Point
      • 2.4.2 backbone
        • 2.4.2.1 point.serialization
        • 2.4.2.2 稀疏化
        • 2.4.2.3 embedding
        • 2.4.2.4 encoder

1. 复现

1.1 复现

根据源码的Github地址,下载源代码

git clone https://gitcode.com/gh_mirrors/po/PointTransformerV3.git

配置环境:

conda create -n pointcept python=3.8 -y
conda activate pointcept
conda install ninja -y

安装PyTorch,这里我是安装的torch1.11.0

pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113

安装环境依赖:

pip install h5py pyyaml haredarray tensorboard tensorboardx yapf addict einops scipy plyfile termcolor timm -i https://pypi.tuna.tsinghua.edu.cn/simple

其中,haredarray这个包在Windows下可能会报错,改为:

pip install git+https://github.com/imaginary-friend94/SharedNumpyArray

安装成功后,修改Pointcept/pointcept/utils/cache.py,将SharedArray改为:

# import SharedArray
import numpysharedarray

将代码中所有的sharedarray.attach改为:

if os.path.exists(f"/dev/shm/{name}"):# return SharedArray.attach(f"shm://{name}")return numpysharedarray.attach_mem_sh(f"shm://{name}")

接着,继续安装其它的依赖:

conda install pytorch-cluster pytorch-scatter pytorch-sparse -c pyg -y
pip install torch-geometric

这里Windows一就会报错,选择本地whl安装,链接,根据自己的torch以及cuda版本选择对应的包:
在这里插入图片描述

pip install torch_cluster-1.6.0-cp38-cp38-win_amd64.whl
pip install torch_cluster-1.6.2+pt21cu118-cp38-cp38-win_amd64.whl
pip install torch_scatter-2.0.9-cp38-cp38-win_amd64.whl
# 上述whl安装好,正常安装torch-geometric
pip install torch-geometric

然后,安装pointops,也是window中很容易报错的:

cd Pointcept/libs/pointops
python setup.py install

如果装不上去,报错:AttributeError: ‘NoneType’ object has no attribute ‘split’:

Traceback (most recent call last):
File “setup.py”, line 8, in
flag for flag in opt.split() if flag != “-Wstrict-prototypes”
AttributeError: ‘NoneType’ object has no attribute ‘split’

Pointcept/libs/pointops/setup.py中将这块代码进行注释
在这里插入图片描述

(opt,) = get_config_vars("OPT")
# os.environ["OPT"] = " ".join(
#     flag for flag in opt.split() if flag != "-Wstrict-prototypes"
# )

最后,根据cuda的版本安装稀疏卷积spconv

pip install spconv-cu113

1.2 数据预处理

这里以S3DIS场景点云数据集为例,进行数据处理,可以在链接中下载该数据集。
我这里下载的是Stanford3dDataset_v1.2。
在这里插入图片描述

# S3DIS without aligned angle
python pointcept/datasets/preprocessing/s3dis/preprocess_s3dis.py --dataset_root ${S3DIS_DIR} --output_root ${PROCESSED_S3DIS_DIR}
# S3DIS with aligned angle
python pointcept/datasets/preprocessing/s3dis/preprocess_s3dis.py --dataset_root ${S3DIS_DIR} --output_root ${PROCESSED_S3DIS_DIR} --align_angle
# S3DIS with normal vector (recommended, normal is helpful)
python pointcept/datasets/preprocessing/s3dis/preprocess_s3dis.py --dataset_root ${S3DIS_DIR} --output_root ${PROCESSED_S3DIS_DIR} --raw_root ${RAW_S3DIS_DIR} --parse_normal
python pointcept/datasets/preprocessing/s3dis/preprocess_s3dis.py --dataset_root ${S3DIS_DIR} --output_root ${PROCESSED_S3DIS_DIR} --raw_root ${RAW_S3DIS_DIR} --align_angle --parse_normal

–dataset_root指定下载好的数据集路径,–output_root指定数据预处理后存放的路径。

1.3 跑通

这里,我选择Pointcept/configs/s3dis/semseg-pt-v3m1-1-rpe.py作为模型的配置文件。
训练脚本文件位于Pointcept/tools/train.py
将配置文件中的数据集路径更改为上述预处理好的路径:
在这里插入图片描述
开始训练:

cd Pointcept/tools
python train.py --config-file D:\PointTransformerV3\Pointcept\configs\s3dis\semseg-pt-v3m1-1-rpe.py

开始训练的时候,可能会报错:AssertionError: channel size mismatch
在这里插入图片描述
将配置文件中的,模型的backbone输入通道数改为3,搞定!
在这里插入图片描述

到此,成功跑通!!!

2. 核心代码详解

整个框架全部以Pointcept/pointcept/engines/train.pyTrainer类为主,类初始化方法中有核心部分:
构建实例化模型:

self.model = self.build_model()

构建日志保存:

self.writer = self.build_writer()

构建dataloder:

self.train_loader = self.build_train_loader()

构建优化器:

self.optimizer = self.build_optimizer()
self.scheduler = self.build_scheduler()

训练脚本train,也作为该类的成员方法。
在这里插入图片描述

2.1 读取数据

build_train_loader()方法的代码如下:
在这里插入图片描述
train_data = build_dataset(self.cfg.data.train)就是构建dataset。
构造dataset类的起始文件是Pointcept/pointcept/datasets/builder.py,这里cfg就是对应的上述配文件。
在这里插入图片描述
随后,dataset类在Pointcept/pointcept/datasets/defaults.py中,
在这里插入图片描述
这块类的初始化方法中,主要是self.get_data_list()方法。会将数据集路径的地址全部读取,并存放在list中。
在这里插入图片描述

2.2 dataloder

dataloder,就是正常的步骤,将构建好的dadaset传到Pytorch官方提供的dataloder:

train_loader = torch.utils.data.DataLoader(train_data,batch_size=self.cfg.batch_size_per_gpu,shuffle=(train_sampler is None),num_workers=self.cfg.num_worker_per_gpu,sampler=train_sampler,collate_fn=partial(point_collate_fn, mix_prob=self.cfg.mix_prob),pin_memory=True,worker_init_fn=init_fn,drop_last=True,persistent_workers=True,)

2.3 模型读取数据的逻辑

实例化的datasrt类继承了Pointcept/pointcept/datasets/defaults.py中的DefaultDataset类,所有方法均在该父类中。
下面,首先看__getitem__方法:

def __getitem__(self, idx):if self.test_mode:return self.prepare_test_data(idx)else:return self.prepare_train_data(idx)

接下里是,prepare_train_data方法:
在这里插入图片描述
该方法嵌套了两个类方法。
1.get_data方法,就是将数据集中的.npy文件用numpy读取出来,并转float32类型:
在这里插入图片描述

在这里插入图片描述
get_data方法最终将数据集中的4个npy文件全部读出来:
在这里插入图片描述
2.transform方法,就是将数据集中的读取的.npy文件进行预处理:
在这里插入图片描述
最终,预处理好的如下,offset就是数据的第一个维度。
在这里插入图片描述

2.4 forward

前向传播的入口在Pointcept/pointcept/engines/train.py中的run_step函数,主要就是将数据加载到推理设备上,并进行前向传播。
在这里插入图片描述
具体的前向传播,在Pointcept/pointcept/models/default.py中的forward函数中。
在这里插入图片描述

2.4.1 Point

首先,先是Point(input_dict),Point这个类。
在这里插入图片描述
将offset转换为batch。
在这里插入图片描述
转换的具体代码如下:
在这里插入图片描述
在这里插入图片描述

2.4.2 backbone

接下来就是point transformer v3主干特征提取了。文件路径为Pointcept/pointcept/models/point_transformer_v3/point_transformer_v3m1_base.py
首先,Point(data_dict),和上述一样,上述已经生成,这里就直接return出来了。
在这里插入图片描述

2.4.2.1 point.serialization

下面,就简要介绍一下配置参数中order=[“z”, “z-trans”, “hilbert”, “hilbert-trans”]的4种空间填充曲线
在这里插入图片描述
PTv3使用空间填充曲线,如Z-order和Hilbert曲线,来遍历三维空间中的点。这些曲线能够在保持点之间空间邻近性的同时,将点映射到一个高维离散空间中。数学上,空间填充曲线可以定义为一个双射函数φ: Z^n → Z^m,其中n是空间的维度(对于点云通常是3),m是映射到的高维空间的维度。
在这里插入图片描述
具体细节代码位于Pointcept/pointcept/models/utils/structure.py,如下:

def serialization(self, order="z", depth=None, shuffle_orders=False):"""Point Cloud Serializationrelay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"]"""assert "batch" in self.keys()if "grid_coord" not in self.keys():# if you don't want to operate GridSampling in data augmentation,# please add the following augmentation into your pipline:# dict(type="Copy", keys_dict={"grid_size": 0.01}),# (adjust `grid_size` to what your want)assert {"grid_size", "coord"}.issubset(self.keys())self["grid_coord"] = torch.div(self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc").int()if depth is None:# Adaptive measure the depth of serialization cube (length = 2 ^ depth)depth = int(self.grid_coord.max()).bit_length()self["serialized_depth"] = depth# Maximum bit length for serialization code is 63 (int64)assert depth * 3 + len(self.offset).bit_length() <= 63# Here we follow OCNN and set the depth limitation to 16 (48bit) for the point position.# Although depth is limited to less than 16, we can encode a 655.36^3 (2^16 * 0.01) meter^3# cube with a grid size of 0.01 meter. We consider it is enough for the current stage.# We can unlock the limitation by optimizing the z-order encoding function if necessary.assert depth <= 16# The serialization codes are arranged as following structures:# [Order1 ([n]),#  Order2 ([n]),#   ...#  OrderN ([n])] (k, n)code = [encode(self.grid_coord, self.batch, depth, order=order_) for order_ in order]code = torch.stack(code)order = torch.argsort(code)inverse = torch.zeros_like(order).scatter_(dim=1,index=order,src=torch.arange(0, code.shape[1], device=order.device).repeat(code.shape[0], 1),)if shuffle_orders:perm = torch.randperm(code.shape[0])code = code[perm]order = order[perm]inverse = inverse[perm]self["serialized_code"] = codeself["serialized_order"] = orderself["serialized_inverse"] = inverse
2.4.2.2 稀疏化

具体细节代码位于Pointcept/pointcept/models/utils/structure.py,主要就是利用spconv系数卷积:

def sparsify(self, pad=96):"""Point Cloud SerializationPoint cloud is sparse, here we use "sparsify" to specifically refer topreparing "spconv.SparseConvTensor" for SpConv.relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"]pad: padding sparse for sparse shape."""assert {"feat", "batch"}.issubset(self.keys())if "grid_coord" not in self.keys():# if you don't want to operate GridSampling in data augmentation,# please add the following augmentation into your pipline:# dict(type="Copy", keys_dict={"grid_size": 0.01}),# (adjust `grid_size` to what your want)assert {"grid_size", "coord"}.issubset(self.keys())self["grid_coord"] = torch.div(self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc").int()if "sparse_shape" in self.keys():sparse_shape = self.sparse_shapeelse:sparse_shape = torch.add(torch.max(self.grid_coord, dim=0).values, pad).tolist()sparse_conv_feat = spconv.SparseConvTensor(features=self.feat,indices=torch.cat([self.batch.unsqueeze(-1).int(), self.grid_coord.int()], dim=1).contiguous(),spatial_shape=sparse_shape,batch_size=self.batch[-1].tolist() + 1,)self["sparse_shape"] = sparse_shapeself["sparse_conv_feat"] = sparse_conv_feat
2.4.2.3 embedding

embedding进去后,Pointcept/pointcept/models/point_transformer_v3/point_transformer_v3m1_base.py中的代码如下:
在这里插入图片描述
self.stem就是卷积+BN+Gelu激活函数
在这里插入图片描述

2.4.2.4 encoder

主要的特征提取位于Pointcept/pointcept/models/point_transformer_v3/point_transformer_v3m1_base.py中。
在这里插入图片描述
self.cpe(point)主要是由卷积+FC+层归一化组成
在这里插入图片描述
核心就self.atten,代码位于Pointcept/pointcept/models/point_transformer_v3/point_transformer_v3m1_base.py中。
其中,多头注意力机制的计算QKV的核心还是和vision transformer的很类似。

def forward(self, point):if not self.enable_flash:self.patch_size = min(   #128offset2bincount(point.offset).min().tolist(), self.patch_size_max)H = self.num_heads   #2K = self.patch_size  #128C = self.channels    #32pad, unpad, cu_seqlens = self.get_padding_and_inverse(point)order = point.serialized_order[self.order_index][pad]inverse = unpad[point.serialized_inverse[self.order_index]]# padding and reshape feat and batch for serialized point patch  #[158424]qkv = self.qkv(point.feat)[order]if not self.enable_flash:# encode and reshape qkv: (N', K, 3, H, C') => (3, N', H, K, C')q, k, v = (qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0))# attnif self.upcast_attention:q = q.float()k = k.float()attn = (q * self.scale) @ k.transpose(-2, -1)  # (N', H, K, K)if self.enable_rpe:attn = attn + self.rpe(self.get_rel_pos(point, order))if self.upcast_softmax:attn = attn.float()attn = self.softmax(attn)attn = self.attn_drop(attn).to(qkv.dtype)feat = (attn @ v).transpose(1, 2).reshape(-1, C)else:feat = flash_attn.flash_attn_varlen_qkvpacked_func(qkv.half().reshape(-1, 3, H, C // H),cu_seqlens,max_seqlen=self.patch_size,dropout_p=self.attn_drop if self.training else 0,softmax_scale=self.scale,).reshape(-1, C)feat = feat.to(qkv.dtype)feat = feat[inverse]# ffnfeat = self.proj(feat)feat = self.proj_drop(feat)point.feat = featreturn point

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/53618.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Emlog程序屏蔽用户IP拉黑名单插件

插件介绍 在很多时候我们需要得到用户的真实IP地址&#xff0c;例如&#xff0c;日志记录&#xff0c;地理定位&#xff0c;将用户信息&#xff0c;网站数据分析等,其实获取IP地址很简单&#xff0c;感兴趣的可以参考一下。 今天给大家带来舍力写的emlog插件&#xff1a;屏蔽…

wakenet尾迹

1、数据集介绍SWIM_Dataset_1.0.0 1.1标注文件介绍 标注文件介绍&#xff0c; 第一种&#xff1a;角度和框的坐标 <annotation><folder>Positive</folder><filename>00001</filename>文件名字<format>jpg</format>图片后缀<s…

自掘坟墓?开源正在卷爆程序员!

前端训练营&#xff1a;1v1私教&#xff0c;终身辅导计划&#xff0c;帮你拿到满意的 offer。 已帮助数百位同学拿到了中大厂 offer Hello&#xff0c;大家好&#xff0c;我是 Sunday。 今天这篇文章其实我想了好久&#xff0c;因为这并不是一个 和光同尘 的话题&#xff0c;它…

第143天:内网安全-权限维持自启动映像劫持粘滞键辅助屏保后门WinLogon

案例一&#xff1a; 权限维持-域环境&单机版-自启动 自启动路径加载 路径地址 C:\Users\Administrator\AppData\Roaming\Microsoft\Windows\StartMenu\Programs\Startup\ ##英文C:\Users\Administrator\AppData\Roaming\Microsoft\Windows\开始菜单\程序\启动\ ##中文…

OpenHarmony鸿蒙( Beta5.0)智能窗户通风设备开发详解

鸿蒙开发往期必看&#xff1a; 一分钟了解”纯血版&#xff01;鸿蒙HarmonyOS Next应用开发&#xff01; “非常详细的” 鸿蒙HarmonyOS Next应用开发学习路线&#xff01;&#xff08;从零基础入门到精通&#xff09; “一杯冰美式的时间” 了解鸿蒙HarmonyOS Next应用开发路…

如何逆转Instagram账号流量减少?实用技巧分享

Instagram作为全球十大社媒之一&#xff0c;不仅是个人分享生活的平台&#xff0c;还是跨境卖家进行宣传推广和客户开发的关键工具。在运营Instagram的过程中&#xff0c;稍有不慎就容易出现账号被限流的情况&#xff0c;对于账号状态和运营工作的进行都十分不利。 一、如何判断…

isis与ospf高级属性

文章目录 前言一、基础配置(配置各设备的IP地址)二、配置各设备的ospf与isis三、检查ospf与isis邻居是否建立成功1.实现快速重路由2.流量过滤方法3.引入默认路由4.配置等价路由 前言 在下面实验中&#xff0c;蓝色区域运行ospf&#xff0c;为了控制ospf的lsdb数量&#xff0c;…

vue页面使用自定义字体

一、准备好字体文件 一般字体问价格式为 .tff&#xff0c;可以去包图网等等网站去下载&#xff0c;好看的太多了&#xff01;&#xff01;&#xff01; 下载下来就是单个的 .tff文件&#xff0c;下载下来后可以进行重命名&#xff0c;但是不要改变他的后缀名&#xff0c;我把他…

【c++】类和对象详解

✅博客主页:爆打维c-CSDN博客​​​​​​ &#x1f43e; &#x1f539;分享c语言知识及代码 来都来了! 点个赞给博主个支持再走吧~&#xff01; 一.类的定义 &#xff08;1&#xff09;类定义格式 class为类定义的关键字&#xff0c;定义一个类格式如下: class 类名{//代码…

turtle.circle() 函数绘制弧形规律助记图 ← Python

【Python 之 turtle.circle() 函数定义】 定义&#xff1a;turtle.circle(radius, extent)作用&#xff1a;根据半径 radius 绘制 extent 角度的弧形参数&#xff1a;radius &#xff1a;弧形半径当 radius 值为正数时&#xff0c;圆心在当前位置/小海龟左侧。当 radius 值为负…

9月美联储决策前哨战——美国CPI数据来袭

随着本周关键CPI数据的即将发布&#xff0c;市场正翘首以待&#xff0c;这将是美联储在9月17日至18日议息会议前获取的最后一块重要经济拼图。鉴于美联储官员已进入传统的政策静默期&#xff0c;8月份的CPI报告无疑将成为交易员们评估未来货币政策走向的重要标尺。 欧洲央行降…

[000-01-002].第03节:Git基础命令

我的博客大纲 我的GIT学习大纲 1、Git的常用命令 2、Git操作步骤&#xff1a; 2.1.操作Git第一步&#xff1a;设置全局的用户签名 1.设置用户名&#xff1a; 格式&#xff1a;git config --global user.name 用户名命令&#xff1a;git config --global user.name root 2.设置…

Taro + Vue 的 CSS Module 解决方案

一、开启模块化配置 Taro 中内置了 CSS Modules 的支持&#xff0c;但默认是关闭的。如果需要开启使用&#xff0c;请先在编译配置中添加如下配置&#xff1a; weapp: {module: {postcss: {// css modules 功能开关与相关配置cssModules: {enable: true, // 默认为 false&…

如何解决户用光伏项目管理难题?

户用光伏作为分布式能源的重要组成部分&#xff0c;正迎来前所未有的发展机遇。户用光伏项目的复杂性和多样性也给项目管理带来了诸多挑战&#xff0c;包括客户分散、安装周期长、运维难度大、数据监控不及时等问题。为解决这些难题&#xff0c;构建一套高效、智能的户用光伏业…

SpringMVC基于注解使用:国际化

01-国际化介绍 首先在bootstrap下载个页面 下载后把登录页面的代码粘上去 然后再登录页面代码上有些超链接需要再spring-mvc.xml里面配置下&#xff0c;登录页面才能正常显示 配置静态资源 国际化-根据浏览器语言国际化 现在是中文的情况&#xff0c;要改为英文 1.配置下属…

OFDM信号PARP的CCDF图

文章目录 引言代码代码疑难解答参考文献 引言 本书主要参考了文献1&#xff0c;但实际上该书中符号和表述的错误非常多&#xff08;只能说棒子是这样的&#xff09;&#xff1b;同时因为发表时间的关系&#xff0c;很多MATLAB代码进行了更新&#xff0c;原书提供的代码已经无法…

Flutter中自定义气泡框效果的实现

在用户界面的设计中&#xff0c;气泡框&#xff08;Bubble&#xff09;是一种非常有效的视觉工具&#xff0c;它可以用来突出显示信息或提示用户。气泡框广泛应用于聊天应用、通知提示等场景。在 Flutter 中&#xff0c;虽然有很多现成的气泡框组件&#xff0c;但如果你想要更多…

使用豆包MarsCode 编写 Node.js 全栈应用开发实践

以下是「豆包MarsCode 体验官」优秀文章&#xff0c;作者狼叔。 欢迎更多用户使用豆包MarsCode 并分享您的产品使用心得及反馈、创意项目开发等&#xff0c;【有奖征集&#xff5c;人人都是豆包MarsCode 测评官&#xff01;】活动正在火热进行中&#xff0c;欢迎大家投稿参加&a…

跨部门SOP与统一知识库:打破信息孤岛,促进团队协作

引言&#xff1a; 在当今这个快速变化且高度竞争的商业环境中&#xff0c;企业面临着前所未有的挑战&#xff0c;其中之一便是如何高效地跨越部门界限&#xff0c;实现无缝协作。传统的组织结构往往导致信息孤岛的出现&#xff0c;不同部门间流程不一致、信息不共享&#xff0…

【车载开发系列】ParaSoft单元测试环境配置(一)

【车载开发系列】ParaSoft单元测试环境配置(一) ParaSoft单元测试环境配置 【车载开发系列】ParaSoft单元测试环境配置(一)一. 什么是bdf文件二. bdf文件构成三. 新规做成bdf文件四. 导入bdf文件创建测试工程五. 获取编译器信息六. 新增自定义编译器Step1:打开向导Step2:…