文章目录
- 一、计算两组点之间的欧式距离
- 二、举例
- 三、中间结果输出
一、计算两组点之间的欧式距离
def square_distance(src, dst):"""Calculate Euclid distance between each two points.src^T * dst = xn * xm + yn * ym + zn * zm;sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dstInput:src: source points, [B, N, C]dst: target points, [B, M, C]Output:dist: per-point square distance, [B, N, M]"""B, N, _ = src.shape_, M, _ = dst.shapedist = -2 * torch.matmul(src, dst.permute(0, 2, 1))dist += torch.sum(src ** 2, -1).view(B, N, 1)dist += torch.sum(dst ** 2, -1).view(B, 1, M)return dist
🍉解释:
B, N, _ = src.shape
:获取输入源点和目标点的形状信息,其中 B 表示批量大小,N 表示源点的数量
_, M, _ = dst.shape
:M 表示目标点的数量,C 表示每个点的维度
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
:这一步计算了两组点之间的叉乘积
dst.permute(0, 2, 1)
:将目标点张量 dst 的第二维和第三维进行交换,以便进行点积
同理,src为N x C,dst为M x C,需要将M x C转置成C x M才可以进行点积(N x C)·(C x M)torch.matmul
:计算源点和目标点之间的点积,结果 dist 是一个形状为 [B, N, M] 的张量,表示每对源点和目标点之间的点积
dist += torch.sum(src ** 2, -1).view(B, N, 1)
:计算了源点和目标点的平方和,并将其广播到与 dist 相同的形状
torch.sum(src ** 2, -1)
:计算张量 src 中每个点的平方和,src ^2 将 src 中的每个元素都平方,然后 torch.sum 函数对最后一个维度(即 -1 所代表的维度)进行求和,最后一个维度被求和消除。- 假设 src 张量的形状是 [B, N, D],其中 B 表示批量大小,N 表示点的数量,D 表示每个点的维度。那么 torch.sum(src ** 2, -1) 的结果形状将是 [B, N],其中每个元素表示了原张量中相应位置点的平方和
view(B, N, 1)
:对张量调整到[B, N, 1],以便与后续的计算相兼容
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
:将这些平方和加到 dist 上,以完成欧氏距离的计算
dist
:张量,函数返回每对源点和目标点之间的欧氏距离的平方,形状为 [B, N, M]
计算欧式距离的平方等价于下方等式
( x 1 − x 2 ) 2 + ( y 1 − y 2 ) 2 + ( z 1 − z 2 ) 2 (x_{1}-x_{2})^{2}+(y_{1}-y_{2})^{2}+(z_{1}-z_{2})^{2} (x1−x2)2+(y1−y2)2+(z1−z2)2=
x 1 2 + y 1 2 + z 1 2 + x 2 2 + y 2 2 + z 2 2 − 2 x 1 x 2 − 2 y 1 y 2 − 2 z 1 z 2 x_{1}^{2}+y_{1}^{2}+z_{1}^{2}+x_{2}^{2}+y_{2}^{2}+z_{2}^{2}-2x_{1}x_{2}-2y_{1}y_{2}-2z_{1}z_{2} x12+y12+z12+x22+y22+z22−2x1x2−2y1y2−2z1z2
二、举例
假设有两组点,分别是 src
和 dst
:
import torchdef square_distance(src, dst):B, N, _ = src.shape_, M, _ = dst.shapedist = -2 * torch.matmul(src, dst.permute(0, 2, 1))dist += torch.sum(src ** 2, -1).view(B, N, 1)dist += torch.sum(dst ** 2, -1).view(B, 1, M)return dist# 定义源点和目标点
src = torch.tensor([[[1, 2, 3], [4, 5, 6]]]) # shape: [1, 2, 3]
dst = torch.tensor([[[7, 8, 9], [10, 11, 12], [13, 14, 15]]]) # shape: [1, 3, 3]dist = square_distance(src, dst)
print(dist)
结果
例如
( 7 − 1 ) 2 + ( 8 − 2 ) 2 + ( 9 − 3 ) 2 = 108 (7-1)^{2}+(8-2)^{2}+(9-3)^{2}=108 (7−1)2+(8−2)2+(9−3)2=108
三、中间结果输出
- 对于
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
对于torch.sum(src ** 2, -1)
tensor([[14, 77]])
:这是一个形状为 (1, 2) 的张量,表示一个批次中有两个源点,每个源点有两个坐标分量。具体地,它包含了以下信息:第一个源点的坐标是 (14, 77)。- 对于torch.sum(src ** 2, -1).view(B, N, 1)
tensor([[[14], [77]]])
:这是一个形状为 (1, 2, 1) 的张量,表示一个批次中有两个目标点,每个目标点有一个坐标分量。具体地,它包含了以下信息:
第一个目标点的坐标是 (14)
第二个目标点的坐标是 (77)- 对于
dist += torch.sum(src ** 2, -1).view(B, N, 1)
- 对于
torch.sum(dst ** 2, -1)
- 对于
torch.sum(src ** 2, -1).view(B, N, 1)
- 对于
dist += torch.sum(src ** 2, -1).view(B, N, 1)