【PyTorch函数解析】einsum的用法示例

一、前言

einsum 是一个非常强大的函数,用于执行张量(Tensor)运算。它的名称来源于爱因斯坦求和约定(Einstein summation convention),在PyTorch中,einsum 可以方便地进行多维数组的操作和计算。

在Transfomer中,einsum用的非常多,比如使用 einsum 实现自注意力机制中注意力权重的获取,也就是Q和K的内积:

  • Q(Query):形状为 (batch_size, seq_len, d_k)

  • K(Key):形状为 (batch_size, seq_len, d_k)

import torch
import torch.nn.functional as FQ = torch.randn(2, 10, 64)  # (batch_size, seq_len, d_k)
K = torch.randn(2, 10, 64)  # (batch_size, seq_len, d_k)# (batch_size, seq_len, seq_len)
attention_scores = torch.einsum('bqd,bkd->bqk', Q, K) / torch.sqrt(torch.tensor(64.0))
# (batch_size, seq_len, seq_len)   
attention_weights = F.softmax(attention_scores, dim=-1)  

二、常见用法示例

2.1 向量点积

a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
result = torch.einsum('i,i->', a, b)
print(result)  # 输出 32

这里,'i,i->' 表示对向量 a 和 b 进行点积操作,其中 i 是索引表示,-> 之后为空表示求和。

2.2 矩阵乘法

A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])
result = torch.einsum('ij,jk->ik', A, B)
print(result)  # 输出 tensor([[19, 22], [43, 50]])

这里,'ij,jk->ik' 表示矩阵乘法,其中 i 和 k 是结果的维度,j 是求和维度。

2.3 批量矩阵乘法

A = torch.randn(2, 3, 4)
B = torch.randn(2, 4, 5)
result = torch.einsum('bij,bjk->bik', A, B)

这里,'bij,bjk->bik' 表示对批量的矩阵进行乘法运算。

解释:

bij,bjk分别是A和B的3个维度,用字符串的形式指代。

为什么最后得到的是bik呢?这个和线性代数的矩阵运算规则有关系。

矩阵乘法规则:

  • 给定矩阵 A 的形状为 (m,n)

  • 给定矩阵 B 的形状为 (n,p)

  • 矩阵乘法 A×B 的结果矩阵 C 的形状为 (m,p)

在矩阵乘法中,结果矩阵的每个元素 Cik 是通过 A 的第 i 行和 B 的第 k 列的对应元素相乘并求和得到的,即:

C_{ik}=\sum_{j=1}^nA_{ij}\cdot B_{jk}

计算过程:

1. 匹配批次维度 (b)

  • 对于每个批次,独立进行矩阵乘法运算。

2. 求和维度 (j):

  • j 是两个张量中共同的维度,根据线性代数中的矩阵乘法规则,需要对 j 维度进行求和。

3. 保留和产生的维度:

  • i 来自 A,表示保留 A 的第一个维度。

  • k 来自 B,表示保留 B 的第二个维度。

经过上述分析,einsum 的结果保留了 b(批次维度)、i(来自 A 的第一个维度)和 k(来自 B 的第二个维度)。因此,结果张量的形状为 (batch_size, seq_len_i, seq_len_k),也就是 bik。

同样,延伸到4维计算的话。

torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

首先,假设 queries 和 keys 的形状为:

  • queries: (batch_size, seq_len_q, num_heads, head_dim)

  • keys: (batch_size, seq_len_k, num_heads, head_dim)

用具体变量名表示:

  • n: batch_size,批次大小。

  • q: seq_len_q,查询序列的长度。

  • k: seq_len_k,键序列的长度。

  • h: num_heads,多头注意力中的头数。

  • d: head_dim,每个头的维度。

1. 匹配批次维度 (n) 和头部维度 (h):

  • 批次大小和头部数量在两个输入张量中都是相同的,保持不变。

2. 求和维度 (d):

  • d 表示每个头的维度。在 queries 和 keys 中,d 都是最后一个维度,对这个维度进行点积运算后求和。

3. 保留和产生的维度:

  • q 来自 queries,表示查询序列的长度。

  • k 来自 keys,表示键序列的长度。

所以最后是nhqk。

2.4 转置操作

A = torch.tensor([[1, 2, 3], [4, 5, 6]])
result = torch.einsum('ij->ji', A)
print(result)  # 输出 tensor([[1, 4], [2, 5], [3, 6]])

这里,'ij->ji' 表示将矩阵进行转置操作。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/35169.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

DWC USB2.0协议学习1--产品概述

本章开始学习记录DWC_otg控制器(新思USB2.0)的特点、功能和应用。 新思USB 2.0 IP主要有两个文档需要参考: 《DesignWare Cores USB 2.0 Hi-Speed On-TheGo (OTG) Data book》 《DesignWare Cores USB 2.0 Hi-Speed On-TheGo (OTG) Progra…

解决IMX6ULL GPIO扩展板PWM7/8中的pwm0/period后卡死问题

前言 本篇文章主要是记录解决百问网论坛上面设置 IMX6ULL GPIO扩展板PWM7/8中的pwm0/period后卡死问题,如下图: 一、查看原理图,找出对应引脚 在这里我们如何确定哪个扩展口中的引脚输出PWM波呢?我们可以通过查看原理图。 查看…

作业6.20

1.已知网址www.hqyj.com截取出网址的每一个部分(要求,该网址不能存入文件中) 2.将配置桥接网络的过程整理成文档,发csdn 步骤i:在虚拟机设置中启用桥接模式 1. 打开VMware虚拟机软件。 2. 选择您想要配置的虚拟机,点击菜单栏中的“…

C++ 基础:指针和引用浅谈

指针 基本概念 在C中,指针是存储其他变量的内存地址的变量。 我们在程序中声明的每个变量在内存中都有一个关联的位置,我们称之为变量的内存地址。 如果我们的程序中有一个变量 var,那么&var 返回它的内存地址。 int main() {int var…

北大医院副院长李建平:用AI解决临床心肌缺血预测的难点、卡点和痛点

2024年6月14日,第六届北京智源大会在中关村展示中心开幕,海内外的专家学者围绕人工智能关键技术路径和应用场景,展开了精彩演讲与尖峰对话。在「智慧医疗和生物系统:影像、功能与仿真」论坛上,北京大学第一医院副院长、…

java复习宝典,jdbc与mysql数据库

一.java 1.面向对象知识 (1)类和对象 类:若干具有相同属性和行为的对象的群体或者抽象,类是创建对象的模板,由属性和行为两部分组成。 类是对象的概括或者抽象,对象是类的实例化。 举例:例如车有很多类型&#xf…

计算机系统基础知识(下)

嵌入式系统以及软件 嵌入式系统是为了特定应用而专门构建且将信息处理过程和物理过程紧密结合为一体的专用计算机系统,这个系统目前以涵盖军事,自动化,医疗,通信,工业控制,交通运输等各个应用领域&#xff…

【Matlab 六自由度机器人】机器人动力学之推导拉格朗日方程(附MATLAB机器人动力学拉格朗日方程推导代码)

【Matlab 六自由度机器人】机器人动力学概述 近期更新前言正文一、拉格朗日方程的推导1. 单自由度系统2. 单连杆机械臂系统3. 双连杆机械臂系统 二、MATLAB实例推导1. 机器人模型的建立2. 动力学代码 总结参考文献 近期更新 【汇总】 【Matlab 六自由度机器人】系列文章汇总 …

JVM专题十:JVM中的垃圾回收机制

在JVM专题九:JVM分代知识点梳理中,我们主要介绍了JVM为什么采用分代算法,以及相关的概念,本篇我们将详细拆分各个算法。 垃圾回收的概念 垃圾回收(Garbage Collection,GC)确实是计算机编程中的…

【自然语言处理系列】探索NLP:使用Spacy进行分词、分句、词性标注和命名实体识别,并以《傲慢与偏见》与全球恐怖活动两个实例文本进行分析

本文深入探讨了scaPy库在文本分析和数据可视化方面的应用。首先,我们通过简单的文本处理任务,如分词和分句,来展示scaPy的基本功能。接着,我们利用scaPy的命名实体识别和词性标注功能,分析了Jane Austen的经典小说《傲…

discuz插件之优雅草超级列表互动增强v1.2版本更新

https://doc.youyacao.com/9/2142 v1.2更新 discuz插件之优雅草超级列表互动增强v1.2版本更新 [title]20220617 v1.2发布[/title] 增加了对php8的支持 增加了 对discuz3.5的支持

RocketMQ源码学习笔记:Broker启动流程

这是本人学习的总结,主要学习资料如下 马士兵教育rocketMq官方文档 目录 1、Broker启动流程2、一些重要的类2.1、MappedFile2.2、MessgeStore2.3、MessageStore的加载启动流程 3、技术亮点3.1、 内存映射3.1.1、简介3.1.2、源码 1、Broker启动流程 Broker启动流程…

RabbitMQ中lazyqueue队列

lazyqueue队列非常强悍 springboot注解方式开启 // 使用注解的方式lazy.queue队列模式 非常GoodRabbitListener(queuesToDeclare Queue(name "lazy.queue",durable "true",arguments Argument(name "x-queue-mode",value "lazy&…

3.蓝牙模块HC-08

目录 一.简介​编辑 二.主要参数 三.模块引脚说明 四、LED指示灯状态 五.AT指令 5.1AT指令重点 5.2 AT指令注意点 5.3 AT指令集 六.AT常用指令 6.1 测试指令 AT 6.2 查询当前参数ATRX 6.3设置主从模式 ATROLE 6.4设置蓝牙模式 ATNAME 6.5 设置波特率 …

YOLOv5改进(八)--引入Soft-NMS非极大值抑制

文章目录 1、前言2、各类NMS代码实现2.1、general.py 3、各类NMS实现3.1、Soft-NMS3.2、GIoU-NMS3.3、DIoU-NMS3.4、CIoU-NMS3.5、EIoU-NMS 4、目标检测系列文章 1、前言 目前yolov5使用的是NMS进行极大值抑制,本篇文章是要将各类NMS添加到yolov5中,同时…

6.25作业

1.整理思维导图 2.终端输入两个数,判断两数是否相等,如果不相等,判断大小关系 #!/bin/bash read num1 read num2 if [ $num1 -eq $num2 ] then echo num1num2 elif [ $num1 -gt $num2 ] then echo "num1>num2" else echo &quo…

200.回溯算法:子集||(力扣)

class Solution { public:vector<int> res; // 当前子集vector<vector<int>> result; // 存储所有子集void backtracing(vector<int>& nums, int index, vector<bool>& used) {result.push_back(res); // 将当前…

【嵌入式Linux】<总览> 进程间通信(更新中)

文章目录 前言 一、管道 1. 概念 2. 匿名管道 3. 有名管道 二、内存映射区 1. 概念 2. mmap函数 3. 进程间通信&#xff08;有血缘关系&#xff09; 4. 进程间通信&#xff08;没有血缘关系&#xff09; 5. 拷贝文件 前言 在文章【嵌入式Linux】&#xff1c;总览&a…

浏览器断点调试(用图说话)

浏览器断点调试&#xff08;用图说话&#xff09; 1、开发者工具2、添加断点3、查看变量值 浏览器断点调试 有时候我们需要在浏览器中查看 html页面的js中的变量值。1、开发者工具 打开浏览器的开发者工具 按F12 &#xff0c;没反应的话按FnF12 2、添加断点 3、查看变量值

清理占道经营商贩自砸西瓜?智慧城管AI视频方案助力城市街道管理

一、背景分析 近日有新闻报道&#xff0c;在山西太原&#xff0c;城管凌晨3时许查处商贩占道经营&#xff0c;商贩将西瓜砸碎一地&#xff0c;引起热议。据悉&#xff0c;事件发生的五龙口街系当地主要街道&#xff0c;来往车辆众多。该商贩长期在该地段占道经营&#xff0c;影…