【PyTorch函数解析】einsum的用法示例

一、前言

einsum 是一个非常强大的函数,用于执行张量(Tensor)运算。它的名称来源于爱因斯坦求和约定(Einstein summation convention),在PyTorch中,einsum 可以方便地进行多维数组的操作和计算。

在Transfomer中,einsum用的非常多,比如使用 einsum 实现自注意力机制中注意力权重的获取,也就是Q和K的内积:

  • Q(Query):形状为 (batch_size, seq_len, d_k)

  • K(Key):形状为 (batch_size, seq_len, d_k)

import torch
import torch.nn.functional as FQ = torch.randn(2, 10, 64)  # (batch_size, seq_len, d_k)
K = torch.randn(2, 10, 64)  # (batch_size, seq_len, d_k)# (batch_size, seq_len, seq_len)
attention_scores = torch.einsum('bqd,bkd->bqk', Q, K) / torch.sqrt(torch.tensor(64.0))
# (batch_size, seq_len, seq_len)   
attention_weights = F.softmax(attention_scores, dim=-1)  

二、常见用法示例

2.1 向量点积

a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
result = torch.einsum('i,i->', a, b)
print(result)  # 输出 32

这里,'i,i->' 表示对向量 a 和 b 进行点积操作,其中 i 是索引表示,-> 之后为空表示求和。

2.2 矩阵乘法

A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])
result = torch.einsum('ij,jk->ik', A, B)
print(result)  # 输出 tensor([[19, 22], [43, 50]])

这里,'ij,jk->ik' 表示矩阵乘法,其中 i 和 k 是结果的维度,j 是求和维度。

2.3 批量矩阵乘法

A = torch.randn(2, 3, 4)
B = torch.randn(2, 4, 5)
result = torch.einsum('bij,bjk->bik', A, B)

这里,'bij,bjk->bik' 表示对批量的矩阵进行乘法运算。

解释:

bij,bjk分别是A和B的3个维度,用字符串的形式指代。

为什么最后得到的是bik呢?这个和线性代数的矩阵运算规则有关系。

矩阵乘法规则:

  • 给定矩阵 A 的形状为 (m,n)

  • 给定矩阵 B 的形状为 (n,p)

  • 矩阵乘法 A×B 的结果矩阵 C 的形状为 (m,p)

在矩阵乘法中,结果矩阵的每个元素 Cik 是通过 A 的第 i 行和 B 的第 k 列的对应元素相乘并求和得到的,即:

C_{ik}=\sum_{j=1}^nA_{ij}\cdot B_{jk}

计算过程:

1. 匹配批次维度 (b)

  • 对于每个批次,独立进行矩阵乘法运算。

2. 求和维度 (j):

  • j 是两个张量中共同的维度,根据线性代数中的矩阵乘法规则,需要对 j 维度进行求和。

3. 保留和产生的维度:

  • i 来自 A,表示保留 A 的第一个维度。

  • k 来自 B,表示保留 B 的第二个维度。

经过上述分析,einsum 的结果保留了 b(批次维度)、i(来自 A 的第一个维度)和 k(来自 B 的第二个维度)。因此,结果张量的形状为 (batch_size, seq_len_i, seq_len_k),也就是 bik。

同样,延伸到4维计算的话。

torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

首先,假设 queries 和 keys 的形状为:

  • queries: (batch_size, seq_len_q, num_heads, head_dim)

  • keys: (batch_size, seq_len_k, num_heads, head_dim)

用具体变量名表示:

  • n: batch_size,批次大小。

  • q: seq_len_q,查询序列的长度。

  • k: seq_len_k,键序列的长度。

  • h: num_heads,多头注意力中的头数。

  • d: head_dim,每个头的维度。

1. 匹配批次维度 (n) 和头部维度 (h):

  • 批次大小和头部数量在两个输入张量中都是相同的,保持不变。

2. 求和维度 (d):

  • d 表示每个头的维度。在 queries 和 keys 中,d 都是最后一个维度,对这个维度进行点积运算后求和。

3. 保留和产生的维度:

  • q 来自 queries,表示查询序列的长度。

  • k 来自 keys,表示键序列的长度。

所以最后是nhqk。

2.4 转置操作

A = torch.tensor([[1, 2, 3], [4, 5, 6]])
result = torch.einsum('ij->ji', A)
print(result)  # 输出 tensor([[1, 4], [2, 5], [3, 6]])

这里,'ij->ji' 表示将矩阵进行转置操作。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/35169.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL中服务器状态变量全解(一)

MySQL 服务器维护了许多状态变量,这些变量提供了关于其操作的信息。您可以使用 SHOW [GLOBAL | SESSION] STATUS 语句来查看这些变量及其值。 GLOBAL 关键字(可选)用于显示所有连接的聚合值。这些值通常表示自MySQL服务器启动以来的累计统计…

DWC USB2.0协议学习1--产品概述

本章开始学习记录DWC_otg控制器(新思USB2.0)的特点、功能和应用。 新思USB 2.0 IP主要有两个文档需要参考: 《DesignWare Cores USB 2.0 Hi-Speed On-TheGo (OTG) Data book》 《DesignWare Cores USB 2.0 Hi-Speed On-TheGo (OTG) Progra…

Leetcode85 01矩阵中的最大矩形

题目描述 给定一个仅包含 0 和 1 、大小为 rows x cols 的二维二进制矩阵,找出只包含 1 的最大矩形,并返回其面积。 解题思路 动态规划的思想,记录每一个位置向上能到达的最大高度,和向左能到达的最大宽度。 在一个点进行遍历时…

解决IMX6ULL GPIO扩展板PWM7/8中的pwm0/period后卡死问题

前言 本篇文章主要是记录解决百问网论坛上面设置 IMX6ULL GPIO扩展板PWM7/8中的pwm0/period后卡死问题,如下图: 一、查看原理图,找出对应引脚 在这里我们如何确定哪个扩展口中的引脚输出PWM波呢?我们可以通过查看原理图。 查看…

作业6.20

1.已知网址www.hqyj.com截取出网址的每一个部分(要求,该网址不能存入文件中) 2.将配置桥接网络的过程整理成文档,发csdn 步骤i:在虚拟机设置中启用桥接模式 1. 打开VMware虚拟机软件。 2. 选择您想要配置的虚拟机,点击菜单栏中的“…

C++ 基础:指针和引用浅谈

指针 基本概念 在C中,指针是存储其他变量的内存地址的变量。 我们在程序中声明的每个变量在内存中都有一个关联的位置,我们称之为变量的内存地址。 如果我们的程序中有一个变量 var,那么&var 返回它的内存地址。 int main() {int var…

北大医院副院长李建平:用AI解决临床心肌缺血预测的难点、卡点和痛点

2024年6月14日,第六届北京智源大会在中关村展示中心开幕,海内外的专家学者围绕人工智能关键技术路径和应用场景,展开了精彩演讲与尖峰对话。在「智慧医疗和生物系统:影像、功能与仿真」论坛上,北京大学第一医院副院长、…

孩子不想上学,父母应如何教育?“强迫教育”会激起孩子反抗心理

上周末朋友聚会,都是家有上学娃的年纪,闲聊中,话题自然少不了孩子的上学问题。其中,不少朋友都有抱怨过同一个问题:孩子不想上学,即使人到了学校,心也不在学校。   事实上,孩子出现…

java复习宝典,jdbc与mysql数据库

一.java 1.面向对象知识 (1)类和对象 类:若干具有相同属性和行为的对象的群体或者抽象,类是创建对象的模板,由属性和行为两部分组成。 类是对象的概括或者抽象,对象是类的实例化。 举例:例如车有很多类型&#xf…

安卓ROM修改默认开启adb调试

要在安卓ROM中修改默认开启ADB调试,你可以遵循以下步骤。请注意,这些步骤可能因设备型号、ROM版本和具体定制方式而有所不同,但基本流程是相似的: 准备工作: 确保你有一台可root的安卓手机,并且已经解锁了…

计算机系统基础知识(下)

嵌入式系统以及软件 嵌入式系统是为了特定应用而专门构建且将信息处理过程和物理过程紧密结合为一体的专用计算机系统,这个系统目前以涵盖军事,自动化,医疗,通信,工业控制,交通运输等各个应用领域&#xff…

鸿蒙开发Ability Kit(程序框架服务):【Stage模型绑定FA模型ServiceAbility】

Stage模型绑定FA模型ServiceAbility 本小节介绍Stage模型的两种应用组件如何绑定FA模型ServiceAbility组件。 UIAbility关联访问ServiceAbility UIAbility关联访问ServiceAbility和UIAbility关联访问ServiceExtensionAbility的方式完全相同。 import { common, Want } from…

【Matlab 六自由度机器人】机器人动力学之推导拉格朗日方程(附MATLAB机器人动力学拉格朗日方程推导代码)

【Matlab 六自由度机器人】机器人动力学概述 近期更新前言正文一、拉格朗日方程的推导1. 单自由度系统2. 单连杆机械臂系统3. 双连杆机械臂系统 二、MATLAB实例推导1. 机器人模型的建立2. 动力学代码 总结参考文献 近期更新 【汇总】 【Matlab 六自由度机器人】系列文章汇总 …

Linux——31个普通信号

每种信号的含义 在Linux操作系统中,信号是一种进程间通信的方式,用于通知进程发生了某种事件。Linux中的普通信号(standard signals)有31个,每个信号都有特定的用途。以下是这31个普通信号的列表及其描述:…

为什么序列化???

跨进程调用,进行数据传输时,无法直接传递对象,需要将对象通过序列化的方式转为字节流或字符流(json),所以需要进行序列化。 需要共享数据时,直接传递对象通常是不可行的,因为对象的…

C# UDP网络通信

TCP和UDP基本概念 TCP:(Transmission Control Protocol)是一种面向连接、可靠的、基于字节流的传输层通信协议。并且提供了全双工通信,允许俩个应用直接建立一个可靠的连接 以进行数据交换 /UDP:(User Datagram Protocol):是一种无连接、不可靠、基于数据报文传输层…

每日复盘-20240625

今日关注: 20240625 六日涨幅最大: ------1--------300930--------- 屹通新材 五日涨幅最大: ------1--------300930--------- 屹通新材 四日涨幅最大: ------1--------300386--------- 飞天诚信 三日涨幅最大: ------1--------300386--------- 飞天诚信 二日涨幅最大: ------…

JVM专题十:JVM中的垃圾回收机制

在JVM专题九:JVM分代知识点梳理中,我们主要介绍了JVM为什么采用分代算法,以及相关的概念,本篇我们将详细拆分各个算法。 垃圾回收的概念 垃圾回收(Garbage Collection,GC)确实是计算机编程中的…

【Android面试八股文】你能说一说RecycleView与ListView的对比吗?着重说一下缓存策略,优缺点。

文章目录 一、考察的知识点二、RecycleView与ListView的对比2.1 布局效果2.2 Item点击事件2.3 空数据处理2.4 头尾布局2.5 局部刷新2.6 动画效果2.7 缓存机制2.7.1 层级不同2.7.2 缓存内容不同2.7.3 缓存机制2.7.4 ListView与RecyclerView缓存级别的对比2.7.4.1 ListView(两级缓…

【自然语言处理系列】探索NLP:使用Spacy进行分词、分句、词性标注和命名实体识别,并以《傲慢与偏见》与全球恐怖活动两个实例文本进行分析

本文深入探讨了scaPy库在文本分析和数据可视化方面的应用。首先,我们通过简单的文本处理任务,如分词和分句,来展示scaPy的基本功能。接着,我们利用scaPy的命名实体识别和词性标注功能,分析了Jane Austen的经典小说《傲…