torch.einsum
是 PyTorch 中一个强大且灵活的张量运算函数,基于爱因斯坦求和约定进行操作。它允许用户通过简单的字符串表达式来定义复杂的张量运算,代替显式的循环或多个矩阵乘法操作。
函数签名
torch.einsum(equation, *operands) → Tensor
参数
equation
: 一个字符串,描述了张量间的操作关系。它使用爱因斯坦求和约定,用逗号分隔不同张量的索引,使用箭头(->
)定义输出的形状。- 左侧部分是输入张量的维度索引,逗号分隔。
- 右侧是输出张量的维度索引。如果没有提供输出维度,函数默认对所有不重复的索引进行求和。
*operands
: 需要操作的张量。张量的维度必须与equation
中的描述匹配。
爱因斯坦求和约定
爱因斯坦求和约定是一种简化张量运算的表示方式。它假设对所有重复的索引进行求和。例如:
'ij,jk->ik'
表示矩阵乘法,其中i
和k
是保留下来的维度,j
是求和的维度。
示例
1. 矩阵乘法
矩阵乘法可以通过 torch.einsum
实现:
import torchA = torch.tensor([[1, 2],