发表博客之:transformer 架构 推理时候运算流程详细讲解,以及变长推理支持,小白都可以看得懂,AI推理工程师必备技能!

文章目录

  • [发表博客之:transformer 架构 推理时候运算流程详细讲解,以及变长推理支持,小白都可以看得懂,AI推理工程师必备技能!](https://cyj666.blog.csdn.net/article/details/138439826)
    • 总结一下
    • 高性能变长推理

发表博客之:transformer 架构 推理时候运算流程详细讲解,以及变长推理支持,小白都可以看得懂,AI推理工程师必备技能!

  • 大家都知道,这些大模型都是一些单元如此的重复堆叠而已,那么这个单元到底长什么样子呢?
  • 在这里,本张大帅就给你们解释的一清二楚!如果看完了我说的,你还是糊里糊涂的,请在评论区留言来打我!
  • 我们姑且称呼这个单元叫做transfomer block吧!

  • 首先这个transfomer block有一个输入,这个输入的shape是啥呢?
    • 那就是[batch_size, seq_len, hidden_dim]
    • batch_size就是表示批量大小啊!
    • seq_len就是序列长度啊!
    • hidden_dim这个大家意会一下啊!
  • 但是要注意啊,网友们,每个batch的seq_len其实常常是不一样的,这个你在心里面要记得注意啊
    • 例如batch0其实seq_len是10,batch1的seq_len是20,batch2的seq_len是30
    • 但是我们这里把他写成统一的按照最大长度30,
    • 但是你需要在心里知道batch0其实有效长度是10哦,batch1的有效长度是20!

  • transfomer block里面的第一个运算是啥呢?
    • 是个layer_norm啦!这个Op是不改变tensor的shape的!
  • 然后是一个Fc Op,那么权重的shape是啥呢?其实就是[hidden_dim , 3 * hidden_dim]
    • 也就是经过这个Op后,输出tensor的shape是[batch_size, seq_len, 3 * hidden_dim]
  • 这个难吗?这个很简单啊!
  • 也就是说目前
  • 各位看官你们看,上面的难嘛?一点也不难啊!

  • 下面继续运算,拿着这个[batch_size, seq_len, 3 * hidden_dim]的tensor继续往下运算,下面的运算是个很牛的运算方式
  • 首先将它split成三份,QKV,shape分别都是[batch_size, seq_len, hidden_dim]
  • 然后三个东西都reshape成[batch_size, seq_len, num_head, head_dim]
    • 也就是num_head * head_dim = hidden_dim
  • 到目前为止,各位看官还有疑惑吗?我相信都是没有的!
  • 然后再将QKV都transpose成[batch_size, num_head, seq_len, head_dim]
  • 接下来就是最关键的点,attention运算!
  • 先用Q*K得到的tensor shape是[batch_size, num_head, seq_len, seq_len]
    • 然后除以一个sqrt(head_dim)
    • 接着来一个softmax,得到attention_weight
    • 也就是attn_weight = softmax(Q*K / sqrt(head_dim))
    • 有的时候啊,还会多一个attn_mask,他的shape呢就是[batch_size, num_head, seq_len, seq_len]
    • 所以attn_weight = attn_weight + attn_mask
    • 至此我们得到了最终的attn_weight!
  • 最后再用attn_weight和V进行矩阵乘法得到最终的输出tensor!
    • 最终tensor的shape是[batch_size, num_head, seq_len, head_dim]
    • 最后记得把他transpose成[batch_size, seq_len, num_head, head_dim]
    • 然后再reshape成[batch_size, seq_len, hidden_dim]
  • 至此上面的运算过程就完成了!
  • 我们把他叫做attention计算过程!
  • 目前图变成下面这样啦!

  • attention层出来之后的shape就是[batch_size, seq_len, hidden_dim]
  • 然后呢,再来一个全联接层,权重shape是[hidden_dim,hidden_dim]
  • 所以出来的tensor shape还是[batch_size, seq_len, hidden_dim]
  • 至此,模型的图如下图所示。
  • 最后,来一个牛逼哄哄的add操作
  • 图变成下面这样啦!

各位老板请注意,上面的两个fc模块到底有没有bias,取决于每个模型的不同,有可能有,也有可能没有!

  • 下面的几个操作其实都是简单的啦!
  • 首先再来一个layer_norm操作!然后接着是一个fc操作!权重是[hidden_dim, intermediate_size]
    • 这个 intermediate_size 一般都是比hidden_dim大很多的!
    • 然后就是激活啦!
    • 然后又是另一个fc,权重是[intermediate_size, hidden_dim]
    • 最后是一个性感的Add操作
  • 也就是下面的图片的这样,至此我们就把到底啥是transformer block给讲完了!

总结一下

  • transformer block的输入是[batch_size, seq_len, hidden_dim],输出也是这么大,因此可以很方便的堆叠起来,例如把40个这样的block串起来!

高性能变长推理

  • 看官你好,上面的 transformer block的输入shape是[batch_size, seq_len, hidden_dim],但是由于不同的batch的seq_len是不一样的,因此这样搞肯定比较冗余!
  • 例如此时有3个batch,seq_len分别是10,20,30,原本的方案是将输入的shape搞成[3,30,hidden_dim]
  • 我们观察transformer block发现一个细节,也就是除了compute_attn模块外,
    • 其他的计算单元都是不操纵batch和seq_len维度的!例如layer_norm,fc等
    • 而只操纵hidden_dim维度的!
  • 也就是说,对于fc op,我们可以将输入只看成2维,对于layer_norm也是如此
    • 对于add操作,我们甚至可以将输入只看成1维

  • 这样我们只需要将输入搞成[10+20+30, hidden_dim]这么大的输入即可!
  • 但是在算compute_attn模块时候,我们需要额外传入seq_lens=[10,20,30]即可!
  • 如此就实现了变长推理了!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/6450.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JAVA面试题--数据库基础

连接查询 1.左连接 (左外连接)以左表为基准进行查询,左表数据会全部显示出来,右表 如果和左表匹配 的数 据则显示相应字段的数据,如果不匹配,则显示为 NULL; 2.右连接 (右外连接)以右表为基准进行查询,右表数据会全部显示出来,右…

初识Vue-脚本架(如何创建vue项目并使用)

一、介绍vue脚本架 Vue 脚手架”通常指的是 Vue CLI,是一个官方提供的命令行工具,用于快速搭建 Vue 项目。Vue CLI 提供了一套标准化的项目模板和一系列的开发工具,使得创建、管理和部署 Vue 项目变得更加简单和高效。以下是 Vue CLI 的一些…

定点乘除法

目录 一、定点乘法 1.串行乘法器 2.并行乘法器 二、定点除法 1.笔算除法 2.机器除法 一、定点乘法 1.串行乘法器 1.符号位单独处理,两数的符号位按异或运算得到,而乘积的数值部分则是两个正数相乘之积。 2.过程 (1) 由乘…

持续总结中!2024年面试必问 100 道 Java基础面试题(二十八)

上一篇地址:持续总结中!2024年面试必问 100 道 Java基础面试题(二十七)-CSDN博客 五十五、Object类有哪些常用的方法? Java中的Object类是所有Java类的根类,它位于类继承层次结构的顶端。Object类提供了一…

「2024年」前端开发常用工具函数总结 TypeScript

前言 在前端开发中,工具函数是提高代码复用率、保持代码整洁和增加开发效率的关键。使用 TypeScript 编写工具函数不仅可以帮助开发者捕捉到更多的类型错误,还可以提供更清晰的代码注释和更智能的代码补全。下面是一些在 TypeScript 中常用的前端开发工…

在Django中实现多用户角色和权限管理的方法

在Django中实现多用户角色和权限管理可以通过以下步骤实现: 定义用户角色模型:首先,定义一个用户角色模型,该模型表示不同的用户角色,例如管理员、普通用户、编辑等。 from django.db import modelsclass Role(model…

移动构造函数是否标记noexcept对性能有重要影响

1. 移动构造标记noexcept时才会被正确调用 #include <iostream> #include <string> #include <vector>class Vehicle{ public:Vehicle(){std::cout << "Vehicle default-ctor called.\n";}Vehicle(const std::string& brand) : brand_(…

Java如何获取当前日期和时间?

Java如何获取当前日期和时间&#xff1f; 本文将为您介绍 Java 中关于日期和时间获取的方法&#xff0c;以及介绍 Java 8 中获取日期和时间的全新API。 1、 System.currentTimeMillis() 获取标准时间可以使用 System.currentTimeMillis() 方法来获取&#xff0c;此方法优势是…

Hadoop生态系统的核心组件探索

理解大数据和Hadoop的基本概念 当我们谈论“大数据”时&#xff0c;我们指的是那些因其体积、速度或多样性而难以使用传统数据处理软件有效管理的数据集。大数据可以来自多种来源&#xff0c;如社交媒体、传感器、视频监控、交易记录等&#xff0c;通常包含了TB&#xff08;太…

【算法】双指针思想

一、Leetcode27.移除元素 1.题目描述 给你一个数组 nums和一个值 val&#xff0c;你需要 [原地] 移除所有数值等于 val的元素&#xff0c;并返回移除后数组的新长度。 不要使用额外的数组空间&#xff0c;你必须仅使用 O(1) 额外空间并 [原地 ]修改输入数组。 元素的顺序可以…

【C语言】详解预处理

、 最好的时光&#xff0c;在路上;最好的生活&#xff0c;在别处。独自上路去看看这个世界&#xff0c;你终将与最好的自己相遇。&#x1f493;&#x1f493;&#x1f493; 目录 •✨说在前面 &#x1f34b;预定义符号 &#x1f34b; #define • &#x1f330;1.#define定义常…

ControlNet官方资源链接【ControlNet论文原文】【持续更新中~】

ControlNet官方资源链接 ControlNet论文原文&#xff1a;https://arxiv.org/abs/2302.05543ControlNet官方GitHub&#xff1a;https://github.com/lllyasviel/ControlNetControlNet 1.1官方GitHub&#xff1a;https://github.com/lllyasviel/ControlNet-v1-1-nightlyControlNe…

phpMyAdmin增加自定义IP登录教程

phpMyAdmin增加自定义IP登录教程 1、打开phpMyAdmin目录&#xff0c; 在此目录下是否有config.sample.inc.php文件&#xff0c;如果存在&#xff0c;那么将其改名为config.inc.php&#xff08;为避免修改失误所造成的损失&#xff0c;强烈建议先备份config.sample.inc.php文件…

4_C语言复杂表达式与指针高级应用

指针数组与数组指针 字面意思来理解指针数组与数组指针 指针数组的实质是一个数组&#xff0c; 这个数组中存储的内容全部是指针变量。 数组指针的实质是一个指针&#xff0c; 这个指针指向的是一个数组。 分析指针数组与数组指针的表达式 int * p[5]; 指针数组 int (*p)[5]…

等保测评考试重点题库分享上

一、单选题 1、下列不属于网络安全测试范畴的是&#xff08;C&#xff09; A&#xff0e;结构安全 B.便捷完整性检查 C.剩余信息保护 D.网络设备防护 2、下列关于安全审计的内容说法中错误的是&#xff08;D&#xff09; A&#xff0e;应对网络系统中的网络设备运行情况、网…

UnityWebGL使用sherpa-ncnn实时语音识别

k2-fsa/sherpa-ncnn&#xff1a;在没有互联网连接的情况下使用带有 ncnn 的下一代 Kaldi 进行实时语音识别。支持iOS、Android、Raspberry Pi、VisionFive2、LicheePi4A等。 (github.com) 如果是PC端可以直接使用ssssssilver大佬的 https://github.com/ssssssilver/sherpa-ncn…

bind、call和apply

bind、call和apply都是 JavaScript 中用于改变函数执行上下文&#xff08;即函数内部的this指向&#xff09;的方法&#xff0c;它们的主要区别如下&#xff1a; bind 方法会创建一个新的函数&#xff0c;并将这个函数的执行上下文绑定到指定的对象。它不会立即执行函数&#x…

[嵌入式系统-62]:RT-Thread-内核:多核CPU SMP的支持与移植

目录 RT-Thread SMP 介绍与移植 1. 多核的优点 2. 多核启动 2.1 概述 2.2 CPU0 启动流程 2.3 次级 CPU 启动流程 3. 多核调度 3.1 任务特性 3.2 调度策略 4. SMP 内核接口 处理器间中断 IPI OS Tick 自旋锁 spinlock 任务绑定 4. SMP移植说明 编译环境准备 创…

配置网关,解决本地连接不上Linux虚拟机的问题

在Window环境下&#xff0c;使用远程终端工具连接不了VMware搭建的Linux虚拟机&#xff08;CentOS 7&#xff09;&#xff0c;并且在命令行ping不通该Linux虚拟机的IP地址。下面通过配置网关解决本地与Linux虚拟机连接问题&#xff1a; 1 查看虚拟机网关地址 在VMware虚拟机上…

opencv merge使用

OpenCV 中的 merge 函数用于将多个单通道或多通道的图像合并成一个多通道的图像。 在C中&#xff0c;OpenCV的merge函数也提供了相同的功能&#xff0c;用于合并多个单通道或多通道的图像。下面是一个使用C的示例&#xff1a; #include <opencv2/opencv.hpp> #include &…