GiantPandaCV | FasterTransformer Decoding 源码分析(二)-Decoder框架介绍

本文来源公众号“GiantPandaCV,仅用于学术分享,侵权删,干货满满。

原文链接:FasterTransformer Decoding 源码分析(二)-Decoder框架介绍

作者丨进击的Killua

来源丨https://zhuanlan.zhihu.com/p/669303360

编辑丨GiantPandaCV

Decoder模块是FasterTransformer Decoding model中最核心的处理模块,在GiantPandaCV | FasterTransformer Decoding 源码分析(一)-整体框架介绍一文中详细介绍了Decoder模块在整体中所处的位置,本文试图从流程框架层面对该模块进行源码分析,梳理出主要处理模块,后续再逐步对各个模块实现进行解析。

一、整体框架

Decoder在整体解码过程中的位置

代码地址:link

下图中左边是经典的Transformer Decoder结构,右边是FasterTransformer Decoder结构,主要有以下几点区别

  1. 将最后的LayerNorm提前到入口,这里并不能加速流程,但是这种顺序在实践中表现得比较好,允许模型更好地调整输入的分布,使其更适合通过self-attention进行处理,最后处理完会在调用外层再做一次LayerNorm。

  2. 将 SelfAttention和CrossAttention中最后一个 Linear 的 Add Bias,Add Res(残差连接)以及 LayerNorm 合并成一个 (Add Bias & Add Res & LayerNorm) Kernel,降低 Kernel Launch 开销以及提升访问带宽。

  3. 将 FFN 的最后一个 Linear 的 Add Bias,Add Res(残差连接)合并成一个 (Add Bias & Add Res) Kernel,降低 Kernel Launch 开销以及提升访问带宽。

Decoder具体处理流程

二、数据处理流

接下来结合框架图来解析下forward函数的数据处理流程,整体流程在代码上还是非常清晰的。

Input & Output

template<typename T>
void Decoder<T>::forward(std::vector<Tensor>*                      output_tensors,const std::vector<Tensor>*                input_tensors,const std::vector<DecoderLayerWeight<T>>* decoder_layer_weight)
{// input tensors://      decoder_input [batch_size, hidden_dimension],//      encoder_output [batch_size, mem_max_seq_len, memory_hidden_dimension],//      encoder_sequence_length [batch_size],//      finished [batch_size],//      step [1] on cpu//      sequence_lengths [batch_size]//      cache_indirection [local_batch_size / beam_width, beam_width, max_seq_len]//              Here, local_batch_size contains the beam_width, so local_batch_size / beam_width//              is real local_batch_size.// output tensors://      decoder_output [batch_size, hidden_dimension],//      key_cache [num_layer, batch, head_num, size_per_head // x, max_seq_len, x]//      value_cache [num_layer, batch, head_num, max_seq_len, size_per_head]//      key_mem_cache [num_layer, batch_size, mem_max_seq_len, hidden_dimension],//      value_mem_cache [num_layer, batch_size, mem_max_seq_len, hidden_dimension]

这里初看其实是不知道这些输入输出shape背后的含义的,没关系这里先做个标记,等我们全部都看完了再回过头来看这里的意义。我们可以大致知道Decoder的输入tensor中包含:

  1. batch_size个单词的embedding表示或上一个step的解码输出。[batch_size, hidden_dimension]

  2. encoder层的输出。[batch_size, mem_max_seq_len, memory_hidden_dimension]

  3. encoder层输入序列的实际长度。[batch_size]

  4. batch中是否已经解码完成。[batch_size]

  5. 当前解码的步长。

  6. 已解码句子的序列长度。[batch_size]

  7. 中间缓存。(这个暂时还无法理解)

注:这里的batch_size实际是batch_size * beam_size的结果,即对每个batch的beam_size个词分别解码。

Decoder的输出tensor包含:

  1. batch个解码器的词向量输出。[batch_size, hidden_dimension]

  2. self-attention中前面steps所计算出来的key buffer。[num_layer, batch, head_num, size_per_head // x, max_seq_len, x],其中 x =4(FP32), x=8(FP16).

  3. self-attention中前面steps所计算出来的value buffer。

  4. cross-attention中前面steps所计算出来的key buffer。

  5. cross-attention中前面steps所计算出来的value buffer。

逐层解码

decoder是逐层进行解码的,接下来每层都会使用以下这些模块进行推理。

Cache

        size_t self_key_cache_offset = l;for (auto t = output_tensors->at(1).shape.begin() + 1; t != output_tensors->at(1).shape.end(); ++t) {self_key_cache_offset *= (*t);}size_t self_value_cache_offset = l;for (auto t = output_tensors->at(2).shape.begin() + 1; t != output_tensors->at(2).shape.end(); ++t) {self_value_cache_offset *= (*t);}

这里是对cache的索引,cache是fastertransformer性能优化的一大重点,思想很简单,就是复用前面step计算的结果,避免重复计算,以空间来换时间。代码中对self-attention和cross-attention中线性化处理后的key和value进行了缓存。针对cross-attention,因为key和value是来自于encoder的输出(如图所示),所以每个step上使用的key和value是相同的。

但是针对self-attention,key和value这里笔者还没完全理解为什么可以复用,这里也先留个标记(self_attention的key,value和query的生成逻辑可能不一样)。

LayerNorm

        invokeGeneralLayerNorm(decoder_normed_input_,decoder_input,decoder_layer_weight->at(l).pre_layernorm_weights.gamma,decoder_layer_weight->at(l).pre_layernorm_weights.beta,layernorm_eps_,batch_size,hidden_units_,(float*)nullptr,0,stream_);

这里调用layernorm的kernel函数进行处理,我们后续单独介绍kernel实现。

SelfAttention

        TensorMap self_attention_input_tensors{{"input_query", Tensor{MEMORY_GPU, data_type, {batch_size, hidden_units_}, decoder_normed_input_}},{"finished", input_tensors->at(3)},{"sequence_lengths", input_tensors->at(5)},{"step", input_tensors->at(4)}};self_attention_input_tensors.insertIfValid("cache_indirection", input_tensors->at(6));TensorMap self_attention_output_tensors{{"hidden_features", Tensor{MEMORY_GPU, data_type, {batch_size, hidden_units_}, self_attn_output_}},{"key_cache",Tensor{MEMORY_GPU,data_type,std::vector<size_t>(output_tensors->at(1).shape.begin() + 1, output_tensors->at(1).shape.end()),output_tensors->at(1).getPtrWithOffset(self_key_cache_offset)}},{"value_cache",Tensor{MEMORY_GPU,data_type,std::vector<size_t>(output_tensors->at(2).shape.begin() + 1, output_tensors->at(2).shape.end()),output_tensors->at(2).getPtrWithOffset<T>(self_value_cache_offset)}}};self_attention_layer_->forward(&self_attention_output_tensors,&self_attention_input_tensors,&decoder_layer_weight->at(l).self_attention_weights);

这里以map的方式对输入输出tensor进行了封装,再调用self_attention_layer层进行推理,详细介绍见:进击的Killua:FasterTransformer Decoding 源码分析(四)-SelfAttention实现介绍。

Add Bias & Add Res & LayerNorm

        invokeGeneralAddBiasResidualPreLayerNorm(self_attn_output_,normed_self_attn_output_,self_attn_output_,decoder_input,decoder_layer_weight->at(l).self_attn_layernorm_weights.gamma,decoder_layer_weight->at(l).self_attn_layernorm_weights.beta,decoder_layer_weight->at(l).self_attention_weights.attention_output_weight.bias,layernorm_eps_,batch_size,hidden_units_,(float*)nullptr,(float*)nullptr,(float*)nullptr,(float*)nullptr,0,stream_);sync_check_cuda_error();

这里将add bias、add res和laynorm操作合成一个kernel进行处理,也是优化的经典方法,文章进击的Killua:FasterTransformer Decoding 源码分析(五)-AddBiasResidualLayerNorm介绍 做了详细介绍。

CrossAttention

        TensorMap cross_attention_input_tensors{{"input_query", Tensor{MEMORY_GPU, data_type, {batch_size, hidden_units_}, normed_self_attn_output_}},{"encoder_output", input_tensors->at(1)},{"encoder_sequence_length", input_tensors->at(2)},{"finished", input_tensors->at(3)},{"step", input_tensors->at(4)}};TensorMap cross_attention_output_tensors{{"hidden_features", Tensor{MEMORY_GPU, data_type, {batch_size, hidden_units_}, cross_attn_output_}},{"key_cache",Tensor{MEMORY_GPU,data_type,std::vector<size_t>(output_tensors->at(3).shape.begin() + 1, output_tensors->at(3).shape.end()),output_tensors->at(3).getPtrWithOffset<T>(mem_cache_offset)}},{"value_cache",Tensor{MEMORY_GPU,data_type,std::vector<size_t>(output_tensors->at(4).shape.begin() + 1, output_tensors->at(4).shape.end()),output_tensors->at(4).getPtrWithOffset<T>(mem_cache_offset)}}};cross_attention_layer_->forward(&cross_attention_output_tensors,&cross_attention_input_tensors,&decoder_layer_weight->at(l).cross_attention_weights);

这里以map的方式对输入输出tensor进行了封装,再调用cross_attention_layer层进行推理,详见文章:进击的Killua:FasterTransformer Decoding 源码分析(六)-CrossAttention介绍

Add Bias & Add Res & LayerNorm

        invokeGeneralAddBiasResidualPreLayerNorm(cross_attn_output_,normed_cross_attn_output_,cross_attn_output_,self_attn_output_,decoder_layer_weight->at(l).cross_attn_layernorm_weights.gamma,decoder_layer_weight->at(l).cross_attn_layernorm_weights.beta,decoder_layer_weight->at(l).cross_attention_weights.attention_output_weight.bias,layernorm_eps_,batch_size,hidden_units_,(float*)nullptr,(float*)nullptr,(float*)nullptr,(float*)nullptr,0,stream_);sync_check_cuda_error();

和上述类似。

FFN

        TensorMap ffn_input_tensors({{"ffn_input", Tensor{MEMORY_GPU, data_type, {batch_size, hidden_units_}, normed_cross_attn_output_}}});TensorMap ffn_output_tensors({{"ffn_output", Tensor{MEMORY_GPU, data_type, {batch_size, hidden_units_}, decoder_output}}});ffn_layer_->forward(&ffn_output_tensors, &ffn_input_tensors, &decoder_layer_weight->at(l).ffn_weights);

FFN详细介绍如下文所示。

进击的Killua:FasterTransformer Decoding 源码分析(七)-FFNLayer MoE(上篇)

进击的Killua:FasterTransformer Decoding 源码分析(八)-FFNLayer MoE(下篇)

Add Bias & Add Res

        invokeAddBiasResidual(decoder_output,cross_attn_output_,decoder_layer_weight->at(l).ffn_weights.output_weight.bias,batch_size,hidden_units_,stream_);sync_check_cuda_error();

这里将add bias、add res操作合成一个kernel进行处理,属于fused op的常用操作。

三、总结

总体来看fastertransformer的decoder主要用了小OP融合、大OP重写、重复计算缓存化这几个优化策略来进行加速,接下来开始逐步剖析内部细节。

THE END !

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/831641.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Python编程实践1/3】模块

目录 目标 模块 import ​编辑 代码小结 题目 from...import 随机模块 代码小结 randint函数 骰子大战 choice函数 总结 目标 拧一颗螺丝&#xff0c;只会用到螺丝刀&#xff1b;但是修一台汽车&#xff0c;需要一整套汽修的工具。函数就像螺丝刀&#xff0c;可以帮…

python项目==一个web项目,配置模板指定文件清洗规则,调用模板规则清洗文件

代码地址 一个小工具。 一个web项目&#xff0c;配置模板指定文件清洗规则&#xff0c;调用模板规则清洗文件 https://github.com/hebian1994/csv-transfer-all 技术栈&#xff1a; SQLite python flask vue3 elementplus 功能介绍&#xff1a; A WEB tool for cleaning…

JavaScript:Web APIs(三)

本篇文章的内容包括&#xff1a; 一&#xff0c;事件流 二&#xff0c;移除事件监听 三&#xff0c;其他事件 四&#xff0c;元素尺寸与位置 一&#xff0c;事件流 事件流是什么呢&#xff1f; 事件流是指事件执行过程中的流动路径。 我们发现&#xff0c;一个完整的事件执行…

Delta lake with Java--利用spark sql操作数据1

今天要解决的问题是如何使用spark sql 建表&#xff0c;插入数据以及查询数据 1、建立一个类叫 DeltaLakeWithSparkSql1&#xff0c;具体代码如下&#xff0c;例子参考Delta Lake Up & Running第3章内容 import org.apache.spark.sql.SaveMode; import org.apache.spark.…

区域文本提示的实时文本到图像生成;通过一致性自注意力机制的视频生成工具保持视频的一致性;专门为雪佛兰汽车设计的客服聊天机器人

✨ 1: StreamMultiDiffusion StreamMultiDiffusion是首个基于区域文本提示的实时文本到图像生成框架&#xff0c;实现了高速且互动的图像生成。 StreamMultiDiffusion 旨在结合加速推理技术和基于区域的文本提示控制&#xff0c;以克服之前解决方案中存在的速度慢和用户交互性…

约瑟夫问题新解法

前言 又碰到了约瑟夫问题&#xff0c;这样的题目本来用环形链表模拟的话就能做出来。然而&#xff0c;最近新学习了一种做法&#xff0c;实在是有点震惊到我了。无论是思路上&#xff0c;还是代码量上&#xff0c;都是那么的精彩。就想也震惊一下其他人。谁能想到原来模拟出来四…

C/C++程序设计实验报告综合作业 | 小小计算器

本文整理自博主本科大一《C/C程序设计》专业课的课内实验报告&#xff0c;适合C语言初学者们学习、练习。 编译器&#xff1a;gcc 10.3.0 ---- 注&#xff1a; 1.虽然课程名为C程序设计&#xff0c;但实际上当时校内该课的内容大部分其实都是C语言&#xff0c;C的元素最多可能只…

深度解析 Spring 源码:探寻Bean的生命周期

文章目录 一、 Bean生命周期概述二、Bean生命周期流程图三、Bean生命周期验证3.1 代码案例3.2 执行结果 四、Bean生命周期源码4.1 setBeanName()4.2 setBeanFactory()4.3 setApplicationContext()4.4 postProcessBeforeInitialization()4.5 afterPropertiesSet()4.6 postProces…

力扣刷题第1天:消失的数字

大家好啊&#xff0c;从今天开始将会和大家一起刷题&#xff0c;从今天开始小生也会开辟新的专栏。&#x1f61c;&#x1f61c;&#x1f61c; 目录 第一部分&#xff1a;题目描述 第二部分&#xff1a;题目分析 第三部分&#xff1a;解决方法 3.1 思路一&#xff1a;先排序…

十、多模态大语言模型(MLLM)

1 多模态大语言模型&#xff08;Multimodal Large Language Models&#xff09; 模态的定义 模态&#xff08;modal&#xff09;是事情经历和发生的方式&#xff0c;我们生活在一个由多种模态(Multimodal)信息构成的世界&#xff0c;包括视觉信息、听觉信息、文本信息、嗅觉信…

MySQL技能树学习——数据库组成

数据库组成&#xff1a; 数据库是一个组织和存储数据的系统&#xff0c;它由多个组件组成&#xff0c;这些组件共同工作以确保数据的安全、可靠和高效的存储和访问。数据库的主要组成部分包括&#xff1a; 数据库管理系统&#xff08;DBMS&#xff09;&#xff1a; 数据库管理系…

MySQL45讲(一)(40)

回顾binlog_formatstatement STATEMENT 记录SQL语句。日志文件小&#xff0c;节约IO&#xff0c;但是对一些系统函数不能准确复制或不能复制&#xff0c;如now()、uuid()等 在RR隔离级别下&#xff0c;binlog_formatstatement 如果执行insert select from 这条语句是对于一张…

OpenCV如何为等值线创建边界旋转框和椭圆(63)

返回:OpenCV系列文章目录&#xff08;持续更新中......&#xff09; 上一篇:OpenCV 为轮廓创建边界框和圆(62) 下一篇:OpenCV的图像矩(64) 目标 在本教程中&#xff0c;您将学习如何&#xff1a; 使用 OpenCV 函数 cv::minAreaRect使用 OpenCV 函数 cv::fitEllipse cv::min…

Gradle 进阶学习 之 build.gradle 文件

build.gradle 是什么&#xff1f; 想象一下&#xff0c;你有一个大型的乐高项目&#xff0c;你需要一个清单来列出所有的乐高积木和它们如何组合在一起。在软件开发中&#xff0c;build.gradle 就是这个清单&#xff0c;它告诉计算机如何构建&#xff08;组合&#xff09;你的软…

这是一个简单的照明材料网站,后续还会更新

1、首页效果图 代码 <!DOCTYPE html> <html><head><meta charset"utf-8" /><title>爱德照明网站首页</title><style>/*外部样式*/charset "utf-8";*{margin: 0;padding: 0;box-sizing: border-box;}a{text-dec…

开源版本管理系统的搭建一:SVN

作者&#xff1a;私语茶馆 1.Windows搭建SVN版本管理系统 1.1.SVN概要和组成 背景介绍 Svn是一个开源版本管理系统&#xff0c;由CollabNet公司于2000年发布&#xff0c;23年12月发布最新版本Apache Subversion 1.14.3。官方网站&#xff1a;Apache Subversion。 Svn可以直…

G1 - 生成对抗网络(GAN)

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 目录 理论知识生成器判别器基本原理 环境步骤环境设置数据准备模型设计模型训练模型效果展示 总结与心得体会 理论知识 生成对抗网络&#xff08;Generative …

U盘惊现“USBC乱码”?别急,数据恢复有妙招!

近日&#xff0c;不少用户反馈在将U盘插入电脑后&#xff0c;出现了一个令人困惑的问题&#xff1a;U盘里的文件或文件夹名突然变成了无法识别的乱码&#xff0c;甚至整个U盘的文件系统显示为“USBC乱码”。面对这种情况&#xff0c;用户往往感到无从下手&#xff0c;担心重要数…

【知识加油站】——机电产品数字孪生机理模型构建

明确一种多领域、多层次、参数化、一致性的机电一体化装备数字孪生机理模型构建准则&#xff01; 关键词英文简称&#xff1a; 数字孪生&#xff1a;DT物联网&#xff1a;IoT网络物理系统&#xff1a;CPS高级架构&#xff1a;HLA统一建模语言&#xff1a;UML数控机床&#xf…

webpack打包工具

目录 1. yarn包管理器 1.1 yarn 是什么, 有什么用? 1.2 yarn的使用 ​​​​​​2. webpack基本概述 2.1 webpack是什么&#xff1f; 2.2 什么是打包&#xff1f; 2.3 webpack能做什么&#xff1f; 3. webpack基本使用步骤 3.1 webpack基本使用步骤 3.2 package.jso…