einops和einsum:直接操作张量的利器

einops和einsum:直接操作张量的利器

einops和einsum是Vision Transformer的代码实现里出现的两个操作tensor维度和指定tensor计算的神器,在卷积神经网络里不多见,本文将介绍简单介绍一下这两样工具,方便大家更好地理解Vision Transformer的代码。

einops:直接操作tensor维度的神器

github地址:https://github.com/arogozhnikov/einops

einops:灵活和强大的张量操作,可读性强和可靠性好的代码。支持numpy、pytorch、tensorflow等。

有了他,研究者们可以自如地操作张量的维度,使得研究者们能够简单便捷地实现并验证自己的想法,在Vision Transformer等需要频繁操作张量维度的代码实现里极其有用。

这里简单地介绍几个最常用的函数。

安装

einops的安装非常简单,直接pip即可:

pip install einops

rearrange

import torch
from einops import rearrangei_tensor = torch.randn(16, 3, 224, 224)		# 在CV中很常见的四维tensor: (N,C,H,W)
print(i_tensor.shape)
o_tensor = rearrange(i_tensor, 'n c h w -> n h w c')
print(o_tensor.shape)

输出:

torch.Size([16, 3, 224, 224])
torch.Size([16, 224, 224, 3])

在CV中很常见的四维tensor:(N,C,H,W),即表示(批尺寸,通道数,图像高,图像宽),在Vision Transformer中,经常需要对tensor的维度进行变换操作,rearrange函数可以很方便地、很直观地操作tensor的各个维度。

除此之外,rearrange还有稍微进阶一点的玩法:

 
i_tensor = torch.randn(16, 3, 224, 224)
o_tensor = rearrange(i_tensor, 'n c h w -> n c (h w)')
print(o_tensor.shape)  
o_tensor = rearrange(i_tensor, 'n c (m1 p1) (m2 p2) -> n c m1 p1 m2 p2', p1=16, p2=16)
print(o_tensor.shape)  

输出:

torch.Size([16, 3, 50176])
torch.Size([16, 3, 14, 16, 14, 16])

可以进行指定维度的合并和拆分,注意拆分时需要在变换规则后面指定参数。

repeat

from einops import repeati_tensor = torch.randn(3, 224, 224)  
print(i_tensor.shape)
o_tensor = repeat(i_tensor, 'c h w -> n c h w', n=16)  
print(o_tensor.shape)

repeat时记得指定右侧repeat之后的维度值

输出:

torch.Size([3, 224, 224])
torch.Size([16, 3, 224, 224])

reduce

from einops import reducei_tensor = torch.randn((16, 3, 224, 224))
o_tensor = reduce(i_tensor, 'n c h w -> c h w', 'mean')
print(o_tensor.shape)
o_tensor_ = reduce(i_tensor, 'b c (m1 p1) (m2 p2)  -> b c m1 m2 ', 'mean', p1=16, p2=16)
print(o_tensor_.shape)

输出:

torch.Size([3, 224, 224])
torch.Size([16, 3, 14, 14])

reduce时记得指定左侧要被reduce的维度值

Rearrange

import torch
from torch.nn import Sequential, Conv2d, MaxPool2d, Linear, ReLU
from einops.layers.torch import Rearrangemodel = Sequential(Conv2d(3, 64, kernel_size=3),MaxPool2d(kernel_size=2),Rearrange('b c h w -> b (c h w)'),      # 相当于 flatten 展平的作用Linear(64*15*15, 120), ReLU(),Linear(120, 10)
)i_tensor = torch.randn(16, 3, 32, 32)
o_tensor = model(i_tensor)
print(o_tensor.shape)

输出:

torch.Size([16, 10])

einops.layers.torch.Rearrange 是nn.Module的子类,可以放在网络里面直接当作一层。

torch.einsum:爱因斯坦简记法

爱因斯坦简记法:是一种由爱因斯坦提出的,对向量、矩阵、张量的求和运算 ∑\sum求和简记法

在该简记法当中,省略掉的部分是:

  1. 求和符号 ∑\sum
  2. 求和号的下标 iii

省略规则为:默认成对出现的下标(如下例1中的 iii 和例2中的 kkk )为求和下标,被省略。

1)xiyix_iy_ixiyi简化表示内积 <x,y><\mathbf{x},\mathbf{y}><x,y>
xiyi:=∑ixiyi=ox_iy_i := \sum_i x_iy_i = o xiyi:=ixiyi=o

其中o为输出。

  1. XikYkjX_{ik}Y_{kj}XikYkj 简化表示矩阵乘法 XY\mathbf{X}\mathbf{Y}XY
    XikYkj:=∑kXikYkj=OijX_{ik}Y_{kj}:=\sum_k X_{ik}Y_{kj}=\mathbf{O}_{ij} XikYkj:=kXikYkj=Oij
    其中 Oij\mathbf{O}_{ij}Oij 为输出矩阵的第ij个元素。

这样的求和简记法,能够以一种统一的方式表示各种各样的张量运算(内积、外积、转置、点乘、矩阵的迹、其他自定义运算),为不同运算的实现提供了一个统一模型。

einsum在numpy和pytorch中都有实现,下面我们以在torch中为例,展示一下最简单的用法

import torchi_a = torch.randn(16, 32, 4, 8)
i_b = torch.randn(16, 32, 8, 16)out = torch.einsum('b h i j, b h j d -> b h i d', i_a, i_b)
print(out.shape)

输出:

torch.Size([16, 32, 4, 16])

可以看到,torch.einsum可以简便地指定tensor运算,输入的两个tensor维度分别为 bhijb\ h\ i\ jb h i jbhjdb\ h\ j\ db h j d ,经过tensor运算后,得到的张量维度为 bhidb\ h\ i\ db h i d 。代码运行结果与我们的预期一致。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/532858.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

php的filter input,php中filter_input函数用法分析

本文实例分析了php中filter_input函数用法。分享给大家供大家参考。具体分析如下&#xff1a;在 php5.2 中,内置了filter 模块,用于变量的验证和过滤,过滤变量等操作&#xff0c;这里我们看下如何直接过滤用户输入的内容.fliter 模块对应的 filter_input 函数使用起来非常的简单…

COCO 数据集格式及mmdetection中的转换方法

COCO 数据集格式及mmdetection中的转换方法 COCO格式 CV中的目标检测任务不同于分类&#xff0c;其标签的形式稍为复杂&#xff0c;有几种常用检测数据集格式&#xff0c;本文将简要介绍最为常见的COCO数据集的格式。 完整的官方样例可自行查阅&#xff0c;以下是几项关键的…

php获取h1,jQuery获取h1-h6标题元素值方法实例

本文主要介绍了jQuery实现获取h1-h6标题元素值的方法,涉及$(":header")选择器操作h1-h6元素及事件响应相关技巧,需要的朋友可以参考下&#xff0c;希望能帮助到大家。1、问题背景&#xff1a;查找到h1-h6&#xff0c;并遍历它们&#xff0c;打印出内容2、实现代码&am…

在导入NVIDIA的apex库时报错 ImportError cannot import name ‘UnencryptedCookieSessionFactoryConfig‘ from

在导入NVIDIA的apex库时报错 ImportError: cannot import name ‘UnencryptedCookieSessionFactoryConfig’ from ‘pyramid.session’ (unknown location) 报错 在使用NVIDIA的apex库时报错 ImportError: cannot import name ‘UnencryptedCookieSessionFactoryConfig’ fro…

php怎么取request,PHP-如何在Guzzle中获取Request对象?

我需要使用Guzzle检查数据库中的很多项目.例如,项目数量为2000-5000.将其全部加载到单个数组中太多了,因此我想将其分成多个块&#xff1a;SELECT * FROM items LIMIT100.当最后一个项目发送到Guzzle时,则请求下一个100个项目.在“已满”处理程序中,我应该知道哪个项目得到了响…

[2021-CVPR] Jigsaw Clustering for Unsupervised Visual Representation Learning 论文简析及关键代码简析

[2021-CVPR] Jigsaw Clustering for Unsupervised Visual Representation Learning 论文简析及关键代码简析 论文&#xff1a;https://arxiv.org/abs/2104.00323 代码&#xff1a;https://github.com/dvlab-research/JigsawClustering 总结 本文提出了一种单批次&#xff0…

java jps都卡死,java长时间运行后,jps失效

在部署完应用后&#xff0c;原本jps使用的好好的&#xff0c;能正确的查询到自己正在运行的java程序。但&#xff0c;过了一段时间后&#xff0c;再使用jps来查看运行的应用时&#xff0c;自己运行的程序都看不到&#xff0c;但是自己也没有关闭这些程序啊&#xff01;然而使用…

指针(*)、取地址()、解引用(*)与引用()

指针(*)、取地址(&)、解引用(*)与引用(&) C 提供了两种指针运算符&#xff0c;一种是取地址运算符 &&#xff0c;一种是间接寻址运算符 *。 指针是一个包含了另一个变量地址的变量&#xff0c;您可以把一个包含了另一个变量地址的变量说成是"指向"另一…

matlab电类,985电气研二,有发过考研经验贴 电气电力类的有

该楼层疑似违规已被系统折叠 隐藏此楼查看此楼clc;clear;p[2.259;2.257;2.256;2.254;2.252;2.248;2.247;2.245;2.244;2.243;2.239;2.238;2.236;2.235;2.234;2.231;2.229;2.228;2.226;2.225;2.221;2.220;2.219;2.217;2.216;2.211;2.209;2.208;2.207;2.206;2.202;2.201;2.199;2.1…

matlab legend 分块,matlab legend 分块!

matlab legend 分块&#xff01;(2013-03-26 18:07:38)%%%压差clc;clear all;figure(55);set (gcf,Position,[116 123 275 210],color,w);P[25 26 27 28 29 30 31 32 33 34 35];%理论q0.00006*pi*28*P*10^(6)*0.03^3/(12*0.028448*5);q1110.00006*pi*28*P*10^(6)*0.03^3/(12*0.…

利用opencv-python绘制多边形框或(半透明)区域填充(可用于分割任务mask可视化)

利用opencv-python绘制多边形框或&#xff08;半透明&#xff09;区域填充&#xff08;可用于分割任务mask可视化&#xff09; 本文主要就少opencv中两个函数polylines和fillPoly分别用于绘制多边形框或区域填充&#xff0c;并会会以常见用途分割任务mask&#xff08;还是笔者…

matlab与maple互联,Matlab,Maple和Mathematica三款主流科学计算软件的互操作

本文根据网上零散的信息以及这三款软件自带的说明文档整理而成&#xff0c;为备忘而记录。记录了Matlab和Maple之间的相互调用&#xff0c;以及Matlab和Mathematica之间相互调用的安装配置方法。为何需要互操作&#xff1f; 数值计算和图形方面Matlab毫无疑问是最强的&a…

PyTorch中的topk方法以及分类Top-K准确率的实现

PyTorch中的topk方法以及分类Top-K准确率的实现 Top-K 准确率 在分类任务中的类别数很多时&#xff08;如ImageNet中1000类&#xff09;&#xff0c;通常任务是比较困难的&#xff0c;有时模型虽然不能准确地将ground truth作为最高概率预测出来&#xff0c;但通过学习&#…

java高级语言特性,Java高级语言特性之注解

注解的定义Java 注解(Annotation)又称 Java 标注&#xff0c;是 JDK1.5 引入的一种注释机制。注解是元数据的一种形式&#xff0c;提供有关于程序但不属于程序本身的数据。注解对它们注解的代码的操作没有直接影响。注解本身没有任何意义&#xff0c;单独的注解就是一种注释&am…

C/C++中的typedef 和 #define

C/C中的typedef 和 #define typedef C/C中的关键字typedef允许用户为类型名来起一个新名字&#xff0c;通常会是缩写或者能够清晰表明类型含义的新名字。 例&#xff1a; typedef unsigned int UINT; UINT 100;值得注意的是&#xff0c;typedef除了为C/C内置的数据类型取别…

php3.2.3 升级,thinkphp3.2.3 升级到3.2.4时出错问题

有些项目最初用OneThink做的&#xff0c;而OneThink 默认使用的TP 是3.2.0 的&#xff0c;没事的时候就想给升级一下&#xff0c;但是直接复制进去的时候&#xff0c;有错误&#xff0c;导致OneThink 不能运行&#xff0c;排查后&#xff0c;需要修改两个地方1、修改 Applicati…

Positional Encodings in ViTs 近期各视觉Transformer中的位置编码方法总结及代码解析 1

Positional Encodings in ViTs 近期各视觉Transformer中的位置编码方法总结及代码解析 最近CV领域的Vision Transformer将在NLP领域的Transormer结果借鉴过来&#xff0c;屠杀了各大CV榜单。对其做各种改进的顶会论文也是层出不穷&#xff0c;本文将聚焦于各种最新的视觉trans…

mysql 分析查询语句,MySQL教程之SQL语句分析查询优化

怎么获取有功能问题的SQL1、经过用户反应获取存在功能问题的SQL2、经过慢查询日志获取功能问题的SQL3、实时获取存在功能问题的SQL运用慢查询日志获取有功能问题的SQL首要介绍下慢查询相关的参数1、slow_query_log 发动定制记载慢查询日志设置的办法&#xff0c;能够经过MySQL指…

关于PyTorch中的register_forward_hook()函数未能执行其中hook函数的问题

关于PyTorch中的register_forward_hook()函数未能执行其中hook函数的问题 Hook 是 PyTorch 中一个十分有用的特性。利用它&#xff0c;我们可以不必改变网络输入输出的结构&#xff0c;方便地获取、改变网络中间层变量的值和梯度。这个功能被广泛用于可视化神经网络中间层的 f…

geoda权重矩阵导入matlab,空间计量经济学-分析解析.ppt

厦门大学 邓明 空间截面回归模型 地理加权回归模型 地理加权回归模型扩展了普通线性回归模型。在GWR模型中&#xff0c;特定区位的回归系数不再是利用全部信息获得的假定常数&#xff0c;而是利用邻近观测值的子样本数据信息进行局域(Local)回归估计而得&#xff0c;并随着空间…