大模型基架:Transformer如何做优化?

大模型的基础模式是transformer,所以很多芯片都实现先专门的transformer引擎来加速模型训练或者推理。本文将拆解Transformer的算子组成,展开具体的数据流分析,结合不同的芯片架构实现,分析如何做性能优化。

Transformer结构

transformer结构包含两个过程,Encoder和Decoder。其中Decoder较Encoder结构相同,多了对于kv_cache的处理。

如下图经典的结构示意图,可以看到在Decoder阶段的Multi-Head Attentiond的三个输入箭头其中两个来自Encoderde输出,关于kv-cache对内容管理的优化也是一个很重要的研究方向。本文暂时重点关注与Transformer的Encoder阶段的优化分析。

Transformer的数据流图

下图对应上面transformer的左边Encoder阶段。不同颜色表示不同的算子,其中linear, 其实也是一种matmul算子,只不过它的两个输入一个来自tensor, 一个来自常量。蓝色标记的matmul算子则两个输入全部是tensor。

包含的算子为:linear, matmul, transpose, softmax, add_layernorm。

通过代入参数,了解具体的数据流执行过程,可以让我们更加直观的理解下面的优化之后,得到相同的输出数据的思路。

优化设计1:图优化

根据上面的数据流图可以发现,transpose算子只是对数据进行重排,并不需要计算,但是过多的transpose算子需要不停从内存搬移数据,消耗紧缺的带宽资源,所以一个简单的优化点就是通过硬件架构的设计,来减少transpose层。

对硬件来说,在实现GEMM算子是的时候,对两个矩阵取数过程,增加一个transpose的逻辑, 不会消耗很多的资源,所以可以对GEMM的两个输入数据,分别设计是否打开transpose的参数。

假设GEMM算子原始的数据存放排布矩阵A为(batch, M, K), 矩阵B为(batch, K, N)。得到的输出为(Batch, M , N)。下面对transpose的多头注意力模块进行优化,示例了两种方案,来减少单独的transpose算子开销。

transpose前置(A_transpose_en)

利用矩阵A的transpose开关,将q, k, v的transpose前置, 数据流图如下,这样可以将原本的5个transpose操作减小为2个。

注意图中用红色和蓝色标记了GEMM算子的矩阵A,矩阵B的设定,当一个linear或者matmul算子的两个输入中显示(Batch, K, M)时候,即认为打开了GEMM算子的A矩阵transpose开关

transpose内置(B_transpose_en)

当利用B矩阵的transpose_en功能,优化后的数据流图如下。在QV的matmul计算过程,逆向利用矩阵B的transpose开关,这样可以将原本的5个transpose操作减小为1个。

当一个linear或者matmul算子的两个输入中显示(Batch, N, K)时候,即认为打开了GEMM算子的B矩阵transpose开关

 ​​​​​​​​​​​​​​

通过上面两个方案,大家可能会对attnV_matmul那一步的数据流关于head位置有点疑问,在这里我们不妨这样考虑,将head分给多个thread线程来做,只要thread的数据位置取的对,是可以将(batch, head,seq_len_q, seq_len_k)和(batch, seq_len_v, head, hidden/head)进行矩阵乘得到(batch, head, seq_len_q, hidden/head)的输出的。

优化设计2:任务并行拆解

模型的分布式并行策略有数据并行,张量并行,pipline 并行等,这些策略的一个要点就是合理利用集群资源,让更多的任务并行基础上,减少中间节点的数据通信。

当我们在一个有很多节点的集群上部署大模型时候,因为模型数据维度较大,往往需要将其拆解到不同的芯片(集群)运行,尤其是GEMM算子,不同的拆分方案对应不同的通信开销。下面我们来具体分析一个任务并行的拆解方案。

如图,首先针对attention模块的多头特征,选择在qkv_linear的weights的outZ方向切分为head份,假设有head个计算节点,每个节点计算1个head的matmul任务,因为没有在累加的维度拆分,所以这样每个节点可以顺序执行下一层任务,不需要交互数据。直到attnV_matmul之后,需要做fc0_linear的任务,要把所有的head合并起来累加运算,所以增加了all_gather的通信开销。接着为了避免通信开销,fc0和add_layernorm选择在seq维度拆分。当到达fc1_linear,对depth_hidden进行了拆分,但是fc2_linear需要对所有的depth_hidden进行累加,所以fc2_linear之前需要再一次的all_gather通信。

当然根据具体的硬件条件限制,还可以有其他的任务拆解方案,总之,需要具体场景具体分析。这里仅做简单的优化示例参考。

欢迎评论交流,如果觉得内容有帮助,需要您的点赞鼓励!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/22988.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

go的反射和断言

在go中对于一个变量,主要包含两个信息变量类型(type)和变量值(value) 可以通过reflect包在运行的时候动态获取变量信息,并能够进行操作 对于Type可以通过reflect.TypeOf()获取到变量的类型信息 reflect.Ty…

13_前端工程化_ES6

1.前端工程化概念 前端工程化是使用软件工程的方法来单独解决前端的开发流程中模块化、组件化、规范化、自动化的问题,其主要目的为了提高效率和降低成本。 前后端分离(前端代码工程化独立出来形成一个单独的app) 1.开发分离 2.部署分离 3.服务器分离…

信号(上)

本节目标: 1. 掌握Linux信号的基本概念 2. 掌握信号产生的一般方式 3. 理解信号递达和阻塞的概念,原理。 4. 掌握信号捕捉的一般方式。 5. 重新了解可重入函数的概念。 6. 了解竞态条件的情景和处理方式 7. 了解SIGCHLD信号, 重新编写信号处理…

ChatGPT基本原理详细解说

ChatGPT基本原理详细解说 引言 在人工智能领域,自然语言处理(NLP)一直是研究的热点之一。随着技术的发展,我们见证了从简单的聊天机器人到复杂的语言模型的演变。其中,ChatGPT作为一项突破性技术,以其强大…

2004NOIP普及组真题 2. 花生采摘

线上OJ: 【04NOIP普及组】花生采摘 核心思想: 1、本题为贪心即可。 2、因为本题严格限制了顺序,所以先把每个节点的花生数量按降序排序。然后逐一判断下一个花生是否需要去采摘即可 3、每一次采摘完,记录耗时 t 以及采集的花…

基于web的垃圾分类回收系统的设计

管理员账户功能包括:系统首页,个人中心,管理员管理,用户管理,公告管理,运输管理,基础数据管理 用户账户功能包括:系统首页,个人中心,运输管理,公告…

pyqt QlineEdit内部增加按钮方法

pyqt QlineEdit内部增加按钮方法 def addButton(self,lineEdit):btn QtWidgets.QPushButton("")icon1 QtGui.QIcon()icon1.addPixmap(QtGui.QPixmap(":/image/images/th.png"), QtGui.QIcon.Normal, QtGui.QIcon.Off)btn.setIcon(icon1)btn.setStyleShe…

全光谱led灯的危害有哪些?曝光低质量全光谱led灯产生的四大风险

眼睛是人类获取信息最重要的感官器官之一,而近视则会导致视力模糊,进而影响学习效果和生活品质。因此,如何保护眼睛,尤其是在学习和使用电子设备时,成为了一个迫切需要解决的问题。然而在护眼领域上,护眼台…

SCAU 数据结构 实验六 排序算法

![[Pasted image 20240 8638 直接插入排序 Description 用函数实现直接插入排序,并输出每趟排序的结果. 输入格式 第一行:键盘输入待排序关键的个数n 第二行:输入n个待排序关键字,用空格分隔数据 输出格式 每行输出一趟排序…

十三、resultMap解析

分为两部分:解析和使用 解析 1.解析XML的时候单独解析所有的resultMap标签,封装成ResultMap对象存入configuration中 2.解析XML中的SQL语句,封装MappedStatement对象,这里会根据SQL的返回类型是resultMap还是resultType做处理。如…

C语言 | Leetcode C语言题解之第133题克隆图

题目: 题解: struct Node** visited; int* state; //数组存放结点状态 0:结点未创建 1:仅创建结点 2:结点已创建并已填入所有内容void bfs(struct Node* s) {if (visited[s->val] && state[s->val] 2…

Python Lambda函数的应用实例教程

在Python编程中,lambda函数是一种简洁且强大的工具,用于创建小型匿名函数。它们在需要快速定义简单函数时特别有用。本文将详细介绍lambda函数的语法及其多种应用实例,帮助读者更好地理解和使用lambda函数。 一、lambda函数的基本概念 1.1 什…

c++(内存分配,构造,析构)

#include <iostream>using namespace std; class Per { private:string name;int age;double *height;double *weigh; public://无参构造Per(){cout << "Per::无参构造" << endl;}//有参构造Per(string name,int age,double height,double weigh):…

【TB作品】 51单片机8x8点阵显示滚动汉字仿真

功能 题目5基于51单片机LED8x8点阵显示 流水灯 直接滚动显示HELLO 直接滚动显示老师好 代码 void main( void ) {/** 移位后&#xff0c;右边的是第一个595&#xff0c;接收0X02&#xff0c;显示出0X02* 移位后&#xff0c;左边的是第2个595&#xff0c;接收0Xfe&#xff0c…

创建常规DLL的动态链接库

本文仅供学习交流&#xff0c;严禁用于商业用途&#xff0c;如本文涉及侵权请及时联系本人将于及时删除 【例9.3】创建一个MFC 常规DLL的动态链接库Areadll&#xff0c;在该动态链接库中添加一个导出类CArea&#xff0c;通过该类获取正方形和圆的面积。 (1) 使用“MFC动态链接…

Allegro器件角度倾斜如何回正?

Allegro器件角度倾斜,坐标含有小数点调整为45度整数倍的方法 Allegro器件角度倾斜回正的方法。 在用Allero进行PCB设计过程中,有时候由于误操作;或者刚开始器件需要非45度整数倍的角度,后又需要调整为整数倍的角度。器件角度倾斜含有小数点调整为45度整数倍的方法。 1、如…

Arduino网页服务器:如何将Arduino开发板用作Web服务器

大家好&#xff0c;我是咕噜铁蛋&#xff01;今天&#xff0c;我将和大家分享一个有趣且实用的项目——如何使用Arduino开发板搭建一个简易的网页服务器。通过这个项目&#xff0c;你可以将Arduino连接到互联网&#xff0c;并通过网页控制或查询Arduino的状态。 一、项目背景与…

vue实现pdf下载——html2canvas

html2canvas 官方文档https://html2canvas.hertzen.com/getting-started html2canvas 的原理是通过遍历DOM树,将每一个HTML元素转化为Canvas对象,并叠加到一起形成一张完整的图片或者PDF文件。 1. 安装插件 npm install html2canvas jspdf --save 2.使用&#xff08;页面已经…

git 的基本操作 Master and branch的版本合并 @ VS 1019

前言&#xff1a; 在VS 2019有git 的可视化管理,但&#xff0c;感觉微软其实就是在git上包了一层。版本冲突后&#xff0c;还是要靠git 的命令行代码搞。本文记录了一次&#xff0c;branch和master的版本合并的过程。作为&#xff0c;后续的参考。 【注意&#xff0c;这个是一…

【二进制部署k8s-1.29.4】十三、metrics-server的安装部署

文章目录 简介 一.metrics-server的安装 简介 本章节主要讲解metrics-server的安装&#xff0c;metrics-server主要是用于采集k8s中节点和pod的内存和cpu指标&#xff0c;在观察几点和pod的实时资源使用情况还是比较有用的&#xff0c;如果需要记录历史信息&#xff0c;建议采用…