pytorch小记(十):pytorch中torch.tril 和 torch.triu 详解
- PyTorch `torch.tril` 和 `torch.triu` 详解
- 1. `torch.tril`(计算下三角矩阵)
- 📌 作用
- 🔍 语法
- 🔹 参数
- 📌 示例
- 🔍 `diagonal` 参数
- 🔍 `torch.tril` 的应用
- 2. `torch.triu`(计算上三角矩阵)
- 📌 作用
- 🔍 语法
- 🔹 参数
- 📌 示例
- 🔍 `diagonal` 参数
- 3. `torch.tril` vs `torch.triu` 对比
- 总结
PyTorch torch.tril
和 torch.triu
详解
在数值计算和深度学习中,下三角矩阵(Lower Triangular Matrix) 和 上三角矩阵(Upper Triangular Matrix) 是非常常见的矩阵操作。PyTorch 提供了 torch.tril()
和 torch.triu()
这两个函数,分别用于计算下三角矩阵和上三角矩阵。
1. torch.tril
(计算下三角矩阵)
📌 作用
torch.tril
返回输入张量的 下三角部分,即:
- 保留 主对角线及其以下的元素。
- 主对角线以上的元素全部变为 0。
🔍 语法
torch.tril(input, diagonal=0)
🔹 参数
参数 | 说明 |
---|---|
input | 输入张量 |
diagonal | 控制对角线位置(默认 0 ) |
diagonal=0 | 保留主对角线 及其以下的元素 |
diagonal>0 | 向上偏移,保留主对角线以上 diagonal 行 |
diagonal<0 | 向下偏移,移除 -diagonal 行的主对角线元素 |
📌 示例
import torch# 创建一个 4×4 的矩阵
A = torch.tensor([[1, 2, 3, 4],[5, 6, 7, 8],[9, 10, 11, 12],[13, 14, 15, 16]
])print("原始矩阵 A:")
print(A)# 计算 A 的下三角矩阵
L = torch.tril(A)
print("\nA 的下三角矩阵(diagonal=0):")
print(L)
输出:
原始矩阵 A:
tensor([[ 1, 2, 3, 4],[ 5, 6, 7, 8],[ 9, 10, 11, 12],[13, 14, 15, 16]])A 的下三角矩阵(diagonal=0):
tensor([[ 1, 0, 0, 0],[ 5, 6, 0, 0],[ 9, 10, 11, 0],[13, 14, 15, 16]])
💡 说明:主对角线上的元素保留,其上的元素变为
0
。
🔍 diagonal
参数
print(torch.tril(A, diagonal=1)) # 保留主对角线以上 1 行
print(torch.tril(A, diagonal=-1)) # 移除主对角线
输出:
A 的下三角矩阵(diagonal=1):
tensor([[ 1, 2, 0, 0],[ 5, 6, 7, 0],[ 9, 10, 11, 12],[13, 14, 15, 16]])A 的下三角矩阵(diagonal=-1):
tensor([[ 0, 0, 0, 0],[ 5, 0, 0, 0],[ 9, 10, 0, 0],[13, 14, 15, 0]])
🔺 diagonal=1:向上偏移,保留
1
行主对角线以上的元素。
🔻 diagonal=-1:向下偏移,移除主对角线。
🔍 torch.tril
的应用
📌 用于 Masking(掩码)
seq_length = 5
mask = torch.tril(torch.ones(seq_length, seq_length)) # 创建一个下三角 Mask
print(mask)
输出:
tensor([[1., 0., 0., 0., 0.],[1., 1., 0., 0., 0.],[1., 1., 1., 0., 0.],[1., 1., 1., 1., 0.],[1., 1., 1., 1., 1.]])
💡 Transformer 中,这种 Mask 用于防止模型在训练时提前看到未来的信息。
2. torch.triu
(计算上三角矩阵)
📌 作用
torch.triu
返回输入张量的 上三角部分,即:
- 保留 主对角线及其以上的元素。
- 主对角线以下的元素全部变为 0。
🔍 语法
torch.triu(input, diagonal=0)
🔹 参数
参数 | 说明 |
---|---|
input | 输入张量 |
diagonal=0 | 保留主对角线及其以上的元素 |
diagonal>0 | 移除 diagonal 行的主对角线元素 |
diagonal<0 | 保留主对角线以下 -diagonal 行 |
📌 示例
U = torch.triu(A)
print("A 的上三角矩阵(diagonal=0):")
print(U)
输出:
A 的上三角矩阵(diagonal=0):
tensor([[ 1, 2, 3, 4],[ 0, 6, 7, 8],[ 0, 0, 11, 12],[ 0, 0, 0, 16]])
💡 说明:主对角线上的元素及其上的元素保留,下面的元素变为
0
。
🔍 diagonal
参数
print(torch.triu(A, diagonal=1)) # 移除主对角线元素
print(torch.triu(A, diagonal=-1)) # 保留主对角线以下 1 行
输出:
A 的上三角矩阵(diagonal=1):
tensor([[ 0, 2, 3, 4],[ 0, 0, 7, 8],[ 0, 0, 0, 12],[ 0, 0, 0, 0]])A 的上三角矩阵(diagonal=-1):
tensor([[ 1, 2, 3, 4],[ 5, 6, 7, 8],[ 0, 10, 11, 12],[ 0, 0, 15, 16]])
🔺 diagonal=1:移除主对角线的元素,仅保留主对角线以上的元素。
🔻 diagonal=-1:允许保留主对角线以下1
行的元素。
3. torch.tril
vs torch.triu
对比
作用 | torch.tril(A) | torch.triu(A) |
---|---|---|
计算结果 | 取下三角部分 | 取上三角部分 |
对角线控制 | diagonal=0 保留主对角线 | diagonal=0 保留主对角线 |
diagonal>0 | 保留主对角线以上元素 | 移除主对角线部分元素 |
diagonal<0 | 移除主对角线部分元素 | 保留主对角线以下部分 |
总结
torch.tril()
取 下三角矩阵,可以用于 Cholesky 分解、Transformer Masking。torch.triu()
取 上三角矩阵,常用于 线性代数计算 和 矩阵变换。
🚀 你可以根据不同的需求选择合适的函数,在 PyTorch 中高效处理矩阵运算!