目录
unbind拆分子张量
1. 沿着第n个维度拆分(即按“批次”拆分)
split分割张量
常用用法:
总结:
unbind拆分子张量
import torchquaternions = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
result = torch.unbind(quaternions, -1)
print(result)
1. 沿着第n个维度拆分(即按“批次”拆分)
假设你有一个形状为 (batch_size, n)
的张量,你可以沿着第一个维度(即批次维度)拆分它。
split分割张量
返回一个元组,其中包含分割后的子张量。
常用用法:
-
按指定大小分割: 当
split_size_or_sections
为一个整数时,表示每个子张量的大小。import torch# 创建一个张量 tensor = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])# 按照大小 3 分割 result = torch.split(tensor, 3)# 输出分割后的结果 for i, part in enumerate(result):print(f"Part {i}: {part}")
这个例子中,张量被分割成了 3 个大小为 3 的子张量和一个大小为 1 的子张量。
-
按指定分割长度分割: 当
split_size_or_sections
是一个列表或元组时,表示沿指定维度分割的块数。每个数值对应要分割的子张量的大小。
# 创建一个张量
tensor = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])# 按照每个块的大小为 2 和 3 分割
result = torch.split(tensor, [2, 3, 5])# 输出分割后的结果
for i, part in enumerate(result):print(f"Part {i}: {part}")
指定维度进行分割: 可以通过 dim
参数指定沿哪个维度进行分割。
# 创建一个二维张量
tensor = torch.tensor([[1, 2, 3, 4],[5, 6, 7, 8]])# 沿着维度 1(列)分割,每个子张量包含 2 列
result = torch.split(tensor, 2, dim=1)# 输出分割后的结果
for i, part in enumerate(result):print(f"Part {i}: {part}")
总结:
torch.split
是一个非常实用的工具,能够根据指定的大小或者分割长度将张量分割成多个子张量。常用的应用场景包括:
-
将数据按批次(batch)分割。
-
在处理大张量时按一定的块进行分割,方便并行计算或逐块处理。