Pytorch 中 matmul 广播方式
一、两个 1 维,向量内积
a = torch.ones(3)
b = torch.ones(3)
print(torch.matmul(a,b)) # tensor(3.)
二、两个 2 维,矩阵相乘
a = torch.ones(3,4)
b= torch.ones(4,3)
print(torch.matmul(a,b))
# tensor([[4., 4., 4.],
# [4., 4., 4.],
# [4., 4., 4.]])
三、一个 1 维,二个 2 维,矩阵和向量相乘
注意
:相靠近的那个维数要相同,比如(7)和(7,8,5),又比如(7,8,5)和(5)
a = torch.ones(3)
b= torch.ones(3,4)
print(torch.matmul(a,b)) # tensor([3., 3., 3., 3.])a = torch.ones(3,4)
b = torch.ones(4)
print(torch.matmul(a,b)) # tensor([4., 4., 4.])
四、高维情况
注意
:两个都高于 2 维,那么除掉最后两个维度外(最后两个维度满足矩阵乘法,即(m,k)*(k,n)),剩下的满足广播机制
a=torch.ones(5,3,4,1)
b=torch.ones( 3,1,1)
print(torch.matmul(a,b))