源码
有两种模式,SO3xR3和SE3,代表不同的刚体变换,都是6个参数,表述三个旋转角+3个偏移量
# Initialize learnable parameters.if self.config.mode == "off":passelif self.config.mode in ("SO3xR3", "SE3"):self.pose_adjustment = torch.nn.Parameter(torch.zeros((num_cameras, 6), device=device))else:assert_never(self.config.mode)
通过相应的函数计算出调整结果adj
def exp_map_SO3xR3(tangent_vector: Float[Tensor, "b 6"]) -> Float[Tensor, "b 3 4"]:"""Compute the exponential map of the direct product group `SO(3) x R^3`.This can be used for learning pose deltas on SE(3), and is generally faster than `exp_map_SE3`.Args:tangent_vector: Tangent vector; length-3 translations, followed by an `so(3)` tangent vector.Returns:[R|t] transformation matrices."""# code for SO3 map grabbed from pytorch3d and stripped down to bare-boneslog_rot = tangent_vector[:, 3:]nrms = (log_rot * log_rot).sum(1)rot_angles = torch.clamp(nrms, 1e-4).sqrt()rot_angles_inv = 1.0 / rot_anglesfac1 = rot_angles_inv * rot_angles.sin()fac2 = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles.cos())skews = torch.zeros((log_rot.shape[0], 3, 3), dtype=log_rot.dtype, device=log_rot.device)skews[:, 0, 1] = -log_rot[:, 2]skews[:, 0, 2] = log_rot[:, 1]skews[:, 1, 0] = log_rot[:, 2]skews[:, 1, 2] = -log_rot[:, 0]skews[:, 2, 0] = -log_rot[:, 1]skews[:, 2, 1] = log_rot[:, 0]skews_square = torch.bmm(skews, skews)ret = torch.zeros(tangent_vector.shape[0], 3, 4, dtype=tangent_vector.dtype, device=tangent_vector.device)ret[:, :3, :3] = (fac1[:, None, None] * skews+ fac2[:, None, None] * skews_square+ torch.eye(3, dtype=log_rot.dtype, device=log_rot.device)[None])# Compute the translationret[:, :3, 3] = tangent_vector[:, :3]return retdef exp_map_SE3(tangent_vector: Float[Tensor, "b 6"]) -> Float[Tensor, "b 3 4"]:"""Compute the exponential map `se(3) -> SE(3)`.This can be used for learning pose deltas on `SE(3)`.Args:tangent_vector: A tangent vector from `se(3)`.Returns:[R|t] transformation matrices."""tangent_vector_lin = tangent_vector[:, :3].view(-1, 3, 1)tangent_vector_ang = tangent_vector[:, 3:].view(-1, 3, 1)theta = torch.linalg.norm(tangent_vector_ang, dim=1).unsqueeze(1)theta2 = theta**2theta3 = theta**3near_zero = theta < 1e-2non_zero = torch.ones(1, dtype=tangent_vector.dtype, device=tangent_vector.device)theta_nz = torch.where(near_zero, non_zero, theta)theta2_nz = torch.where(near_zero, non_zero, theta2)theta3_nz = torch.where(near_zero, non_zero, theta3)# Compute the rotationsine = theta.sin()cosine = torch.where(near_zero, 8 / (4 + theta2) - 1, theta.cos())sine_by_theta = torch.where(near_zero, 0.5 * cosine + 0.5, sine / theta_nz)one_minus_cosine_by_theta2 = torch.where(near_zero, 0.5 * sine_by_theta, (1 - cosine) / theta2_nz)ret = torch.zeros(tangent_vector.shape[0], 3, 4).to(dtype=tangent_vector.dtype, device=tangent_vector.device)ret[:, :3, :3] = one_minus_cosine_by_theta2 * tangent_vector_ang @ tangent_vector_ang.transpose(1, 2)ret[:, 0, 0] += cosine.view(-1)ret[:, 1, 1] += cosine.view(-1)ret[:, 2, 2] += cosine.view(-1)temp = sine_by_theta.view(-1, 1) * tangent_vector_ang.view(-1, 3)ret[:, 0, 1] -= temp[:, 2]ret[:, 1, 0] += temp[:, 2]ret[:, 0, 2] += temp[:, 1]ret[:, 2, 0] -= temp[:, 1]ret[:, 1, 2] -= temp[:, 0]ret[:, 2, 1] += temp[:, 0]# Compute the translationsine_by_theta = torch.where(near_zero, 1 - theta2 / 6, sine_by_theta)one_minus_cosine_by_theta2 = torch.where(near_zero, 0.5 - theta2 / 24, one_minus_cosine_by_theta2)theta_minus_sine_by_theta3_t = torch.where(near_zero, 1.0 / 6 - theta2 / 120, (theta - sine) / theta3_nz)ret[:, :, 3:] = sine_by_theta * tangent_vector_linret[:, :, 3:] += one_minus_cosine_by_theta2 * torch.cross(tangent_vector_ang, tangent_vector_lin, dim=1)ret[:, :, 3:] += theta_minus_sine_by_theta3_t * (tangent_vector_ang @ (tangent_vector_ang.transpose(1, 2) @ tangent_vector_lin))return ret
叠加到 camera_to_worlds
矩阵上
return torch.cat([# Apply rotation to directions in world coordinates, without touching the origin.# Equivalent to: directions -> correction[:3,:3] @ directionstorch.bmm(adj[..., :3, :3], camera.camera_to_worlds[..., :3, :3]),# Apply translation in world coordinate, independently of rotation.# Equivalent to: origins -> origins + correction[:3,3]camera.camera_to_worlds[..., :3, 3:] + adj[..., :3, 3:],],dim=-1,)
对比两种方式
在你提供的代码中,exp_map_SO3xR3
和exp_map_SE3
是两个函数,它们都用于计算从切空间到特殊欧几里得群SE(3)的指数映射。SE(3)是结合了旋转和平移的群,用于描述三维空间中的刚体运动。这两个函数的主要差异在于它们处理旋转部分的方式不同,以及它们对切向量的解释略有不同。
exp_map_SO3xR3
这个函数计算的是SO(3)(三维旋转群)和R^3(三维平移空间)的直积群的指数映射。它假定输入的切向量前三个分量是平移分量,后三个分量是SO(3)的切向量(通常表示为旋转向量或轴角表示法)。函数首先计算旋转部分的指数映射,然后与平移部分结合,生成最终的[R|t]变换矩阵。
关键步骤包括:
- 计算旋转角度(
rot_angles
)。 - 使用旋转角度和对应的旋转向量构造旋转矩阵(通过计算
skews
和skews_square
)。 - 结合旋转矩阵和平移向量生成最终的变换矩阵。
exp_map_SE3
这个函数直接计算SE(3)的指数映射。输入的切向量同样包含平移和旋转信息,但是这个函数在处理旋转时采用了不同的方法。它首先计算旋转角度(theta
),然后根据旋转角度的大小采用不同的近似方法来计算旋转矩阵。
关键步骤包括:
- 分离平移向量(
tangent_vector_lin
)和旋转向量(tangent_vector_ang
)。 - 根据旋转角度的大小,选择使用精确计算还是近似计算来得到旋转矩阵。
- 计算旋转矩阵并与平移向量结合,生成最终的变换矩阵。
主要差异
- 旋转处理:
exp_map_SO3xR3
使用了一个简化的方法来直接从旋转向量构造旋转矩阵,而exp_map_SE3
则根据旋转角度的大小采用不同的近似方法。 - 性能:
exp_map_SO3xR3
通常比exp_map_SE3
更快,因为它采用了更直接的方法来构造旋转矩阵。 - 适用性:
exp_map_SO3xR3
适用于学习SE(3)上的位姿增量,而exp_map_SE3
则直接处理SE(3)的切向量。
总的来说,这两个函数都是用于从SE(3)的切空间到SE(3)本身的映射,但是它们在处理旋转部分时采用了不同的策略,这可能会影响它们的性能和适用性。