之前的一篇博客《动手学无人驾驶(4):基于激光雷达点云数据3D目标检测》里介绍到了如何基于PointRCNN模型来进行3D目标检测,作者使用的主干网是PointNet++,而PointNet++又是基于PointNet来实现的。今天写的这篇博客就是对PointNet网络进行详细介绍。
(2021-1-27日补充):这是PointNet作者2021年分享的报告《3D物体检测发展与未来》,对3D物体检测感兴趣的朋友可以看看。
【PointNet作者亲述】90分钟带你了解3D物体检测算法和未来方向!
补充:下面的视频是PointNet作者分享的报告《点云上的深度学习及其在三维场景理解中的应用》,里面有详细介绍PointNet(https://www.bilibili.com/video/BV1As411377S/?spm_id_from=333.788.videocard.1)。
将门创投 | 斯坦福大学在读博士生祁芮中台:点云上的深度学习及其在三维场景理解中的应用
1.PointNet论文解读
前言
PointNet网络
2.PointNet源码
参考资料
1.PointNet论文解读
前言
随着大数据和深度学习的兴起,涌现了许许多多的3D应用,与此同时需要一种数据驱动的方式去理解和处理三维数据,这就是:3D deep learning。
三维数据本身有一定的复杂性,2D图像可以轻易的表示成矩阵,3D表达形式主要分为以下几种:
- point cloud :深度传感器扫描得到的深度数据,点云。
- Mesh:三角面片在计算机图形学中渲染和建模话会很有用。
- Volumetric:将空间划分成三维网格,栅格化。
- Multi-View:用多个角度的图片表示物体。
Point cloud 是一种非常适合于3D场景理解的数据,原因是:
- 点云是非常接近原始传感器的数据集,激光雷达扫描之后就是点云,原始的数据可以做端到端的深度学习。
- 点云在表达形式上是比较简单的,一组点。相比较来说Mesh需要选择面片类型和如何连接;网格需要选择多大的网格,分辨 率;图像的选择,需要选择拍摄的角度,但是表达是不全面的。
最近才有一些方法研究直接在点云上进行特征学习,之前的大部分工作都是集中在手工设计点云数据的。这些特征都是针对特定任务,有不同的假设,新的任务很难优化特征。
但是点云数据是一种不规则的数据,在空间上和数量上可以任意分布,之前的研究者在点云上会先把它转化成一个规则的数据,比如栅格让其均匀分布,然后再用3D CNN来处理栅格数据。3D CNN复杂度相当的高,是三次方的增长,所以分辨率不高,相比图像是很低的。
但是如果考虑不计复杂度的栅格,会导致大量的栅格都是空白,智能扫描到表面,内部都是空白的。所以栅格并不是对3D点云很好的一种表达方式,也有人考虑过,用3D点云数据投影到2D平面上用2D cnn 进行训练,这样会损失3D的信息。 还要决定的投影的角度。点云中提取手工的特征,再接FC,这么做有很大的局限性
我们能否直接用一种在点云上学习的方法?
PointNet网络
我们的目标是提出一种端到端的点云多任务处理框架,包括目标分类,目标零件分类以及场景语义解析。
点云输入数据处理
点云是数据的表达点的集合,网络模型应对点云的排列方式不敏感,如下图所示,对于N个具有D维特征的点云数据,排列方式可能有N!种,我们希望我们的网络模型能够对于N!排列方式点云数据能够保持同样的学习效果。
神将网络本质上是一个函数,我们希望找到一个对称函数,能够对于点云数据具有置换不变性。如取最大值函数,无论输入怎么变换,最后的结果都是输入的最大值。
虽然是置换不变的,但是这种方式只计算了最远点的边界,损失了很多有意义的几何信息,如何解决呢?与其说直接做对称性可以先把每个点映射到高维空间,在高维空间中做对称性的操作,高维空间可以是一个冗余的,在max操作中通过冗余可以避免信息的丢失,可以保留足够的点云信息,再通过一个网络来进一步消化信息得到点云的特征。这就是函数的组合:每个点都做h低维到高维的映射,G是对称的那么整个结构就都是对称的。下图就是原始的pointnet结构。
在实际执行过程中,可以用MLP多层感知器(Multilayer perceptron) 来描述h和γ,g( max polling) 效果最好。
我们发现,pointnet 可以任意的逼近在集合上的对称函数,只要是对称函数是在hausdorff空间是连续的,那么就可以通过任意的增加神经网络的宽度深度,来逼近这个函数:
视角变换
如何来应对输入点云的几何(视角)变换,比如一辆车在不同的角度点云的xyz都是不同的, 但代表的都是车,我们希望网络也能应对视角的变换。
增加了一个基于数据本身的变换函数模块, T-net 生成变换参数,之后的网络处理变换之后的点,目标是通过整体优化变换网络和后面的网络使得变换函数对齐输入,如果对齐了,不同视角的问题就可以简化。实际中点云的变化很简单,不像图片做变换需要做插值,做矩阵乘法就可以。比如对于一个3*3的矩阵仅仅是一个正交变换,计算容易实现简单。
PointNet分类网络
将以上这些变换的网络和pointnet结合起来,就可以得到PointNet分类网络。
首先输入一个n*3的矩阵,先做一个输入的矩阵变换,T-net 变成一个n*3的矩阵,然后通过MLP把每个点投射到64高维空间,在做一个高维空间的变换,形成一个更加归一化的64维矩阵,继续做MLP将64维映射到1024维,在1024中可以做对称性的操作,就maxpooling,得到globle fearue,1024维度 ,通过全连接网络生成k (分类)。
PointNet分割网络
分割网络如图:
可以定以为对每个点的分类问题,通过全局坐标是没法对每个点进行分割的,简单有效的做法是,将局部单个点的特征和全局的坐标结合起来,实现分割的功能。最后输出m类相当于m个score:(将单个点和总体的特征连接到一起,判定在总体中的位置,来决定是哪个分类)
2.PointNet源码
这里使用的是Pytorch的版本。
class STN3d(nn.Module):'''3x3 transform'''def __init__(self):super(STN3d, self).__init__()self.conv1 = torch.nn.Conv1d(3, 64, 1)self.conv2 = torch.nn.Conv1d(64, 128, 1)self.conv3 = torch.nn.Conv1d(128, 1024, 1)self.fc1 = nn.Linear(1024, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, 9)self.relu = nn.ReLU()self.bn1 = nn.BatchNorm1d(64)self.bn2 = nn.BatchNorm1d(128)self.bn3 = nn.BatchNorm1d(1024)self.bn4 = nn.BatchNorm1d(512)self.bn5 = nn.BatchNorm1d(256)def forward(self, x):batchsize = x.size()[0]x = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = F.relu(self.bn3(self.conv3(x)))x = torch.max(x, 2, keepdim=True)[0]x = x.view(-1, 1024)x = F.relu(self.bn4(self.fc1(x)))x = F.relu(self.bn5(self.fc2(x)))x = self.fc3(x)iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1)if x.is_cuda:iden = iden.cuda()x = x + idenx = x.view(-1, 3, 3)return xclass STNkd(nn.Module):'''64x64 transform'''def __init__(self, k=64):super(STNkd, self).__init__()self.conv1 = torch.nn.Conv1d(k, 64, 1)self.conv2 = torch.nn.Conv1d(64, 128, 1)self.conv3 = torch.nn.Conv1d(128, 1024, 1)self.fc1 = nn.Linear(1024, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, k*k)self.relu = nn.ReLU()self.bn1 = nn.BatchNorm1d(64)self.bn2 = nn.BatchNorm1d(128)self.bn3 = nn.BatchNorm1d(1024)self.bn4 = nn.BatchNorm1d(512)self.bn5 = nn.BatchNorm1d(256)self.k = kdef forward(self, x):batchsize = x.size()[0]x = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = F.relu(self.bn3(self.conv3(x)))x = torch.max(x, 2, keepdim=True)[0]x = x.view(-1, 1024)x = F.relu(self.bn4(self.fc1(x)))x = F.relu(self.bn5(self.fc2(x)))x = self.fc3(x)iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1)if x.is_cuda:iden = iden.cuda()x = x + idenx = x.view(-1, self.k, self.k)return xclass PointNetfeat(nn.Module):'''Output: global feature / local+global feature'''def __init__(self, global_feat = True, feature_transform = False):super(PointNetfeat, self).__init__()self.stn = STN3d()self.conv1 = torch.nn.Conv1d(3, 64, 1)self.conv2 = torch.nn.Conv1d(64, 128, 1)self.conv3 = torch.nn.Conv1d(128, 1024, 1)self.bn1 = nn.BatchNorm1d(64)self.bn2 = nn.BatchNorm1d(128)self.bn3 = nn.BatchNorm1d(1024)self.global_feat = global_featself.feature_transform = feature_transformif self.feature_transform:self.fstn = STNkd(k=64)def forward(self, x):n_pts = x.size()[2]trans = self.stn(x)x = x.transpose(2, 1)x = torch.bmm(x, trans)x = x.transpose(2, 1)x = F.relu(self.bn1(self.conv1(x)))if self.feature_transform:trans_feat = self.fstn(x)x = x.transpose(2,1)x = torch.bmm(x, trans_feat)x = x.transpose(2,1)else:trans_feat = Nonepointfeat = xx = F.relu(self.bn2(self.conv2(x)))x = self.bn3(self.conv3(x))x = torch.max(x, 2, keepdim=True)[0]x = x.view(-1, 1024)if self.global_feat:return x, trans, trans_feat # (B, 1024) (B, 3, 3) (B, 64, 64)else:x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)return torch.cat([x, pointfeat], 1), trans, trans_feat # (B, 1088, 2500) (B,3, 3) (B, 64, 64)class PointNetCls(nn.Module):# 分类网络def __init__(self, k=2, feature_transform=False):super(PointNetCls, self).__init__()self.feature_transform = feature_transformself.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform)self.fc1 = nn.Linear(1024, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, k)self.dropout = nn.Dropout(p=0.3)self.bn1 = nn.BatchNorm1d(512)self.bn2 = nn.BatchNorm1d(256)self.relu = nn.ReLU()def forward(self, x):x, trans, trans_feat = self.feat(x)x = F.relu(self.bn1(self.fc1(x)))x = F.relu(self.bn2(self.dropout(self.fc2(x))))x = self.fc3(x)return F.log_softmax(x, dim=1), trans, trans_featclass PointNetDenseCls(nn.Module):# 分割网络def __init__(self, k = 2, feature_transform=False):super(PointNetDenseCls, self).__init__()self.k = kself.feature_transform=feature_transformself.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform)self.conv1 = torch.nn.Conv1d(1088, 512, 1)self.conv2 = torch.nn.Conv1d(512, 256, 1)self.conv3 = torch.nn.Conv1d(256, 128, 1)self.conv4 = torch.nn.Conv1d(128, self.k, 1)self.bn1 = nn.BatchNorm1d(512)self.bn2 = nn.BatchNorm1d(256)self.bn3 = nn.BatchNorm1d(128)def forward(self, x):batchsize = x.size()[0]n_pts = x.size()[2]x, trans, trans_feat = self.feat(x)x = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = F.relu(self.bn3(self.conv3(x)))x = self.conv4(x)x = x.transpose(2,1).contiguous()x = F.log_softmax(x.view(-1,self.k), dim=-1)x = x.view(batchsize, n_pts, self.k)return x, trans, trans_featif __name__ == '__main__':sim_data = Variable(torch.rand(32,3,2500))trans = STN3d()out = trans(sim_data)print('stn', out.size())print('loss', feature_transform_regularizer(out))sim_data_64d = Variable(torch.rand(32, 64, 2500))trans = STNkd(k=64)out = trans(sim_data_64d)print('stn64d', out.size())print('loss', feature_transform_regularizer(out))pointfeat = PointNetfeat(global_feat=True)out, _, _ = pointfeat(sim_data)print('global feat', out.size())pointfeat = PointNetfeat(global_feat=False)out, _, _ = pointfeat(sim_data)print('point feat', out.size())cls = PointNetCls(k = 5)out, _, _ = cls(sim_data)print('class', out.size())seg = PointNetDenseCls(k = 3)out, _, _ = seg(sim_data)print('seg', out.size())
stn torch.Size([32, 3, 3]) loss tensor(2.5054, grad_fn=<MeanBackward0>)stn64d torch.Size([32, 64, 64]) loss tensor(127.5234, grad_fn=<MeanBackward0>)global feat torch.Size([32, 1024])point feat torch.Size([32, 1088, 2500])class torch.Size([32, 5])seg torch.Size([32, 2500, 3])
参考资料
https://www.cnblogs.com/yibeimingyue/p/12002469.html
https://github.com/fxia22/pointnet.pytorch
http://stanford.edu/~rqi/pointnet/
https://zhuanlan.zhihu.com/p/86331508
https://www.bilibili.com/video/BV1As411377S/?spm_id_from=333.788.videocard.1