torch.mm()
是 PyTorch 中用于执行矩阵乘法(matrix multiplication)的函数。它能够将两个给定的张量进行矩阵乘法运算,得到结果张量。
这是 torch.mm()
函数的基本语法:
torch.mm(input, mat2, *, out=None)
input
: 第一个输入张量,形状为(N, M)
。mat2
: 第二个输入张量,形状为(M, P)
。out
: 可选参数,用于指定输出张量。
两个输入张量的维度必须满足矩阵乘法的要求,即第一个张量的列数必须等于第二个张量的行数。
下面是一个简单的例子,说明了 torch.mm()
函数的用法:
import torch# 创建两个矩阵
A = torch.tensor([[1, 2],[3, 4]])
B = torch.tensor([[5, 6],[7, 8]])# 执行矩阵乘法运算
C = torch.mm(A, B)print(C)
输出结果是:
tensor([[19, 22],[43, 50]])
在这个例子中,我们定义了两个 2x2 的矩阵 A
和 B
,并通过 torch.mm()
函数执行了矩阵乘法运算,得到了结果矩阵 C
。