input.unsqueeze(0)
是 PyTorch 张量(Tensor)的方法之一,用于增加张量的维度。具体来说,它会在索引为 0 的位置上插入一个维度。
假设 input
是一个形状为 (n,)
的一维张量,其中 n
是任意长度。调用 unsqueeze(0)
后,它会返回一个形状为 (1, n)
的二维张量,新插入的维度的大小为 1。
以下是一个示例:
import torchinput = torch.tensor([1, 2, 3, 4])# 调用 unsqueeze(0) 增加维度
output = input.unsqueeze(0)print(input.shape) # 输出: torch.Size([4])
print(output.shape) # 输出: torch.Size([1, 4])
在上述示例中,input
是一个长度为 4 的一维张量。通过 unsqueeze(0)
将其转换为一个形状为 (1, 4)
的二维张量 output
。新插入的维度位于索引 0 的位置。
unsqueeze(0)
的应用场景通常是在需要对张量进行运算或与其他张量进行操作时,需要调整张量的维度匹配。例如,将一维张量作为输入传递给大小为 (batch_size, ...)
的神经网络,就通常需要在维度上插入一个批次大小的维度。
需要注意的是,unsqueeze(0)
并不会在原地修改输入张量,而是返回一个新的张量。因此,我们在示例中将结果赋值给 output
,以便进行打印输出。