Tensor轴变换 axis 或 dim(transpose、permute、view、reshape、einsum)

操作分类:

  1. 重排维度transpose、swapaxes、permute都是对维度进行重排序,但不改变维度的大小。

  2. 重组维度view、reshape可以重组原始维度,修改维度大小。

  3. 万能运算einsum 通过操作index(dim/axis)匹配对应的矩阵运算

    • dim 与 axis
    • transpose 重排维度
    • permute 重排维度
    • view 重组维度
    • reshape 重组维度
    • einsum 万能运算

dim 与 axis

Tensor的 dim维度axis轴 变换 是 Pytorch深度学习最重要的操作之一(在torch中叫dim多一些,在numpy中叫axis多一些),这些操作不改变内存中的物理存储,只会改变tensor的视图view,即以什么样的顺序或维度来看待这个tensor,越靠后的维度在内存上越相连每个维度都有具体的物理含义。可以通过tensor.shape来查看一个张量的维度。

如加载图像数据后,[32, 3, 64,64]可以理解为[batch_size, channel, hight, weight],如self-attention中[16, 8, 32, 128]可以理解为[batch_szie, heads, seq_len, head_dim]

tensor的dim索引从下标0开始,如shape为[10, 3, 64, 64]的tensor,其dim的取值范围是0,1,2,3

如下例子:

import torch
tensor = torch.randn(10, 3, 64, 64).to("cuda")
tensor.shape  # torch.Size([10, 3, 64, 64])
  • tensor[i]等价于tensor[i, :, :]tensor[i]的shape为[3, 64, 64];
  • tensor[i, j]等价于tensor[i, j, :]tensor[i, j]的shape为[64, 64].

transpose 重排维度

  • 使用方法torch.tanspose(tensor, dim1, dim2)交换 tensor 的 dim1 和 dim2 这两个维度
import torch
tensor = torch.randn(16, 8, 32, 128).to("cuda")
# torch.Size([16, 8, 32, 128])
trans = torch.transpose(tensor, 2, 3).contiguous()
# torch.Size([16, 8, 128, 32])

另外,swapaxes就是tanspose的别名!torch.swapaxes(tensor, dim1, dim2),效果等于上面的tanspose。

permute 重排维度

  • 使用方法transpose和swapaxes只能交换两个维度dim,而permute可以对所有轴进行重排torch.permute(dim1, dim2, dim3...)dim_i是原始维度的索引,将其放到新的位置,就是交换旧维度到新索引位置。
import torch
tensor = torch.randn(16, 8, 32, 128).to("cuda")
# torch.Size([16, 8, 32, 128])
tensor = tensor.permute(0, 2, 1, 3)  # 交换1,2维度
# torch.Size([16, 32, 8, 128])

view 重组维度

  • 使用方法tensor.contiguous().view(dim0, dim1, dim2...) ,将tensor的shape变换为(dim0, dim1, dim2...)dim的个数可以少于或多于原来tensor!,因为所有维度的累积 ∏ i = 0 N d i m i \prod_{i=0}^N{dim_i} i=0Ndimi是不变的,因此当有一个dim=-1时,将自动计算。
import torch
tensor = torch.randn(16, 8, 32, 16).to("cuda")
# torch.Size([16, 8, 32, 16])  (batch_size, heads, seq_len, head_dim)
tensor = tensor.contiguous().view(16, 32, -1)  # 合头heads
# torch.Size([16, 32, 128])  (batch_size, seq_len, dim)
  • contiguous:因为transpose和permute这些操作不改变内存中的物理存储,而torch要求 越靠后的维度在内存上越相连,所以按照新维度索引,tensor在内存中不再是连续存储的,但view操作要求tensor的内存连续存储,需要用tensor.contiguous() 将原始的tensor调整为一个内存连续的tensor。在pytorch 0.4中,增加了torch.reshape()操作,大致相当于 tensor.contiguous().view(),这样就省去了对tensor做view()变换前,调用contiguous()的麻烦;因此建议所有情况都无脑使用 reshape

reshape 重组维度

  • 使用方法tonsor.reshape()tensor.contiguous().view()tensor.reshape(dim0, dim1, dim2...) ,将tensor的shape变换为(dim0, dim1, dim2...)dim的个数可以少于或多于原来tensor!,因为所有维度的累积 ∏ i = 0 N d i m i \prod_{i=0}^N{dim_i} i=0Ndimi是不变的,因此当有一个dim=-1时,将自动计算。
import torch
tensor = torch.randn(16, 8, 32, 16).to("cuda")
# torch.Size([16, 8, 32, 16])  (batch_size, heads, seq_len, head_dim)
tensor = tensor.reshape(16, 32, -1)  # 合头heads
# torch.Size([16, 32, 128])  (batch_size, seq_len, dim)

einsum 万能运算

  • 使用方法:爱因斯坦表达式通过操作index(dim/axis)匹配对应的矩阵运算。和前面几个操作不同的是,torch.einsum不仅可以进行单个矩阵维度的重排、重组,还可以完成多个矩阵的矩阵加法矩阵乘法元素乘法等运算

->左侧表示输入的矩阵shape,->右侧表示输出的矩阵shape

  • permute 重排:单个输入矩阵,->左右维度数量不变,只改变顺序,如交换i和j维度,ij->ji
import torch
tensor = torch.randn(16, 8, 32, 16).to("cuda")
# torch.Size([16, 8, 32, 16])
tensor = torch.einsum("bhsd->bhds", tensor)
# torch.Size([16, 8, 16, 32])
  • sum求和:单个输入矩阵,->右侧缺少哪些维度,就按照哪些维度求和,如按照j维度求和,ij->i
import torch
tensor = torch.randn(16, 8, 32, 16).to("cuda")
# torch.Size([16, 8, 32, 16])
tensor = torch.einsum("bhsd->bh", tensor)
# torch.Size([16, 8])
  • matrix multi 矩阵乘法->左边多个输入矩阵逗号分隔,->左边是单个矩阵,沿左边两者重复出现右边消失的维度进行乘法,如沿k维度进行矩阵乘法,ij,jk->ik
tensor1 = torch.randn(2, 3).to("cuda")
tensor2 = torch.randn(3, 5).to("cuda")
tensor = torch.einsum("ij, jk -> ik", tensor1, tensor2)
# (2,3) @ (3,5) = (2,5)

组合操作:先沿着j维度进行矩阵乘法,再沿着k维度进行求和:

tensor1 = torch.randn(2, 3).to("cuda")
tensor2 = torch.randn(3, 5).to("cuda")
tensor = torch.einsum("ij, jk -> i", tensor1, tensor2)
# (2,3) @ (3,5) = (2,5)

更加复杂的组合操作:模拟attention score,先自动进行转置,然后最后两个维度进行矩阵乘法,其中虽然都有seq_len,但因为output输出矩阵中不能出现两个相同的字母,所以不能都用s命名,因此使用i和j

import torch
# key 和 value 都是[batch_size, heads, seq_len, head_dim]
query = torch.randn(16, 8, 32, 16).to("cuda")
key = torch.randn(16, 8, 32, 16).to("cuda")
attention_score = torch.einsum("bhid, bhjd -> bhij", query, key)  # bhid, bhjd -> bhid, bhdj -> bhij
# torch.Size([16, 8, 32, 32])# 等价操作
attention_score = query @ key.transpose(-2, -1)
attention_score = torch.matmul(query, key.transpose(-2, -1))
  • element-wise multi 元素乘法->左边多个相同shape的矩阵,->右边单个和做左边相同shape的矩阵。矩阵对应元素相乘,也叫hadamard product
import torch
tensor1 = torch.randn(16, 8, 32, 16).to("cuda")
tensor2 = torch.randn(16, 8, 32, 16).to("cuda")tensor = torch.einsum("bhsd,bhsd->bhsd", tensor1, tensor2)
# torch.Size([16, 8, 32, 16])# 等价操作
tensor = tensor1 * tensor2
  • dot product 矩阵点积->左边多个相同shape的矩阵,->是空的(求和sum)。即,先逐元素相乘,然后全部求和
import torch
tensor1 = torch.randn(16, 8, 32, 16).to("cuda")
tensor2 = torch.randn(16, 8, 32, 16).to("cuda")tensor = torch.einsum("bhsd,bhsd-> ", tensor1, tensor2)
# tensor是一个值# 等价操作
tensor = sum(tensor1 * tensor2)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/581479.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringBoot2.7-集成Knife4j

Knife4j 是什么 Knife4j是一个集Swagger2 和 OpenAPI3为一体的增强解决方案 添加依赖 <!--引入Knife4j的官方start包,该指南选择Spring Boot版本<3.0,开发者需要注意--> <dependency><groupId>com.github.xiaoymin</groupId><artifactId>knif…

Django Cookie和Session使用(十一)

一、Cookie Cookie具体指一小段信息&#xff0c;它是服务器发送出来存储在浏览器上的一组键值对&#xff0c;下次访问服务器时浏览器会自动携带这些键值对&#xff0c;以便服务器提取有用信息。 Cookie的特性 1、服务器让浏览器进行设置的 2、保存在浏览器本地&#xff0c;…

Permutation Importance重要性

目录 <font colorblue size4 face"楷体">算法解构<font colorblue size4 face"楷体">代码 算法解构 Permutation Importance适用于表格型数据&#xff0c;其对于特征重要性的评判取决于该特征被随机重排后&#xff0c;模型表现评分的下降程度…

免费API-JSONPlaceholder使用手册

官方使用指南快速索引>>点这里 快速导览&#xff1a; 什么是JSONPlaceholder?有啥用?如何使用JSONPlaceholder? 关于“增”关于“改”关于“查”关于“删”关于“分页查”关于“根据ID查多个” 尝试自己搭一个&#xff1f;扩展的可能&#xff1f; 什么是JSONPlaceho…

面向对象(高级)知识点强势总结!!!

文章目录 一、知识点复习1-关键字&#xff1a;static1、知识点2、重点 2-单例模式&#xff08;或单子模式&#xff09;1、知识点2、重点 3-理解main()方法1、知识点2、重点 4-类的成员之四&#xff1a;代码块1、知识点2、重点 5-关键字&#xff1a;final1、知识点2、重点 6-关键…

新建虚拟环境并与Jupyter内核连接

第一步:在cmd里新建虚拟环境,shap38是新建的虚拟环境的名字 ,python=3.x conda create -n shap38 python=3.8第二步,安装ipykernel,打开anconda powershell prompt: 虚拟环境的文件夹位置,我的如图所示: 进入文件夹并复制地址: 输入复制的文件夹地址更改文件夹:…

单挑力扣(LeetCode)SQL题:534. 游戏玩法分析 III(难度:中等)

题目&#xff1a;534. 游戏玩法分析 III &#xff08;通过次数23,825 | 提交次数34,947&#xff0c;通过率68.17%&#xff09; Table:Activity----------------------- | Column Name | Type | ----------------------- | player_id | int | | device_id | int…

交换域系数的选择:图像处理与编码的关键策略

在图像处理和编码领域&#xff0c;选择适当的交换域系数对于实现高效的图像处理和编码至关重要。交换域系数是指在特定的数学变换下产生的频域系数。通过选择合适的交换域系数&#xff0c;可以实现图像的压缩、增强和重构。本文将深入探讨交换域系数的选择在图像处理和编码中的…

你好!Apache Seata

北京时间 2023 年 10 月 29 日&#xff0c;分布式事务开源项目 Seata 正式通过 Apache 基金会的投票决议&#xff0c;以全票通过的优秀表现正式成为 Apache 孵化器项目&#xff01; 根据 Apache 基金会邮件列表显示&#xff0c;在包含 13 个约束性投票 (binding votes) 和 6 个…

Qt学习:Qt的意义安装Qt

Qt 的简介 QT 是一个跨平台的 C图形用户界面应用程序框架。它为程序开发者提供图形界面所需的所有功能。它是完全面向对象的&#xff0c;很容易扩展&#xff0c;并且允许真正地组件编程。 支持平台 xP 、 Vista、Win7、win8、win2008、win10Windows . Unix/Linux: Ubuntu 等…

【ARMv8M Cortex-M33 系列 2.1 -- Cortex-M33 使用 .hex 文件介绍】

文章目录 HEX 文件介绍英特尔十六进制文件格式记录类型hex 示例Cortex-M 系列hex 文件的使用 HEX 文件介绍 .hex 文件通常用于微控制器编程&#xff0c;包括 ARM Cortex-M 系列微控制器。这种文件格式是一种文本记录&#xff0c;用于在编程时传递二进制信息。.hex 文件格式最常…

docker学习笔记02-安装mysql

1.安装mysql8 下载MySQL镜像 docker pull mysql:8.0创建并启动容器 docker run -itd --name mysqltest -p 9999:3306 -e MYSQL_ROOT_PASSWORD123456 mysql其中-it是交互界面 -d是后台执行 -name 指定容器名称 -p指定映射端口 -e设置环境变量 最后mysql是镜像名或者用镜像id如…

Flask 日志

flask 日志 代码源码源自编程浪子flask点餐小程序代码 记录用户访问日志 和 错误日志 这段代码是一个基于Flask框架的日志服务类&#xff0c;用于 记录用户访问日志 和 错误日志。代码中定义了一个名为LogService的类&#xff0c;其中包含了两个静态方法&#xff1a;addAcc…

QT C++ TCP Socket 请求心知天气

0.0 相关连接代码部分头文件具体实现 相关连接 心知天气官方天气图标 心知天气官网 代码部分 头文件 #include <QtNetwork> #include <QNetworkAccessManager> #include <QDebug> #include <QJsonValue> #include <QJsonArray> #include &l…

单挑力扣(LeetCode)SQL题:1285. 找到连续区间的开始和结束数字(难度:中等)

给题目&#xff1a;1285. 找到连续区间的开始和结束数字 &#xff08;通过次数8,111 | 提交次数9,900&#xff0c;通过率81.93%&#xff09; 表&#xff1a;Logs ------------------------ | Column Name | Type | ------------------------ | log_id | int …

XUbuntu22.04之删除多余虚拟网卡和虚拟网桥(二百零四)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 优质专栏&#xff1a;多媒…

在 docker 中安装 GLEE

1、安装 detectron2。 2、git clone GLEE。 git clone https://github.com/FoundationVision/GLEE.git docker 中没有 git&#xff0c;可以通过共享主机文件夹&#xff0c;在主机中做 git clone。 3、删除 GLEE/app/requirements.txt 中的 torch 和 torchvision&#xff0c…

用Xshell连接虚拟机的Ubuntu20.04系统记录。虚拟机Ubuntu无法上网。本机能ping通虚拟机,反之不能。互ping不通

先别急着操作&#xff0c;看完再试。 如果是&#xff1a;本机能ping通虚拟机&#xff0c;反之不能。慢慢看到第8条。 如果是&#xff1a;虚拟机不能上网&#xff08;互ping不通&#xff09;&#xff0c;往下一直看。 系统是刚装的&#xff0c;安装步骤&#xff1a;VMware虚拟机…

【Linux】 last 命令使用

last 命令 用于检索和展示系统中用户的登录信息。它从/var/log/wtmp文件中读取记录&#xff0c;并将登录信息按时间顺序列出。 著者 Miquel van Smoorenburg 语法 last [-R] [-num] [ -n num ] [-adiox] [ -f file ] [name...] [tty...]last 命令 -Linux手册页 选项及作用…

vue项目表单使用正则过滤ip、手机号

import useFormValidate from /hooks/useFormValidatesetup(props, { emit }) {const { validateName, validateIPAndPort } useFormValidate()const state reactive({workFaceInfo: props.info?.id ? props.info : {},sysTypeData: props.sysType,formRules: {name: [{req…