paper:Lift, Splat, Shoot: Encoding Images from Arbitrary Camera Rigs by Implicitly Unprojecting to 3D
code:https://github.com/nv-tlabs/lift-splat-shoot
一、完整复现代码(可一键运行)和效果图
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import cv2
import numpy as np# 根据世界坐标范围和一个像素代表的世界坐标距离来计算bev_size
# dx:[0.5,0.5,20]代表单位长度,bx是[-49.75,49.75,0]代表起始网格点的中心,nx[200,200,1] 代表网格数目
xbound = [-50.0, 50.0, 0.5] # 前后100米,1个pixel=0.5米 -> x方向: 200 pixel
ybound = [-50.0, 50.0, 0.5] # 左右100米,1个pixel=0.5米 -> y方向: 200 pixel
zbound = [-10.0, 10.0, 20.0] # 上下20米, 1个pixel=20米 -> z方向: 1 pixel
dbound = [4.0, 45.0, 1.0] # 深度4~45米, 1个pixel=1米 -> d方向: 41 pixel
D_ = int((dbound[1]-dbound[0])/dbound[2])def gen_dx_bx(xbound, ybound, zbound):dx = torch.Tensor([row[2] for row in [xbound, ybound, zbound]])bx = torch.Tensor([row[0] + row[2]/2.0 for row in [xbound, ybound, zbound]])nx = torch.LongTensor([(row[1] - row[0]) / row[2] for row in [xbound, ybound, zbound]])dx = nn.Parameter(dx, requires_grad=False)bx = nn.Parameter(bx, requires_grad=False)nx = nn.Parameter(nx, requires_grad=False)return dx, bx, nxbatch_size = 1
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 模型输入尺寸及下采样倍数
in_H = 128
in_W = 352
scale_downsample = 16
# 模型输出尺寸
feat_W16 = in_W // scale_downsample
feat_H16 = in_H // scale_downsample
semantic_channels = 64# 相机参数(两个相机)
num_cams = 2
rots=torch.Tensor([[[[ 8.2076e-01, -3.4144e-04, 5.7128e-01],[-5.7127e-01, 3.2195e-03, 8.2075e-01],[-2.1195e-03, -9.9999e-01, 2.4474e-03]],[[-9.3478e-01, 0, 0],[ 3.5507e-01, 0, -9.3477e-01],[-1.0805e-02, -9.9981e-01, 0]]]])
intrins = torch.Tensor([[[[1.2726e+03, 0.0, 0],[0.0000e+00, 1.2726e+03, 4.7975e+02],[0.0000e+00, 0.0000e+00, 1.0000e+00]],[[1.2595e+03, 0.0000e+00, 8.0725e+02], [0.0000e+00, 1.2595e+03, 5.0120e+02],[0.0000e+00, 0.0000e+00, 1.0000e+00]]]])
post_rots = torch.Tensor([[[[0.2200, 0.0000, 0.0000],[0.0000, 0.2200, 0.0000],[0.0000, 0.0000, 1.0000]],[[0.2200, 0.0000, 0.0000],[0.0000, 0.2200, 0.0000],[0.0000, 0.0000, 1.0000]]]])
post_trans =torch.Tensor([[[ 0.],[ 0.]], [[0.], [0.]], [[ 0.],[ 0.]]])
trans = torch.Tensor([[[ 1.5239, 0.4946, 1.5093], [ 1.0149, -0.4806, 1.5624]]])def create_uvd_frustum():# 41米深度范围,值在[4,45]# 扩展至41x22x8distance = torch.arange(*dbound, dtype=torch.float).view(-1, 1, 1).expand(-1, feat_H16, feat_W16)D, _, _ = distance.shape# 22格,值在[0,128]# 再扩展至[41,8,22]x_stride = torch.linspace(0, in_W - 1, feat_W16, dtype=torch.float).view(1, 1, feat_W16).expand(D, feat_H16, feat_W16)# 8格,值在[0,352]# 再扩展至[41,8,22]y_stride = torch.linspace(0, in_H - 1, feat_H16, dtype=torch.float).view(1, feat_H16, 1).expand(D, feat_H16, feat_W16)# 创建视锥: [41,8,22,3]frustum = torch.stack((x_stride, y_stride, distance), -1)# 不计算梯度,不需要学习return nn.Parameter(frustum, requires_grad=False)def plot_uvd_frustum(frustum): # 41 8 22 3fig = plt.figure()ax = fig.add_subplot(111, projection='3d')# Convert frustum tensor to numpy array for visualizationfrustum_np = frustum.numpy()# Extract x, y, d coordinatesx = frustum_np[..., 0].flatten()y = frustum_np[..., 1].flatten()d = frustum_np[..., 2].flatten()# Plot the points in 3D spaceax.scatter(x, y, d, c=d, cmap='viridis', marker='o')ax.set_xlabel('u')ax.set_ylabel('v')ax.set_zlabel('d')plt.show()path = f'uvd_frustum.png'plt.savefig(path)def get_geometry_feat(frustum,rots, trans, intrins, post_rots, post_trans):B, N, _ = trans.shape# 视锥逆数据增强points = frustum - post_trans.view(B, N, 1, 1, 1, 3)# 加上B,N(6 cams)维度points = torch.inverse(post_rots).view(B, N, 1, 1, 1, 3, 3).matmul(points.unsqueeze(-1))#根据相机内外参将视锥点云从相机坐标映射到世界坐标points = torch.cat((points[:, :, :, :, :, :2] * points[:, :, :, :, :, 2:3],points[:, :, :, :, :, 2:3]), 5)combine = rots.matmul(torch.inverse(intrins))points = combine.view(B, N, 1, 1, 1, 3, 3).matmul(points).squeeze(-1)points += trans.view(B, N, 1, 1, 1, 3)return pointsdef plot_XYZ_frustum(frustum,path):fig = plt.figure()ax = fig.add_subplot(111, projection='3d')# Convert frustum tensor to numpy array for visualizationfor i in range(len(frustum)):frustum_np = frustum[i].numpy()# Extract x, y, d coordinatesx = frustum_np[..., 0].flatten()y = frustum_np[..., 1].flatten()d = frustum_np[..., 2].flatten()# Plot the points in 3D spaceax.scatter(x, y, d, c=d, cmap='viridis', marker='o')ax.set_xlabel('X')ax.set_ylabel('Y')ax.set_zlabel('Z')plt.show()plt.savefig(path)def cumsum_trick(cam_feat, geom_feat, ranks):# 最后一个维度累计,前缀和cam_feat = cam_feat.cumsum(0)# 过滤# [42162,64]->[7268,64] [42162,4]->[7268,4]# 将rank错位比较,找到rank中 == voxel_id == 发生变化的位置,记为keptkept = torch.ones(cam_feat.shape[0], device=cam_feat.device, dtype=torch.bool)kept[:-1] = (ranks[1:] != ranks[:-1])# 利用kept筛选得到x, 错位相减,从而实现将落在相同voxel特征求和cam_feat, geom_feat = cam_feat[kept], geom_feat[kept]cam_feat = torch.cat((cam_feat[:1], cam_feat[1:] - cam_feat[:-1])) # 错位相减得到的特征和return cam_feat, geom_featdef plot_bev(bev, name = f'bev'):# ---- tensor -> array ----#array1 = bev.squeeze(0).cpu().detach().numpy()# ---- array -> mat ----#array1 = array1 * 255mat = np.uint8(array1)mat = mat.transpose(1, 2, 0)# ---- vis ----#cv2.imshow(name, mat)cv2.waitKey(0)if __name__ == "__main__":# 1.创建三维tensor(2d image + depth)uvd_frustum = create_uvd_frustum()plot_uvd_frustum(uvd_frustum)# 2.视锥化(使用相机内外参,将三维tensor转到EGO坐标系下)XYZ_frustum = get_geometry_feat(uvd_frustum,rots, trans, intrins, post_rots, post_trans)plot_XYZ_frustum(XYZ_frustum[0],path = f'EGO_XYZ_frustum.png')# 3.体素化dx, bx, nx = gen_dx_bx(xbound, ybound, zbound)geom_feats = ((XYZ_frustum - (bx - dx / 2.)) / dx).long()plot_XYZ_frustum(geom_feats[0], path = f'voxel.png')# 4.bev_pool# 4.1. cam_feats,geom_feats 展平cam_feats = torch.rand(batch_size, num_cams, D_, feat_H16, feat_W16, semantic_channels)B, N, D, H, W, C = cam_feats.shapeL__ = B * N * D * H * Wcam_feats = cam_feats.reshape(L__, C)geom_feats = geom_feats.view(L__, 3)# 4.2.geom_feat增加batch维度batch_index = torch.cat([torch.full([L__ // B, 1], ix, device=cam_feats.device, dtype=torch.long) for ix in range(B)])geom_feats = torch.cat((geom_feats, batch_index), 1)# 4.3.filter by (X<200,Y<200,Z<1)kept = (geom_feats[:, 0] >= 0) & (geom_feats[:, 0] < nx[0]) & (geom_feats[:, 1] >= 0) & (geom_feats[:, 1] < nx[1]) & (geom_feats[:, 2] >= 0) & (geom_feats[:, 2] < nx[2])cam_feats = cam_feats[kept]geom_feats = geom_feats[kept]# 4.4.voxel index 位置编码,排序ranks = (geom_feats[:, 0] * (nx[1] * nx[2] * B) # X+ geom_feats[:, 1] * (nx[2] * B) # Y+ geom_feats[:, 2] * B # Z+ geom_feats[:, 3]) # batch_indexsorts = ranks.argsort()cam_feats, geom_feats, ranks = cam_feats[sorts], geom_feats[sorts], ranks[sorts]# 4.5. sumcam_feats, geom_feats = cumsum_trick(cam_feats, geom_feats, ranks)# 4.6.根据视锥获取相应的cam_feat, final:[1,64,1,200,200]final = torch.zeros((B, C, nx[2], nx[0], nx[1]), device=cam_feats.device)final[geom_feats[:, 3], :, geom_feats[:, 2], geom_feats[:, 0], geom_feats[:, 1]] = cam_feats# 4.7.去掉Z维度, dim_Z维度属于dim=2, 生成bev图final = torch.cat(final.unbind(dim=2), 1)# 5.bev_encoderbev_encoder = nn.Conv2d(semantic_channels, 1, kernel_size=1, stride=1, padding=0,bias=False)bev = bev_encoder(final)plot_bev(bev, name = f'bev')
二、逐步代码讲解+图解
完整流程:
1.创建uv coord
+ depth estimation (2d image + depth)
2.视锥化(uv coord -> world coord
) (根据相机内外参,构建4x3的投影矩阵)
3.体素化(world coord -> voxel coord
) (会有到世界范围划分及各自维度的刻度)
4.bev_pool(voxel coord -> bev coord
)(去掉Z轴)
1.创建uv coord + depth estimation (2d image + depth)
uvd_frustum = create_uvd_frustum()
plot_uvd_frustum(uvd_frustum)
注意
1.坐标范围,u,v范围代表模型输入尺寸(352,128),d范围为(4,45)。
2.u轴有22个柱子(pillar),22=352//16;v轴有8个柱子(pillar),8=128//16;d轴有41个刻度,41=(45-4)//1
2.视锥化(uv coord -> world coord) (根据相机内外参,构建4x3的投影矩阵)
XYZ_frustum = get_geometry_feat(uvd_frustum,rots, trans, intrins, post_rots, post_trans)
plot_XYZ_frustum(XYZ_frustum[0],path = f'EGO_XYZ_frustum.png')
我这里为了看起来更直观点,选了两个相机,实际在使用过程中,可以灵活使用1个,2个,4个,6个相机。
3.体素化(world coord -> voxel coord) (会有到世界范围划分及各自维度的刻度)
dx, bx, nx = gen_dx_bx(xbound, ybound, zbound)
geom_feats = ((XYZ_frustum - (bx - dx / 2.)) / dx).long()
plot_XYZ_frustum(geom_feats[0], path = f'voxel.png')
为什么上面和下面的形状不一样呢?因为1.相机内外参数的影响 2.因为(旋转,平移)数据增强的影响
注意观察,此时的XYZ轴的范围已经落在(200,200,1)的bev尺寸范围里了!
4.bev_pool(voxel coord -> bev coord)(去掉Z轴)
- 4.1. cam_feats,geom_feats 展平
cam_feats = torch.rand(batch_size, num_cams, D_, feat_H16, feat_W16, semantic_channels)
B, N, D, H, W, C = cam_feats.shape
L__ = B * N * D * H * W
cam_feats = cam_feats.reshape(L__, C)geom_feats = geom_feats.view(L__, 3)
- 4.2.geom_feat增加batch维度
batch_index = torch.cat([torch.full([L__ // B, 1], ix, device=cam_feats.device, dtype=torch.long) for ix in range(B)])
geom_feats = torch.cat((geom_feats, batch_index), 1)
- 4.3.filter by (X<200,Y<200,Z<1)
kept = (geom_feats[:, 0] >= 0) & (geom_feats[:, 0] < nx[0]) & (geom_feats[:, 1] >= 0) & (geom_feats[:, 1] < nx[1]) & (geom_feats[:, 2] >= 0) & (geom_feats[:, 2] < nx[2])
cam_feats = cam_feats[kept]
geom_feats = geom_feats[kept]
- 4.4.voxel index 位置编码,排序
ranks = (geom_feats[:, 0] * (nx[1] * nx[2] * B) # X+ geom_feats[:, 1] * (nx[2] * B) # Y+ geom_feats[:, 2] * B # Z+ geom_feats[:, 3]) # batch_index
sorts = ranks.argsort()
cam_feats, geom_feats, ranks = cam_feats[sorts], geom_feats[sorts], ranks[sorts]
可以参考我画的示意图
- 4.5. sum
cam_feats, geom_feats = cumsum_trick(cam_feats, geom_feats, ranks)
- 4.6.根据视锥获取相应的cam_feat, final:[1,64,1,200,200]
final = torch.zeros((B, C, nx[2], nx[0], nx[1]), device=cam_feats.device)
final[geom_feats[:, 3], :, geom_feats[:, 2], geom_feats[:, 0], geom_feats[:, 1]] = cam_feats
- 4.7.去掉Z维度, dim_Z维度属于dim=2, 生成bev图
final = torch.cat(final.unbind(dim=2), 1)
5.bev_encoder
bev_encoder = nn.Conv2d(semantic_channels, 1, kernel_size=1, stride=1, padding=0,bias=False)
bev = bev_encoder(final)
plot_bev(bev, name = f'bev')
bev尺寸为200x200