在PyTorch中view()
, transpose()
和 permute()
函数都是用于改变张量(Tensor)维度结构的,但它们的作用和使用场景有所不同。
- torch.view()
- 功能:该函数用于将一个张量重塑为新的形状,但它必须保持原有元素数量不变。它主要用于改变张量的维度布局,而不仅仅是交换维度。
- 用法:通常用于简化或展开张量的维度,例如将三维张量展平成一维或二维。
import torchbatch = 3
seq_size = 2
embed = 8
torch.random.manual_seed(123)
x = torch.randint(25, (batch, seq_size, embed)).float()
print(x)
# tensor([[[ 7., 14., 2., 10., 5., 17., 11., 7.],
# [24., 4., 11., 21., 16., 21., 12., 24.]],
#
# [[14., 1., 13., 5., 0., 16., 5., 22.],
# [ 9., 2., 21., 6., 15., 1., 16., 15.]],
#
# [[23., 4., 4., 16., 1., 18., 0., 20.],
# [ 9., 1., 1., 7., 13., 21., 12., 12.]]])# 将后两维度张量展平,将每一词的词嵌入按行连接
z = x.view(batch, -1)
print(z)
# tensor([[7., 14., 2., 10., 5., 17., 11., 7., 24., 4., 11., 21., 16., 21., 12., 24.],
# [14., 1., 13., 5., 0., 16., 5., 22., 9., 2., 21., 6., 15., 1., 16., 15.],
# [23., 4., 4., 16., 1., 18., 0., 20., 9., 1., 1., 7., 13., 21., 12., 12.]])# transformer中多头注意力机制常用,把最后一维词嵌入的维度进行两次切割
# 切割出来多余的那部分做为batch放在第一个维度上
y = x.view(batch * 2, -1, embed // 2)
print(y)
# tensor([[[ 7., 14., 2., 10.],
# [ 5., 17., 11., 7.]],
#
# [[24., 4., 11., 21.],
# [16., 21., 12., 24.]],
#
# [[14., 1., 13., 5.],
# [ 0., 16., 5., 22.]],
#
# [[ 9., 2., 21., 6.],
# [15., 1., 16., 15.]],
#
# [[23., 4., 4., 16.],
# [ 1., 18., 0., 20.]],
#
# [[ 9., 1., 1., 7.],
# [13., 21., 12., 12.]]])
- torch.transpose()
- 功能:该函数用于交换两个指定的维度(转置),其中给定轴上的元素被互换。
- 用法:传入两个指定的参数维度且参数循序无关。
import torchbatch = 3
seq_size = 2
embed = 8
torch.random.manual_seed(123)
x = torch.randint(25, (batch, seq_size, embed)).float()
print(x)
# tensor([[[ 7., 14., 2., 10., 5., 17., 11., 7.],
# [24., 4., 11., 21., 16., 21., 12., 24.]],
#
# [[14., 1., 13., 5., 0., 16., 5., 22.],
# [ 9., 2., 21., 6., 15., 1., 16., 15.]],
#
# [[23., 4., 4., 16., 1., 18., 0., 20.],
# [ 9., 1., 1., 7., 13., 21., 12., 12.]]])# 将后两维交换(转置),将每一词的词嵌入按列展示
z = x.transpose(1, 2) # 等价 x.transpose(2, 1)
print(z)
# tensor([[[ 7., 24.],
# [14., 4.],
# [ 2., 11.],
# [10., 21.],
# [ 5., 16.],
# [17., 21.],
# [11., 12.],
# [ 7., 24.]],
#
# [[14., 9.],
# [ 1., 2.],
# [13., 21.],
# [ 5., 6.],
# [ 0., 15.],
# [16., 1.],
# [ 5., 16.],
# [22., 15.]],
#
# [[23., 9.],
# [ 4., 1.],
# [ 4., 1.],
# [16., 7.],
# [ 1., 13.],
# [18., 21.],
# [ 0., 12.],
# [20., 12.]]])
- torch.permute()
- 功能:该函数允许一次性重新排列多个维度,理解成transpose的扩展。
- 用法:传入张量的所有维度,可以同时交换任意两个及以上的维度。
import torchbatch = 3
seq_size = 2
embed = 8
torch.random.manual_seed(123)
x = torch.randint(25, (batch, seq_size, embed)).float()
print(x)
# tensor([[[ 7., 14., 2., 10., 5., 17., 11., 7.],
# [24., 4., 11., 21., 16., 21., 12., 24.]],
#
# [[14., 1., 13., 5., 0., 16., 5., 22.],
# [ 9., 2., 21., 6., 15., 1., 16., 15.]],
#
# [[23., 4., 4., 16., 1., 18., 0., 20.],
# [ 9., 1., 1., 7., 13., 21., 12., 12.]]])# 将后两维重新排序
# 注意这样是报错x.permute(2, 1)或者permute(1, 2, 1)都是非法的
z = x.permute(0, 2, 1) # 等价 x.transpose(2, 1),
# print(z)# 如果我们想要三个维度都交换transpose是做不到的
# 至于有什么实际意义就不讨论了
y = x.permute(2, 1, 0)
print(y)
# tensor([[[ 7., 14., 23.],
# [24., 9., 9.]],
#
# [[14., 1., 4.],
# [ 4., 2., 1.]],
#
# [[ 2., 13., 4.],
# [11., 21., 1.]],
#
# [[10., 5., 16.],
# [21., 6., 7.]],
#
# [[ 5., 0., 1.],
# [16., 15., 13.]],
#
# [[17., 16., 18.],
# [21., 1., 21.]],
#
# [[11., 5., 0.],
# [12., 16., 12.]],
#
# [[ 7., 22., 20.],
# [24., 15., 12.]]])
- torch.unsqueeze()
- 功能:增加一个新的维度。
- 用法:增加维度指定的位置。
import torchseq_size = 2
embed = 8
torch.random.manual_seed(123)
x = torch.randint(25, (seq_size, embed)).float()
print(x)
# tensor([[ 7., 14., 2., 10., 5., 17., 11., 7.],
# [24., 4., 11., 21., 16., 21., 12., 24.]])z = x.unsqueeze(0) # 等价 torch.unsqueeze(x, dim=0)
print(z)
# tensor([[[ 7., 14., 2., 10., 5., 17., 11., 7.],
# [24., 4., 11., 21., 16., 21., 12., 24.]]])
总结:
-
view()
更侧重于保持数据不变的前提下改变张量的维度形状,常用于展平、重塑等操作。 -
transpose()
是特定的维度交换操作,只涉及两个维度的变换。 -
permute()
则提供了更灵活的维度重排功能,可以处理多维度情况下的整体维度顺序调整。 -
unsqueeze()
指定位置增加张量维度。