PointNet人工智能深度学习简明图解

PointNet 是一种深度网络架构,它使用点云来实现从对象分类、零件分割到场景语义解析等应用。 它于 2017 年实现,是第一个直接将点云作为 3D 识别任务输入的架构。

本文的想法是使用 Pytorch 实现 PointNet 的分类模型,并可视化其转换以了解模型的工作原理。

如果你不知道点云是什么……它只是对象或场景的 3D 表示,通常从 LiDAR(光检测和测距)传感器收集。 这些传感器发射光脉冲,然后测量它们返回传感器所需的时间。 此信息可用于创建对象或场景的 3D 模型,如上面的模型。 LiDAR 传感器变得越来越流行,你可以在自动驾驶汽车、无人机、测绘飞机甚至某些智能手机中找到它们!

1、Pointnet训练数据集

为了简单起见,我们将使用著名的 MNIST 数据集,我们可以直接使用 Pytorch 下载该数据集。

MNIST 包含 60,000 张手写数字图像,从 0 到 9。

PointNet 处理由三个坐标 (x, y, z) 表示的点,因此我们将把 2D 图像转换为 3D 点云,如下图所示。

MNIST 样本是 28 x 28 像素的灰度图像。 像素值是范围从 0(黑色)到 255(白色)的整数。 我们想要将数字的每个像素转换为一个点。 函数transform_img2pc过滤图像中值高于127的像素并获取它们的索引。

import numpy as npdef transform_img2pc(img):img_array = np.asarray(img)indices = np.argwhere(img_array > 127)return indices.astype(np.float32)

一旦我们将像素转换为点,我们需要所有点云具有相同数量的点,以便我们可以将它们输入到 PointNet 中。 PointNet 的作者使用每个对象 2500 个点,我们将绘制每个数字的点的直方图来确定阈值。

from torchvision.datasets import MNIST
import matplotlib.pyplot as pltdataset = MNIST(root='./data', train=True, download=True)
len_points = []
# loop over samples
for idx in range(len(dataset)):img,label = dataset[idx]pc = transform_img2pc(img)len_points.append(len(pc))h = plt.hist(len_points)
plt.title('Histogram of amount of points per number')

我们将点数固定为 200,因为最大点数为 312,并且大多数点都在 200 以下。我们可能面临两种情况,点云高于 200 点和点云低于此阈值。

  • 当点数超过 200 时,我们将对点进行随机采样。
  • 相反,我们将随机复制现有点。

最后,我们将向所有产生均值为零、标准差为 0.05 的高斯噪声的点添加第三个分量 z。

让我们将数据处理包装在自定义 Dataset 类中。

from torch.utils.data import Datasetclass MNIST3D(Dataset):"""3D MNIST dataset."""NUM_CLASSIFICATION_CLASSES = 10POINT_DIMENSION = 3def __init__(self, dataset, num_points):self.dataset = datasetself.number_of_points = num_pointsdef __len__(self):return len(self.dataset)def __getitem__(self, idx):img,label = dataset[idx]pc = transform_img2pc(img)if self.number_of_points-pc.shape[0]>0:# Duplicate pointssampling_indices = np.random.choice(pc.shape[0], self.number_of_points-pc.shape[0])new_points = pc[sampling_indices, :]pc = np.concatenate((pc, new_points),axis=0)else:# sample pointssampling_indices = np.random.choice(pc.shape[0], self.number_of_points)pc = pc[sampling_indices, :]pc = pc.astype(np.float32)# add znoise = np.random.normal(0,0.05,len(pc))noise = np.expand_dims(noise, 1)pc = np.hstack([pc, noise]).astype(np.float32)pc = torch.tensor(pc)return pc, label

Dataset存储预处理后的样本及其相应的标签,现在我们需要定义一个DataLoader来迭代训练循环中的数据。

下载 MNIST 数据后,我们将连接默认分区(训练和测试)并将数据输入到我们的自定义 MNIST3D 数据集中。 然后,我们将数据集分为训练(80%)、验证(10%)和测试(10%),并为每个分区生成一个 DataLoader,批量大小为 128。

train_dataset = MNIST(root='./data/MNIST', download=True, train=True)
test_dataset = MNIST(root='./data/MNIST', download=True, train=False)
dataset = torch.utils.data.ConcatDataset([train_dataset, test_dataset])dataset_3d = MNIST3D(dataset, number_of_points)
l_data = len(dataset_3d)
train_dataset, val_dataset, test_dataset = random_split(dataset_3d,[round(0.8*l_data), round(0.1*l_data), round(0.1*l_data)],generator=torch.Generator().manual_seed(1))train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=128, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False)

最后,我们绘制一些样本来检查点云是否正确生成。 你还可以使用我们笔记本的实现来生成类似上面的很酷的 3D gif。

pc = train_dataset[5][0].numpy()
label = train_dataset[5][1]
fig = plt.figure(figsize=[7,7])
ax = plt.axes(projection='3d')
sc = ax.scatter(pc[:,0], pc[:,1], pc[:,2], c=pc[:,0] ,s=80, marker='o', cmap="viridis", alpha=0.7)
ax.set_zlim3d(-1, 1)
plt.title(f'Label: {label}')

现在数据已经准备好了,我们可以专注于模型了!

2、Pointnet的体系结构和属性

PointNet由分类网络和分割网络组成。 分类网络以n个点(x,y,z)作为输入,使用T-Net应用输入和特征变换,然后通过最大池化聚合点特征。 输出是 k 个类别中每个类别的分类分数。 分割网络是分类网络的扩展。 它连接全局和局部特征并输出每点分数。

pointNet 的架构受到点集属性的启发,它们是一些设计选择的关键……让我们来检查一下!

1、无序。 与图像中的像素阵列不同,点云是一组没有特定顺序的点。

  • 要求:模型需要对点的排列保持不变。
  • 解决方案:使用最大池化层作为对称函数来聚合所有点的信息。 最大池化,如 * 和 +,是对称函数,因为输入的顺序不会改变结果。

2、点之间的交互。 这些点来自具有距离度量的空间。 这意味着点不是孤立的,相邻点形成一个有意义的子集。

  • 要求:模型需要能够捕获附近点的局部结构。
  • 解决方案:结合局部和全局特征进行分割。

3、变换下的不变性。 学习到的点集表示对于某些变换应该是不变的。

  • 要求:同时旋转和平移点不应修改全局点云类别或点的分割。
  • 解决方案:使用空间转换器网络,尝试在 PointNet 处理数据之前将数据转换为规范形式。 T-Net 是一种用于对齐输入点和点特征的神经网络。

可以在下面的代码中看到 T-Net(输入变换和 feature_transform)、最大池化(MaxPool1d)和特征生成(局部和全局)的使用。 ClassificationPointNet 返回每个点云的对数概率、损失正则化所需的特征变换以及用于绘图目的的最后两个元素(tnet_out、ix_maxpool)。

在下一节中,我们将更详细地介绍 T-Net 的实施、它的工作原理以及提供的好处。

class BasePointNet(nn.Module):def __init__(self, point_dimension):...def forward(self, x, plot=False):num_points = x.shape[1]input_transform = self.input_transform(x) # T-Net tensor [batch, 3, 3]x = torch.bmm(x, input_transform) # Batch matrix-matrix product x = x.transpose(2, 1) tnet_out=x.cpu().detach().numpy()x = F.relu(self.bn_1(self.conv_1(x)))x = F.relu(self.bn_2(self.conv_2(x)))x = x.transpose(2, 1)feature_transform = self.feature_transform(x)  # T-Net tensor [batch, 64, 64]x = torch.bmm(x, feature_transform)  # local point features [batch, 200, 64]x = x.transpose(2, 1)x = F.relu(self.bn_3(self.conv_3(x)))x = F.relu(self.bn_4(self.conv_4(x)))x = F.relu(self.bn_5(self.conv_5(x)))x, ix = nn.MaxPool1d(num_points, return_indices=True)(x)  # max-poolingx = x.view(-1, 1024)  # global feature vector [batch, 1024]return x, feature_transform, tnet_out, ixclass ClassificationPointNet(nn.Module):def __init__(self, num_classes, dropout=0.3, point_dimension=3):...def forward(self, x):x, feature_transform, tnet_out, ix_maxpool = self.base_pointnet(x)x = F.relu(self.bn_1(self.fc_1(x)))x = F.relu(self.bn_2(self.fc_2(x)))x = self.dropout_1(x)return F.log_softmax(self.fc_3(x), dim=1), feature_transform, tnet_out, ix_maxpool

出于空间原因,init 函数已被省略,但您可以在笔记本中查看它们。

3、训练Pointnet

我们使用经典的 Pytorch 训练循环来训练我们的模型。 我们将学习率设置为 0.001,最大 epoch 数设置为 80。您可以在上面的链接中找到 PointNet 的更轻版本(在 Google Colab 中实现)来使用它。 PointNet 包含多个 MLP,因此它具有大量可训练参数 (3.472.339)。 PointNet 的轻量级版本是通过减少每层神经元数量来减少训练时间来实现的,从而产生 910.611 个可训练参数。

该模型通过负对数似然损失 (NLL) 和正则化项进行优化,使其更加稳定。 NLL 是训练具有多个类别的分类问题时的典型损失。

一旦我们看到损失已经收敛,验证损失不会减少,我们就可以停止训练并测试我们的模型。

Test Accuracy
0.967
Alert⚠️ 如果模型没有经过完全训练,它可能无法保证排列的不变性。

3、可视化 T-Net 的输入和输出

T-Net 在特征提取之前将所有输入集对齐到规范空间。 它是如何做到的? 它预测将应用于输入点 (x, y, z) 坐标的 3x3 仿射变换矩阵。

这个想法可以进一步扩展到特征空间的对齐。 在PointNet架构图中可以看到,第二个T-Net预测了64x64的特征转换矩阵,用于对齐来自不同输入点云的特征。

正如你在下面的代码块中看到的,T-Net 由用于点无关特征提取的一维卷积层、最大池化和全连接层组成。 结果是一个变换矩阵,我们直接将其应用于输入点的坐标。

class TransformationNet(nn.Module):def __init__(self, input_dim, output_dim):super(TransformationNet, self).__init__()self.output_dim = output_dimself.conv_1 = nn.Conv1d(input_dim, 64, 1)self.conv_2 = nn.Conv1d(64, 128, 1)self.conv_3 = nn.Conv1d(128, 1024, 1)self.bn_1 = nn.BatchNorm1d(64)self.bn_2 = nn.BatchNorm1d(128)self.bn_3 = nn.BatchNorm1d(1024)self.bn_4 = nn.BatchNorm1d(512)self.bn_5 = nn.BatchNorm1d(256)self.fc_1 = nn.Linear(1024, 512)self.fc_2 = nn.Linear(512, 256)self.fc_3 = nn.Linear(256, self.output_dim * self.output_dim)def forward(self, x):num_points = x.shape[1]x = x.transpose(2, 1)x = F.relu(self.bn_1(self.conv_1(x)))x = F.relu(self.bn_2(self.conv_2(x)))x = F.relu(self.bn_3(self.conv_3(x)))x = nn.MaxPool1d(num_points)(x)x = x.view(-1, 1024)x = F.relu(self.bn_4(self.fc_1(x)))x = F.relu(self.bn_5(self.fc_2(x)))x = self.fc_3(x)identity_matrix = torch.eye(self.output_dim)if torch.cuda.is_available():identity_matrix = identity_matrix.cuda()x = x.view(-1, self.output_dim, self.output_dim) + identity_matrixreturn x
注意📝 T-Net 通过学习变换矩阵将所有输入集对齐到规范空间

通过绘制 T-Net 输出乘以输入点的结果,我们可以看到对输入点云执行的规范变换。

PointNet 的特性之一是它对点的排列具有不变性。 我们来测试一下! 我们将打乱点并比较转换和预测。 我们将使点大小更小,以更好地识别两种转换之间的差异。

我们可以看到,对于这个例子,使用不同的点顺序,我们得到非常相似的表示和相同的预测。

所有测试样本都会保留它吗? 让我们比较所有测试样本上的打乱点和非打乱点之间的预测。

(results==results_shuffle)
False

我们从 7000 个样本(测试集大小)中得到 6 个样本,在洗牌时得到不同的结果。 我们存储这些样本的索引以比较转换和预测。 在这里你可以看到几个示例:

我们发现转换非常相似,并且通过查看 T-Net 转换来猜测这些数字时我们也可能是错误的。 您认为为什么同一个模型会预测不同的数字? 我们可以绘制对最大池化有贡献的点来获得一个想法。

4、可视化 PointNet 关键点

PointNet 学习通过一组稀疏的关键点(作者称为关键点)来总结输入点云。 关键点是那些对最大池化特征有贡献的点。

我们存储了最大池化层的索引,我们绘制了混洗和非混洗点云的这些点,并获得了下图:

我们看到临界点集对应于数字的骨架,并且在混洗和非混洗点云之间是不同的,这导致模型预测一个或另一个类别!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/587817.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一个WebSocket的自定义hook

一个WebSocket的自定义hook 自己封装了一个WebSocket的hook,代码如下&#xff1a; import { useEffect, useRef } from "react";const WS_URL wss://xxx // 服务地址const useSocket () > {const socketRef useRef<WebSocket>()let heartTimer 0; // …

【python】爬取百度热搜排行榜Top50+可视化【附源码】【送数据分析书籍】

一、导入必要的模块&#xff1a; 这篇博客将介绍如何使用Python编写一个爬虫程序&#xff0c;从斗鱼直播网站上获取图片信息并保存到本地。我们将使用requests模块发送HTTP请求和接收响应&#xff0c;以及os模块处理文件和目录操作。 如果出现模块报错 进入控制台输入&#xff…

第82讲:MySQL Binlog日志的滚动

MySQL Binlog日志的滚动 MySQL Binlog日志滚动指的就是产生一个新的Binlog日志&#xff0c;然后进行记录&#xff0c;因为如果都在一个Binlog中记录&#xff0c;查询是非常慢的&#xff0c;检索的效率也很低。 Binlog日志滚动有三种方法&#xff1a; 重启MySQL 数据库一般不重…

android开发调用百度地图api实现加载地图和定位

目录 一.踩的一些坑以及解决方法 1.权限声明不要少 2.地图初始化 3.定位问题 &#xff08;1&#xff09;监听器注册 &#xff08;2&#xff09;定位监听器类MyLocationListener的实现 &#xff08;3&#xff09;定位功能的调用 4.android studio连接真机调试问题 二.…

Spring Boot应用整合Prometheus

Spring Boot Actuator 提供了一组用于监控和管理 Spring Boot 应用程序的端点&#xff0c;而 Prometheus 是一个开源的监控和告警工具。通过将这两者结合起来&#xff0c;您可以实时监控您的应用程序的性能指标&#xff0c;并通过 Prometheus 提供的丰富的查询语言来分析和可视…

MySQL:索引

MySQL官方对索引的定义为: 索引 (Index) 是帮助MySQL高效获取数据的数据结构。 提取句子主干&#xff0c;就可以得到索引的本质:索引是数据结构。 1. 什么是索引&#xff0c;索引的作用 索引是一种用于快速查询和检索数据的数据结构&#xff0c;帮助mysql提高查询效率的数据…

ros2查看launch文件内需要提供的参数(接口):

格式&#xff1a;ros2 launch --show-args 包名称 launch文件名称 例如&#xff1a; ros2 launch --show-args ros_gz_sim gz_sim.python.py

行人重识别优化:Pose-Guided Feature Alignment for Occluded Person Re-Identification

文章记录了ICCV2019的一篇优化遮挡行人重识别论文的知识点&#xff1a;Pose-Guided Feature Alignment for Occluded Person Re-Identification 论文地址&#xff1a; https://yu-wu.net/pdf/ICCV2019_Occluded-reID.pdf Partial Feature Branch分支: PCB结构&#xff0c;将…

精致旅游网ROXANDREA 网页设计 html模板

一、需求分析 旅游网站通常具有多种功能&#xff0c;以下是一些常见的旅游网站功能&#xff1a; 酒店预订&#xff1a;旅游网站可以提供酒店预订服务&#xff0c;让用户搜索并预订符合其需求和预算的酒店房间。 机票预订&#xff1a;用户可以通过旅游网站搜索、比较和预订机票…

JavaScript 工具库 | PrefixFree给CSS自动添加浏览器前缀

新版的CSS拥有多个新属性&#xff0c;而标准有没有统一&#xff0c;有的浏览器厂商为了吸引更多的开发者和用户&#xff0c;已经加入了最新的CSS属性支持&#xff0c;这其中包含了很多炫酷的功能&#xff0c;但是我们在使用的时候&#xff0c;不得不在属性前面添加这些浏览器的…

毕业设计之开题报告

终于轮到我来写开题报告了&#xff0c;呃呃呃呃呃&#xff0c;目前有点难产了。想做的东西是关于区块链的后端设计实现&#xff0c;但是因为是完全原创之前没有类似的项目能去参考&#xff0c;所以其实有点慌的。 框架梳理 这是我们开题报告的要求&#xff1a; 包括题目研究的…

Django框架:入门指南与常用命令

引言&#xff1a; 在当今的Web开发世界中&#xff0c;Django无疑是一个备受瞩目的框架。它以其强大的功能和易用性&#xff0c;吸引着越来越多的开发者。这篇博客将为你提供一个关于Django的概览&#xff0c;以及一些常用的命令&#xff0c;帮助你快速上手。 一、Django简介&…

GPT技术:人工智能的语言革命

在人工智能的领域中&#xff0c;自然语言处理&#xff08;NLP&#xff09;一直是一个极具挑战性的研究领域。随着技术的进步&#xff0c;一个名为GPT&#xff08;Generative Pre-trained Transformer&#xff09;的模型出现在了公众的视野中&#xff0c;它不仅改变了我们与机器…

Java项目:102SSM汽车租赁系统

博主主页&#xff1a;Java旅途 简介&#xff1a;分享计算机知识、学习路线、系统源码及教程 文末获取源码 一、项目介绍 汽车租赁系统基于SpringSpringMVCMybatis开发&#xff0c;系统使用shiro框架做权限安全控制&#xff0c;超级管理员登录系统后可根据自己的实际需求配角色…

uniapp的css样式图片大小截图展示

目录 截取图片前截取图片后第一种方式&#xff1a;代码第二种方式&#xff1a;代码最后 截取图片前 截取图片后 第一种方式&#xff1a;代码 <view class"swiper-box-img"><image class"swiper-box-img-img" :src"item.file_path" mod…

我自己的Mac装机软件推荐!

我自己的Mac装机软件推荐&#xff01; 以下内容是我自己用着挺舒服的&#xff0c;使用频率很高的mac软件&#xff0c;写在这里留个印记。 之前好多mac破解软件网址没了&#xff0c;macbl现在还活着也还用起来不错&#xff5e; 首先还是推荐windows和mac双持用户看看我的这篇文…

Windows系统历史版本简介详细版

学习目标&#xff1a; 目录 学习目标&#xff1a; 学习内容&#xff1a; 学习产出&#xff1a; Windows 11的全新用户界面设计&#xff1a;学习新的任务栏、开始菜单、窗口管理等界面元素的使用与操作。 Windows 11的新功能和特点&#xff1a;学习新的虚拟桌面、Microsoft Team…

SLAM学习入门--什么是回环检测

文章目录 SLAM001 什么是回环检测?002 常用的回环检测方法有哪些?003 介绍一下Gauss-Netwon和LM算法004 介绍一下Ceres优化库,比如你使用过里面哪些内容?005 描述(扩展)卡尔曼滤波与粒子滤波,你自己在用卡尔曼滤波时遇到什么问题没有?006 除了视觉传感,还用过其他传感…

Ubuntu20.04 防火墙配置

ubuntu 系统中配置防火墙 ufw&#xff08;Uncomplicated Firewall&#xff09;是一个简化的、易于使用的Linux防火墙工具&#xff0c;旨在方便用户管理iptables防火墙规则。 特点 简化的防火墙管理&#xff1a;ufw提供了一个简洁的命令行界面&#xff0c;让您能够轻松地添加、…

2022-2023年度广东省职业院校学生专业技能大赛“软件测试”赛项性能测试题目-LoadRunner

性能测试-LR 1、脚本录制: (1)脚本一:脚本名称ProdAdd。 脚本内容:系统管理员登录、进行新增商品操作。 脚本具体要求如下:登录脚本存放在init,新增商品脚本存放在Action。商品名称前4位为固定值SPMC,固定值后面的字符可任意设置。对新增商品保存操作设置事务,事务…