首先,假设我们有一个三行四列的张量 X:
view() 和 reshape() 函数都可以指定并改变张量的维度,它们本质上是相同的,只有两点区别:
1、view() 函数返回的是原始张量的视图,而 reshape() 函数返回的是原始张量的视图或者副本;
2、正因为 view() 函数返回的是视图,它要求张量数据在内存中是连续contiguous 的,而 reshape() 函数则没有这样的要求,它可以对连续张量或者非连续张量进行操作。
假设张量 Y 比 X 多了一个长度等于 1 的维度,如下:
长度为 1 的维度实际上是没有意义的,可以使用 squeeze() 自动消除所有长度为 1 的维度:
与之相反的,就是 unsqueeze() 函数,可以增加一个维度(长度等于 1):
最后,flatten() 函数的作用是展平张量,它还可以指定开始的维度和结束的维度,默认 start_dim=0, end_dim=-1
,例如 Pytorch 中维度顺序为 NCHW,如果 N(batch_size) 不等于 1,就得让 start_dim = 1,只展平每个样本的 CHW,保持每个样本的独立性。
使用 view(-1) 也能实现和 flatten() 一样的功能,它们区别实际上就是 view() 函数和 reshape() 函数的区别,因为 flatten() 就是用 reshape() 实现的,看下源码即可知道。