【Pytorch】各种维度变换函数总结

维度变换千万不要混着用,尤其是交换维度的transpose和更改观察视角的view或者reshape!混用了以后虽然不会报错,但是数据是乱的, 建议用einops中的rearrange,符合人的直观,不容易出错。
一个例子:

>>>t
tensor([[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11]])
>>> t.transpose(0,1)	# 交换t的前两个维度,即对t进行转置。
tensor([[ 0,  4,  8],[ 1,  5,  9],[ 2,  6, 10],[ 3,  7, 11]])
>>> a.reshape(4,3)     # 使用reshape()/view()的方法,虽然形状一样,但是数据排列完全不同
tensor([[ 0,  1,  2],[ 3,  4,  5],[ 6,  7,  8],[ 9, 10, 11]])>>> from einops import rearrange
>>> rearrange(t, 'r c -> c r')
tensor([[ 0,  4,  8],[ 1,  5,  9],[ 2,  6, 10],[ 3,  7, 11]])

下面是转载的这篇文章:PyTorch各种维度变换函数总结,侵删

介绍

本文对于PyTorch中的各种维度变换的函数进行总结,包括reshape()view()resize_()transpose()permute()squeeze()unsqeeze()expand()repeat()函数的介绍和对比。

contiguous

区分各个维度转换函数的前提是需要了解contiguous。在PyTorch中,contiguous指的是Tensor底层一维数组的存储顺序和其元素顺序一致

Tensor是以一维数组的形式存储的,C/C++使用行优先(按行展开)的方式,Python中的Tensor底层实现使用的是C,因此PyThon中的Tensor也是按行展开存储的,如果其存储顺序按行优先展开的一维数组元素顺序一致,就说这个Tensor是连续(contiguous)的。

形式化定义:

对于任意的 d d d 维张量 t t t ,如果满足对于所有的 i i i ,第 i i i 维相邻元素间隔=第 i + 1 i+1 i+1 维相邻元素间隔 × \times × i + 1 i+1 i+1 维长度的乘积,则 t t t 是连续的:
stride  [ i ] = stride  [ i + 1 ] × size ⁡ [ i + 1 ] , ∀ i = 0 , … , d − 1 ( i ≠ d − 1 ) \text { stride }[i]=\text { stride }[i+1] \times \operatorname{size}[i+1], \forall i=0, \ldots, d-1(i \neq d-1)  stride [i]= stride [i+1]×size[i+1],i=0,,d1(i=d1)

  • stride [ i ] [i] [i] 表示第 i i i 维相邻元素之间间隔的位数,称为步长,可通过 stride () 方法获得。
  • size [ i ] [i] [i] 表示固定其他维度时,第 i i i 维的元素数量,即第 i i i 维的长度,通过 size () 方法获得。

Python中的多维张量按照行优先展开的方式存储,访问矩阵中下一个元素是通过偏移来实现的,这个偏移量称为步长(stride),比如python中,访问 2 × 3 2 \times 3 2×3 矩阵的同一行中的相邻元素,物理结构需要偏移 1 个位置,即步长为 1 ,同一列中的两个相邻元素则步长为 3 。

举例说明:

>>>t = torch.arange(12).reshape(3,4)
>>>t
tensor([[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11]])
>>>t.stride(),t.stride(0),t.stride(1) # 返回t两个维度的步长,第0维的步长,第1维的步长
((4,1),4,1)
# 第0维的步长,表示沿着列的两个相邻元素,比如‘0’和‘4’两个元素的步长为4
>>>t.size(1)
4
# 对于i=0,满足stride[0]=stride[1] * size[1]=1*4=4,那么t是连续的。

PyTorch提供了两个关于contiguous的方法:

is_contiguous() : 判断Tensor是否是连续的
contiguous() : 返回新的Tensor,重新开辟一块内存,并且是连续的

举例说明(参考[1]):

>>>t = torch.arange(12).reshape(3,4)
>>>t
tensor([[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11]])
>>>t2 = t.transpose(0,1)
>>>t2
tensor([[ 0,  4,  8],[ 1,  5,  9],[ 2,  6, 10],[ 3,  7, 11]])
>>>t.data_ptr() == t2.data_ptr()  # 返回两个张量的首元素的内存地址
True    	#说明底层数据是同一个一维数组
>>>t.is_contiguous(),t2.is_contiguous()  # t连续,t2不连续
(True, False)

可以看到,t和t2共享内存中的数据。如果对t2使用contiguous()方法,会开辟新的内存空间:

>>>t3 = t2.contiguous()
>>>t3
tensor([[ 0,  4,  8],[ 1,  5,  9],[ 2,  6, 10],[ 3,  7, 11]])
>>>t3.data_ptr() == t2.data_ptr() # 底层数据不是同一个一维数组
False
>>>t3.is_contiguous()
True

关于contiguous的更深入的解释可以参考[1].

view()/reshape()

view()

tensor.view()函数返回一个和tensor共享底层数据,但不同形状的tensor。使用view()函数的要求是tensor必须是contiguous的

用法如下:

>>>t
tensor([[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11]])
>>>t2 = t.view(2,6)
>>>t2
tensor([[ 0,  1,  2,  3,  4,  5],[ 6,  7,  8,  9, 10, 11]])
>>>t.data_ptr() == t2.data_ptr()	# 二者的底层数据是同一个一维数组
True

reshape()

tensor.reshape()类似于tensor.contigous().view()操作,如果tensor是连续的,则reshape()操作和view()相同,返回指定形状、共享底层数据的tensor;如果tensor是不连续的,则会开辟新的内存空间,返回指定形状的tensor,底层数据和原来的tensor是独立的,相当于先执行contigous(),再执行view()

如果不在意底层数据是否使用新的内存,建议使用reshape()代替view().

resize_()

tensor.resize_()函数,返回指定形状的tensor,与reshape()view()不同的是,resize_()可以只截取tensor一部分数据,或者是元素个数大于原tensor也可以,会自动扩展新的位置。

resize_()函数对于tensor的连续性无要求,且返回的值是共享的底层数据(同view()),也就是说只返回了指定形状的索引,底层数据不变的。

transpose()/permute()

permute()transpose()还有t()是PyTorch中的转置函数,其中t()函数只适用于2维矩阵的转置,是这三个函数里面最”弱”的。

transpose()

tensor.transpose(),返回tensor的指定维度的转置,底层数据共享,与view()/reshape()不同的是,transpose()只能实现维度上的转置,不能任意改变维度大小。

对于维度交换来说,view()/reshape()transpose()有很大的区别,一定不要混用!混用了以后虽然不会报错,但是数据是乱的,血坑。

reshape()/view()transpose()的区别在于对于维度改变的方式不同,前者是在存储顺序的基础上对维度进行划分,也就是说将存储的一维数组根据shape大小重新划分,而transpose()则是真正意义上的转置,比如二维矩阵的转置。

举个例子:

>>>t
tensor([[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11]])
>>> t.transpose(0,1)	# 交换t的前两个维度,即对t进行转置。
tensor([[ 0,  4,  8],[ 1,  5,  9],[ 2,  6, 10],[ 3,  7, 11]])
>>> a.reshape(4,3)     # 使用reshape()/view()的方法,虽然形状一样,但是数据排列完全不同
tensor([[ 0,  1,  2],[ 3,  4,  5],[ 6,  7,  8],[ 9, 10, 11]])

permute()

tensor.permute()函数,以view的形式返回矩阵指定维度的转置,和transpose()功能相同。

transpose()不同的是,permute()同时对多个维度进行转置,且参数是期望的维度的顺序,而transpose()只能同时对两个维度转置,即参数只能是两个,这两个参数没有顺序,只代表了哪两个维度进行转置。

举个例子:

>>> t				# t的形状为(2,3,2)
tensor([[[ 0,  1],[ 2,  3],[ 4,  5]],[[ 6,  7],[ 8,  9],[10, 11]]])
>>> t.transpose(0,1)   # 使用transpose()将前两个维度进行转置,返回(3,2,2)
tensor([[[ 0,  1],[ 6,  7]],[[ 2,  3],[ 8,  9]],[[ 4,  5],[10, 11]]])
>>> t.permute(1,0,2)   # 使用permute()按照指定的维度序列对t转置,返回(3,2,2)
tensor([[[ 0,  1],[ 6,  7]],[[ 2,  3],[ 8,  9]],[[ 4,  5],[10, 11]]])

squeeze()/unsqueeze()

squeeze()

tensor.squeeze()返回去除size为1的维度的tensor,默认去除所有size=1的维度,也可以指定去除某一个size=1的维度,并返回去除后的结果。

举个例子:

>>> t.shape 
torch.Size([3, 1, 4, 1])
>>> t.squeeze().shape  # 去除所有size=1的维度
torch.Size([3, 4])
>>> t.squeeze(1).shape  # 去除第1维
torch.Size([3, 4, 1])
>>> t.squeeze(0).shape  # 如果指定的维度size不等于1,则不执行任何操作。
torch.Size([3, 1, 4, 1])

unsqueeze()

tensor.unsqueeze()与squeeze()相反,是在tensor插入新的维度,插入的维度size=1,用于维度扩展。

举个例子:

>>> t.shape
torch.Size([3, 1, 4, 1])
>>> t.unsqueeze(1).shape   # 在指定的位置上插入新的维度,size=1
torch.Size([3, 1, 1, 4, 1]) 
>>> t.unsqueeze(-1).shape  # 参数为-1时表示在最后一维添加新的维度,size=1
torch.Size([3, 1, 4, 1, 1])
>>> t.unsqueeze(4).shape   # 和dim=-1等价
torch.Size([3, 1, 4, 1, 1])

expand()/repeat()

expand()

tensor.expand()的功能是扩展tensor中的size为1的维度,且只能扩展size=1的维度。以view的形式返回tensor,即不改变原来的tensor,只是以视图的形式返回数据。

举个例子:

>>> t
tensor([[[0, 1, 2],[3, 4, 5]]])
>>> t.shape
torch.Size([1, 2, 3])
>>> t.expand(3,2,3)  # 将第0维扩展为3,可见其将第0维复制了3次
tensor([[[0, 1, 2],[3, 4, 5]],[[0, 1, 2],[3, 4, 5]],[[0, 1, 2],[3, 4, 5]]])
>>> t.expand(3,-1,-1) # dim=-1表示固定这个维度,效果是一样的,这样写更方便
tensor([[[0, 1, 2],[3, 4, 5]],[[0, 1, 2],[3, 4, 5]],[[0, 1, 2],[3, 4, 5]]])
>>> t.expand(3,2,3).storage()    # expand不扩展新的内存空间012345
[torch.LongStorage of size 6]

repeat()

tensor.repeat()用于维度复制,可以将size为任意大小的维度复制为n倍,和expand()不同的是,repeat()会分配新的存储空间,是真正的复制数据。

举个例子:

>>> t
tensor([[0, 1, 2],[3, 4, 5]])
>>> t.shape
torch.Size([2, 3])
>>> t.repeat(2,3)  # 将两个维度分别复制2、3倍
tensor([[0, 1, 2, 0, 1, 2, 0, 1, 2],[3, 4, 5, 3, 4, 5, 3, 4, 5],[0, 1, 2, 0, 1, 2, 0, 1, 2],[3, 4, 5, 3, 4, 5, 3, 4, 5]])
>>> t.repeat(2,3).storage()   # repeat()是真正的复制,会分配新的空间012012012345......345
[torch.LongStorage of size 36]

如果维度size=1的时候,repeat()expand()的作用是一样的,但是expand()不会分配新的内存,所以优先使用expand()函数。

总结

  1. view()/reshape()两个函数用于将tensor变换为任意形状,本质是将所有的元素重新分配
  2. t()/transpose()/permute()用于维度的转置,转置和reshape()操作是有区别的,注意区分。
  3. squeeze()/unsqueeze()用于压缩/扩展维度,仅在维度的个数上去除/添加,且去除/添加的维度size=1。
  4. expand()/repeat()用于数据的复制,对一个或多个维度上的数据进行复制。
  5. 以上提到的函数仅有两种会分配新的内存空间:reshape()操作处理非连续的tensor时,返回tensor的copy数据会分配新的内存;repeat()操作会分配新的内存空间。其余的操作都是返回的视图,底层数据是共享的,仅在索引上重新分配。

Reference

1. PyTorch中的contiguous

2. stackoverflow-pytorch-contiguous

3. PyTorch官方文档

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/700670.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何提高测试人员在公司的影响力?“小题大做”少不了!

背景:一件不太合理却又常见的小诉求 某天,一个新入职的运营人员在公司聊天软件里找到我,希望我团队的QA同学能在test环境给他介绍一下当前业务APP里的功能操作,提及了多个模块。 我把聊天内容转给我团队中的小H看,小H的反应是“这活儿也太杂了吧,连运营人员想了解业务…

Ubuntu20.04开启/禁用ipv6

文章目录 Ubuntu20.04开启/禁用ipv61.ipv62. 开启ipv6step1. 编辑sysctl.confstep2. 编辑网络接口配置文件 3. 禁用ipv6(sysctl)4. 禁用ipv6(grub)附:总结linux网络配置 Ubuntu20.04开启/禁用ipv6 1.ipv6 IP 是互联网…

C/C++ PrefixToNetmask、NetmaskToPrefix(网络掩码、Prefix 互转)

作用: Netmask:255.255.255.0 互转 Prefix:24 新实现: static int NetmaskToPrefix(unsigned char* bytes, int bytes_size) noexcept {if (NULL bytes || bytes_…

openGauss学习笔记-227 openGauss性能调优-系统调优-其他因素对LLVM性能的影响

文章目录 openGauss学习笔记-227 openGauss性能调优-系统调优-其他因素对LLVM性能的影响 openGauss学习笔记-227 openGauss性能调优-系统调优-其他因素对LLVM性能的影响 LLVM优化效果不仅依赖于数据库内部具体的实现,还与当前所选择的硬件环境等有关。 表达式调用C…

视频监控集中管理的好处

在视频监控建设的进程中,许多单位采用分期、分批的方式进行,不同的部门、项目所采纳的监控系统和平台各有不同。面对这样一个复杂的状况,大型企业经常需应对多平台、多品牌录像机和监控设备的挑战。这无疑给后续的应用管理带来了不小的困扰。…

2.23C语言学习

P1480 A/B Problem 高精度数除以非高精度数 #include<bits/stdc.h> long long b[66660],c[66660],sum0; char a[66660]; int n; int main(){scanf("%s",a);scanf("%d",&n);int lenstrlen(a);for(int i1;i<len;i){b[i]a[i-1]-0;}for(int i1;…

Qt之Qstring元素访问

和之前讲述的访问QByteArray类对象中某个元素的方式类似&#xff0c;访问QString 类对象方式的某个元素采用类似的4种主要方式&#xff0c;分别为[、at&#xff08;&#xff09;、data[]和 constData[]。其中&#xff0c;[]和data[]方式为可读可写&#xff0c;at&#xff08;&a…

【力扣白嫖日记】176.第二高的薪水

前言 练习sql语句&#xff0c;所有题目来自于力扣&#xff08;https://leetcode.cn/problemset/database/&#xff09;的免费数据库练习题。终于把所有的简单题刷完&#xff0c;进入第一道中等题。 今日题目&#xff1a; 176.第二高的薪水 表&#xff1a;Employee 列名类型…

CSS 的圆角矩形

CSS 的圆角矩形 通过 border-radius 属性使矩形边框带圆角效果成为圆角矩形 语法&#xff1a;border-radius: length; length 是内切圆的半径&#xff0c;其数值越大, 弧线越明显 border-radius 属性值描述length定义圆角的形状%以百分比定义圆角的形状 生成圆形 让 border-…

村镇医院医疗中心污废水如何处理达标

污废水处理是村镇医院医疗中心运营中不可忽视的重要环节。如何有效处理污废水&#xff0c;使其达到相关标准&#xff0c;是保障医疗中心环境卫生的关键之一。 首先&#xff0c;村镇医院医疗中心应建立科学的废水处理系统。该系统应包括预处理、初级处理、中级处理和高级处理等环…

整型数组按个位值排序/最低位排序(C语言)

题目来自于博主算法大师的专栏&#xff1a;最新华为OD机试C卷AB卷OJ&#xff08;CJavaJSPy&#xff09; https://blog.csdn.net/banxia_frontend/category_12225173.html 题目描述 给定一个非空数组&#xff08;列表&#xff09;&#xff0c;其元素数据类型为整型&#xff0c…

JVM(1)

JVM简介 JVM是Java Virtual Machine的简称,意为Java虚拟机. 在java中,它归属于jre(java运行时环境), 而jre归属于jdk(java开发工具包). 虚拟机是指通过软件模拟的具有完整硬件功能的,运行在一个完全隔离的环境中的完整计算机系统. 常见的虚拟机:JVM, VMwave, VirtualBox. J…

2024 Impeller:快速了解 Flutter 的渲染引擎的优势

参考原文 &#xff1a;https://tomicriedel.medium.com/understanding-impeller-a-deep-dive-into-flutters-rendering-engine-ba96db0c9614 最近&#xff0c;在 Flutter 2024 路线规划里明确提出了&#xff0c;今年 Flutter Team 将计划删除 iOS 上的 Skia 的支持&#xff0c;…

python 打包 apk

转换之前python代码需要使用指定的框架才能转换&#xff0c;列如&#xff1a;kivy from kivy.app import App from kivy.uix.boxlayout import BoxLayout from kivy.uix.button import Buttonimport time import pyautogui import threadingstatus False# 这是一个将被线程执…

七种设计原则

1.开闭原则&#xff1a;&#xff08;面向对象编程中&#xff0c;最核心最基础的一个原则&#xff0c;所有设计模式都是围绕这一原则去实践&#xff09;对原有的类不做修改&#xff0c;只做扩展 2.单一职责&#xff1a;说的是类的职责要单子。也就是说一个类最好只负责一方面的…

踩坑:SpringBoot连接Mysql的时区报错

解决方法&#xff1a;1.修改时区2.修改连接版本 目录 1.修改时区 2.切换版本 1.修改时区 查看mysql的默认时区 SELECT global.time_zone AS Global Time Zone, session.time_zone AS Session Time Zone; 查看mysqk的默认是时区返回两个结果 Global Time Zone:表示Mysql…

【数据结构】C语言实现二叉树的相关操作

树 定义 树&#xff08;Tree&#xff09;是 n (n > 0) 个结点的有限集 若 n 0&#xff0c;称为空树 若 n > 0&#xff0c;则它满足如下两个条件&#xff1a; 有且仅有一个特定的称为根&#xff08;Root&#xff09;的结点其余结点可分为 m(m>0) 个互不相交的有限…

剪辑视频调色怎么让画质变得清晰 视频剪辑调色技巧有哪些方面 剪辑视频免费的软件有哪些 会声会影调色在哪里 会声会影模板素材

视频调色的作用有很多&#xff0c;除了进行风格化剪辑以外&#xff0c;还可以让作品的画质变得清晰。通过调色来增强画面的清晰度&#xff0c;在观感上也会显得十分自然。视频调色的技巧有很多&#xff0c;并且原理大都十分简单。有关剪辑视频调色怎么让画质变得清晰&#xff0…

Mybatis总结--传参二

#叫做占位符 Mybatis是封装的JDBC 增强版 内部还是用的jdbc 每遇到一个#号 这里就会变为&#xff1f;占位符 一个#{}就是对应一个问号 一个占位符 用这个对象执行sql语句没有sql注入的风险 八、多个参数-使用Param 当 Dao 接口方法有多个参数&#xff0c;需要通过名称使…

Vue3_基础使用_4_路由器Router

概念&#xff1a; 路由&#xff1a;是一个key-value的对应关系叫路由。 路由器&#xff1a;管理多个路由的集合或者叫设备称为路由器。 由于现在组件替代了以前的mvc中的cshtml, 组件的菜单切换也不用我手动去写&#xff0c;vue给我们通过配置完成。 实现简单的路由跳转&…