文章目录
- torch.nn.Flatten
- torch.flatten()
官方API文档:点击跳转。torch.nn.Flatten
是Pytorch提供的类,常用于将输入数据进行展平,而torch.flatten()
函数与之功能相同。
torch.nn.Flatten
类初始化方式:
torch.nn.Flatten(start_dim=1, end_dim=-1)
start_dim=1
:开始维度,默认为1.end_dim=1
:结束维度,默认为-1.
torch.nn.Flatten
对象常用于参与构造神经网络,对神经网络模型的输出进行处理,将连续的维度范围展平为张量。例:
# 定义输入数据
input = torch.randn(32, 1, 5, 5)# 创建Flatten对象
flatten = nn.Flatten()
output = flatten(input)# 输出结果
print('output.shape=',output.shape)
从dim=0
展开直至dim=3
,有关系: [ 32 , 1 , 5 , 5 ] − > [ 32 , 1 ∗ 5 ∗ 5 ] = [ 32 , 25 ] [32,1,5,5]->[32,1*5*5]=[32,25] [32,1,5,5]−>[32,1∗5∗5]=[32,25]。这是因为在模型中,dim=0
通常表示数据的batch_size
,故需要将一个个数据拉平为一维。而若从dim=0
开始拉平,则所有数据被混淆在一起而存在一个一维向量中。
torch.flatten()
函数声明如下:
torch.flatten(start_dim = 0, end_dim = -1)
用于指定维度拉平数据,作用与torch.nn.Flatten
类对象相同,但默认从dim=0
开始拉平。例:
import torch
x=torch.randn(2,4,2)# 默认start_dim=0,end_dim=-1,将Tensor拉为一维。
z=torch.flatten(x)
print(z.shape)# [16]=[2*4*2]# 模拟torch.nn.Flatten功能
w=torch.flatten(x,1)
print(w.shape)# [2,8]=[2,4*2]