view()
x.view()
是 PyTorch 中用于改变张量形状的方法之一,它允许你在保持张量元素总数不变的情况下,重新组织张量的维度和大小。
view()
方法的用法如下:
x.view(*shape)
其中 x
是要进行形状变换的张量,shape
是一个整数或整数元组,用于指定目标形状。*
表示可以接受不定数量的参数。
view()
方法的关键是指定目标形状 shape
。目标形状应该满足以下两个条件:
- 目标形状的元素总数必须与原始张量的元素总数相同,以保持数据完整性。
- 目标形状中的维度可以是具体的整数值,也可以是
-1
,表示该维度的大小由其他维度的大小推断得出。
下面是一些使用 view()
方法的示例:
- 将一个张量的形状从
(2, 3, 4)
调整为(6, 4)
:
x = torch.randn(2, 3, 4)
y = x.view(6, 4)
- 将一个张量的形状从
(1, 10)
调整为(5, 2)
:
x = torch.randn(1, 10)
y = x.view(5, 2)
- 将一个张量的形状从
(3, 4, 5)
调整为(3, -1)
,其中-1
表示自动计算该维度的大小:
x = torch.randn(3, 4, 5)
y = x.view(3, -1)
需要注意的是,view()
方法并不改变张量的数据内容,它只是改变了张量的形状。因此,改变形状后的张量与原始张量共享相同的存储空间,即它们指向同一块内存。这意味着对其中一个张量的修改也会影响到另一个张量。
此外,有时候在进行形状变换时,可能会遇到不能满足形状变换要求的情况,例如无法将一个张量的元素总数与目标形状的元素总数匹配。在这种情况下,会抛出一个错误。因此,确保目标形状与原始张量的元素总数匹配是很重要的。
总结起来,x.view()
方法是用于改变张量形状的函数,它允许你重新组织张量的维度和大小。通过指定目标形状,你可以调整张量的形状以满足特定的需求。
业务示例代码:
# forward中定义数据流def forward(self, x):x = self.conv1(x)x = self.block2(x)x = self.block3(x)x = self.avgpool(x) # 此时这里是(样本量,256,1,1) NCHW# 进入线性层,必须将数据从四维数据变成一个二维数据# 此时需要将x给拉平x = x.view(x.shape[0], 256)x = self.fc(x)