GiantPandaCV | FasterTransformer Decoding 源码分析(六)-CrossAttention介绍

本文来源公众号“GiantPandaCV”,仅用于学术分享,侵权删,干货满满。

原文链接:FasterTransformer Decoding 源码分析(六)-CrossAttention介绍

GiantPandaCV | FasterTransformer Decoding 源码分析(一)-整体框架介绍-CSDN博客

GiantPandaCV | FasterTransformer Decoding 源码分析(二)-Decoder框架介绍-CSDN博客

GiantPandaCV | FasterTransformer Decoding 源码分析(三)-LayerNorm介绍-CSDN博客

GiantPandaCV | FasterTransformer Decoding 源码分析(四)-SelfAttention实现介绍-CSDN博客

GiantPandaCV | FasterTransformer Decoding 源码分析(五)-AddBiasResidualLayerNorm介绍-CSDN博客

作者丨进击的Killua

来源丨https://zhuanlan.zhihu.com/p/670739629

编辑丨GiantPandaCV

本文是FasterTransformer Decoding源码分析的第六篇,笔者试图去分析CrossAttention部分的代码实现和优化。由于CrossAttention和SelfAttention计算流程上类似,所以在实现上FasterTransformer使用了相同的底层Kernel函数,因此会有大量重复的概念和优化点,重复部分本文就不介绍了,所以在阅读本文前务必先浏览进击的Killua:FasterTransformer Decoding 源码分析(四)-SelfAttention实现介绍这篇文章,一些共性的地方会在这篇文章中做统一介绍,本文着重介绍区别点。

一、模块介绍

如下图所示,CrossAttention模块位于DecoderLayer的第4个模块,输入为经过LayerNorm后的SelfAttention结果和encoder的outputs,经过该模块处理后进行残差连接再输入LayerNorm中。

CrossAttention在decoder中的位置

CrossAttention模块本质上还是要实现如下几个公式,主要的区别在于其中 CrossAttention 的K, V矩阵不是使用 上一个 Decoder block的输出或inputs计算的,而是使用Encoder 的编码信息矩阵计算的,这里还是把公式放出来展示下。

crossAttention 公式

二、设计&优化

整体Block和Thread的执行模型还是和SelfAttention的保持一致,这里不再赘述,主要介绍一下有一些区别的KV Cache。

1. KV Cache

由于在CrossAttention中K,V矩阵是来自于已经计算完成的Encoder输出,所以KV Cache的程度会更大,即第一次运算把KV计算出来之后,后续只要读取Cache即可,不需要用本step的输入再进行线性变换得到增量的部分K,V,如下图所示。

三、源码分析

1. 方法入口

CrossAttention的调用入口如下,解释下这里的输入和输出,具体逻辑在后面。

输入Tensor

  1. input_query:normalize之后的SelfAttention输出,大小是[batch_size,hidden_units_]

  2. encoder_output: encoder模块的输出,大小是[batch_size, mem_max_seq_len, memory_hidden_dimension]

  3. encoder_sequence_length:每个句子的长度,大小是[batch_size]

  4. finished: 解码是否结束的标记,大小是[batch_size]

  5. step: 当前解码的步数

输出Tensor

  1. hidden_features:CrossAttention的输出feature,大小是[batch_size,hidden_units_],和input_query大小一致。

  2. key_cache:CrossAttention中存储key的cache,用于后续step的计算。

  3. value_cache: CrossAttention中存储Value的cache,用于后续step的计算。

        TensorMap cross_attention_input_tensors{{"input_query", Tensor{MEMORY_GPU, data_type, {batch_size, hidden_units_}, normed_self_attn_output_}},{"encoder_output", input_tensors->at(1)},{"encoder_sequence_length", input_tensors->at(2)},{"finished", input_tensors->at(3)},{"step", input_tensors->at(4)}}; TensorMap cross_attention_output_tensors{{"hidden_features", Tensor{MEMORY_GPU, data_type, {batch_size, hidden_units_}, cross_attn_output_}},{"key_cache",Tensor{MEMORY_GPU,data_type,std::vector<size_t>(output_tensors->at(3).shape.begin() + 1, output_tensors->at(3).shape.end()),output_tensors->at(3).getPtrWithOffset<T>(mem_cache_offset)}},{"value_cache",Tensor{MEMORY_GPU,data_type,std::vector<size_t>(output_tensors->at(4).shape.begin() + 1, output_tensors->at(4).shape.end()),output_tensors->at(4).getPtrWithOffset<T>(mem_cache_offset)}}};cross_attention_layer_->forward(&cross_attention_output_tensors,&cross_attention_input_tensors,&decoder_layer_weight->at(l).cross_attention_weights);

2. 主体框架

主体框架代码由三部分构成,分别是该step的QKV生成、output生成和Linear输出。其中第一部分和第三部分都使用了cublas的封装矩阵乘方法gemm,这里就不多介绍了,主要功能逻辑在第二部分output生成。

第一部分:QKV生成

如上所述,代码中Q矩阵是需要每个step生成的,而KV矩阵只有第一个step需要生成,后续步骤读取cache即可。

    cublas_wrapper_->Gemm(CUBLAS_OP_N,CUBLAS_OP_N,hidden_units_,  // n                          batch_size,d_model_,  // k                          attention_weights->query_weight.kernel,hidden_units_,  // n                          attention_input,d_model_,  // k                          q_buf_,hidden_units_ /* n */);if (step == 1) {cublas_wrapper_->Gemm(CUBLAS_OP_N,CUBLAS_OP_N,hidden_units_,batch_size * mem_max_seq_len,encoder_output_tensor.shape[2],attention_weights->key_weight.kernel,hidden_units_,encoder_output_tensor.getPtr<T>(),encoder_output_tensor.shape[2],key_mem_cache,hidden_units_);cublas_wrapper_->Gemm(CUBLAS_OP_N,CUBLAS_OP_N,hidden_units_,batch_size * mem_max_seq_len,encoder_output_tensor.shape[2],attention_weights->value_weight.kernel,hidden_units_,encoder_output_tensor.getPtr<T>(),encoder_output_tensor.shape[2],value_mem_cache,hidden_units_);}

第二部分:output生成

核心函数调用,这里参数较多不一一介绍了,非常多(像一些has_ia3等参数应该是在不断迭代的过程中加入的),在后面函数实现中会将重点参数进行阐述。

    cross_attention_dispatch<T>(q_buf_,attention_weights->query_weight.bias,key_mem_cache,attention_weights->key_weight.bias,value_mem_cache,attention_weights->value_weight.bias,memory_sequence_length,context_buf_,finished,batch_size,batch_size,head_num_,size_per_head_,step,mem_max_seq_len,is_batch_major_cache_,q_scaling_,output_attention_param,has_ia3 ? input_tensors->at("ia3_tasks").getPtr<const int>() : nullptr,has_ia3 ? attention_weights->ia3_key_weight.kernel : nullptr,has_ia3 ? attention_weights->ia3_value_weight.kernel : nullptr,stream_);

第三部分:Linear输出

这里就是简单地对上步输出结果乘以一个权重矩阵。

    cublas_wrapper_->Gemm(CUBLAS_OP_N,CUBLAS_OP_N,d_model_,  // nbatch_size,hidden_units_,  // kattention_weights->attention_output_weight.kernel,d_model_,  // ncontext_buf_,hidden_units_,  // kattention_out,d_model_ /* n */);

3. kernel函数调用

上述output生成步骤中会调用如下代码,这里针对每个head中需要处理的层数进行了分类,这个也是大量优化中的常用方案,针对不同的入参大小选择不同size和配置的kernel函数进行处理,这里有经验的一些成分在里面,我们常用的case是hidden_size_per_head=64(head=8)的情况。

template<typename T, typename KERNEL_PARAMS_TYPE>void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream){switch (params.hidden_size_per_head) {case 32:mmha_launch_kernel<T, 32, 32, KERNEL_PARAMS_TYPE>(params, stream);break;case 48:mmha_launch_kernel<T, 48, 64, KERNEL_PARAMS_TYPE>(params, stream);break;case 64:mmha_launch_kernel<T, 64, 64, KERNEL_PARAMS_TYPE>(params, stream);break;case 80:mmha_launch_kernel<T, 80, 128, KERNEL_PARAMS_TYPE>(params, stream);break;case 96:mmha_launch_kernel<T, 96, 128, KERNEL_PARAMS_TYPE>(params, stream);break;case 112:mmha_launch_kernel<T, 112, 128, KERNEL_PARAMS_TYPE>(params, stream);break;case 128:mmha_launch_kernel<T, 128, 128, KERNEL_PARAMS_TYPE>(params, stream);break;case 144:mmha_launch_kernel<T, 144, 256, KERNEL_PARAMS_TYPE>(params, stream);break;case 160:mmha_launch_kernel<T, 160, 256, KERNEL_PARAMS_TYPE>(params, stream);break;case 192:mmha_launch_kernel<T, 192, 256, KERNEL_PARAMS_TYPE>(params, stream);break;case 224:mmha_launch_kernel<T, 224, 256, KERNEL_PARAMS_TYPE>(params, stream);break;case 256:mmha_launch_kernel<T, 256, 256, KERNEL_PARAMS_TYPE>(params, stream);break;default:assert(false);}}

4. kernel函数实现

这个函数和SelfAttention中的kernel函数是同一个,流程如图所示,这里只介绍下区别点。

1. CrossAttention中只有第一个step需要将KV存入Cache,其他step不需要。

        const bool handle_kv = !DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0);if (handle_kv) {// Trigger the stores to global memory.            if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {*reinterpret_cast<Qk_vec_m*>(&params.k_cache[offset]) = vec_conversion<Qk_vec_m, Qk_vec_k>(k);}}

2. 处理本轮step的KV时,也是从cache中取得KV,无需进行本轮计算得到增量KV。

    if (DO_CROSS_ATTENTION) {// The 16B chunk written by the thread.        int co = tidx / QK_VECS_IN_16B;// The position of the thread in that 16B chunk.        int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.        int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +// params.timestep*QK_ELTS_IN_16B +                     tlength * QK_ELTS_IN_16B + ci;k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ?vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k_cache[offset])) :k;}else {if (params.int8_mode == 2) {using Packed_Int8_t  = typename packed_type<int8_t, num_elems<Qk_vec_m>::value>::type;using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec_m>::value>::type;const auto k_scaling = params.qkv_scale_out[1];const auto k_quant =*reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.k)[qk_offset]);convert_from_float(k, mul<Packed_Float_t, float>(k_scaling, float_from_int8(k_quant)));}else {k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ?vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k[qk_offset])) :k;}}if (DO_CROSS_ATTENTION) {v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&v_cache[tlength * Dh]));}

四、总结

本文相对简单,分析了FasterTransformer中CrossAttention模块的设计方法和代码实现,和SelfAttention基本一致,只是对KV Cache的处理细节上有一点区别,整体上看缓存的使用会比SelfAttention多一些,所以速度应该还会快一点。

THE END !

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/845854.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MyBatis系统学习篇 - 分页插件

MyBatis是一个非常流行的Java持久层框架&#xff0c;它简化了数据库操作的代码。分页是数据库查询中常见的需求&#xff0c;MyBatis本身并不直接支持分页功能&#xff0c;但可以通过插件来实现&#xff0c;从而帮助我们在查询数据库的时候更加方便快捷 引入依赖 <dependen…

移动端路由切换解决方案 —— 虚拟任务栈让你的 H5 像APP一样丝滑

目录 01: 前言 02: 通用组件&#xff1a;trigger-menu 和 trigger-menu-item 构建方案分析 03: 通用组件&#xff1a;构建 trigger-menu 和 trigger-menu-item 04: 前台业务下 H5 的应用场景 05: 通用组件&#xff1a;transition-router-view 构建方案分析 与 虚拟任务栈…

Java实战:将学生列表写入文件

本实战项目旨在演示如何使用Java语言将学生信息列表写入到一个文本文件中&#xff0c;并进行单元测试以确保代码的正确性。 创建静态方法 定义一个名为writeStudentsToFile的静态方法&#xff0c;该方法接收两个参数&#xff1a;一个Student对象的列表和一个文件路径。使用File…

Python疑难杂症--考试复习

1.排序输出字典中数据 dic1 {Tom:21,Bob:18,Jack:23,Ana:20} dic2 {李雷:21,韩梅梅:18,小明:23,小红:20} nint(input()) if n>len(dic1):nlen(dic1) print(sorted(dic1.keys())[:n]) print(sorted(dic2.items(),keylambda item:item[1])[:n]) 2.罗马数字转换 def F(s):d{…

SQL—DQL(数据查询语言)之小结

一、引言 在前面我们已经学习完了所有的关于DQL&#xff08;数据查询语言&#xff09;的基础语法块部分&#xff0c;现在对DQL语句所涉及的语法&#xff0c;以及需要注意的事项做一个简单的总结。 二、DQL语句 1、基础查询 注意&#xff1a; 基础查询的语法是&#xff1a;SELE…

FineBi导出Excel后台版实现

就是不通过浏览器,在后台运行的导出 参考文档在:仪表板查看接口- FineBI帮助文档 FineBI帮助文档 我这里是将这个帮助文档中导出的excel文件写到服务器某个地方后,对excel进行其他操作后再下载。由于原有接口耦合了HttpServletRequest req, HttpServletResponse res对象,…

海外短剧APP/H5 系统开发搭建

目前已经有多个客户用我们搭建的海外短剧系统&#xff0c;在使用中已经取得了较高的收益。目前一个客户打算做日本区域的海外短剧项目&#xff0c;需求已经理清楚了&#xff0c;系统正在搭建中

[MYSQL] 部门工资最高的员工

表&#xff1a; Employee ----------------------- | 列名 | 类型 | ----------------------- | id | int | | name | varchar | | salary | int | | departmentId | int | ----------------------- 在 SQL 中&#xff0c;id…

Deconfounding Duration Bias in Watch-time Prediction for Video Recommendation

Abstract 观看时间预测仍然是通过视频推荐加强用户粘性的关键因素。然而&#xff0c;观看时间的预测不仅取决于用户与视频的匹配&#xff0c;而且经常被视频本身的持续时间所误导。为了提高观看时间&#xff0c;推荐总是偏向于长时间的视频。在这种不平衡的数据上训练的模型面…

[机器学习]GPT LoRA 大模型微调,生成猫耳娘

往期热门专栏回顾 专栏描述Java项目实战介绍Java组件安装、使用&#xff1b;手写框架等Aws服务器实战Aws Linux服务器上操作nginx、git、JDK、VueJava微服务实战Java 微服务实战&#xff0c;Spring Cloud Netflix套件、Spring Cloud Alibaba套件、Seata、gateway、shadingjdbc…

牛客网刷题 | BC104 翻转金字塔图案

目前主要分为三个专栏&#xff0c;后续还会添加&#xff1a; 专栏如下&#xff1a; C语言刷题解析 C语言系列文章 我的成长经历 感谢阅读&#xff01; 初来乍到&#xff0c;如有错误请指出&#xff0c;感谢&#xff01; 描述 KiKi学习了循环&am…

万字详解 MySQL MGR 高可用集群搭建

文章目录 1、MGR 前置介绍1.1、什么是 MGR1.2、MGR 优点1.3、MGR 缺点1.4、MGR 适用场景 2、MySQL MGR 搭建流程2.1、环境准备2.2、搭建流程2.2.1、配置系统环境2.2.2、安装 MySQL2.2.3、配置启动 MySQL2.2.4、修改密码、设置主从同步2.2.5、安装 MGR 插件 3、MySQL MGR 故障转…

智慧排水监测系统方案

智慧排水监测系统方案 智慧排水监测系统作为现代城市基础设施管理的重要组成部分&#xff0c;旨在通过先进的信息技术手段&#xff0c;实现对城市排水系统的全面、实时、高效的远程监控与管理。该系统整合了物联网技术、大数据分析、云计算平台与人工智能算法&#xff0c;不仅…

告别暗黄,唤醒肌肤

&#x1f3ad; 想象一下&#xff0c;你的皮肤是舞台上的主角&#xff0c;但最近它似乎有些“疲惫”和“黯淡”&#xff0c;仿佛失去了往日的星光✨。别急&#xff0c;今天&#xff0c;我要为你揭秘一个能让肌肤重新焕发光彩的“魔法”——胶原蛋白&#xff01;&#x1f3a9; &a…

docker查看容器目录挂载

查看命令 docker inspect --format{{ json .Mounts }} <container_id_or_name> | jq 示例 docker inspect --format{{ json .Mounts }} af656ae540af | jq输出

FreeRTOS笔记 - 二(正点原子)

一&#xff0c;任务创建和删除 具体的参数&#xff08;看视频&#xff09; 1&#xff0c;动态和静态创建的区别 动态: 任务的任务控制块以及任务的栈空间所需的内存&#xff0c;均由FreeRTOS从 FreeRTOS 管理的堆中分配。 静态: 任务的任务控制块以及任务的栈空间所需的内存&am…

vscode设置编辑器文件自动保存

步骤 1.打开vscode的设置 2.在搜索栏输入关键字“保存”&#xff1b; 在 Files: Auto Save 设置项&#xff0c;选择自动保存的模式

java使用资源过高排查

在生产环境中有可能出现某java程序使用资源特别严重&#xff0c;这就需要找到该java进程&#xff0c;然后通过进程去找到是哪个线程的问题&#xff0c;这里我们就是用pidstat工具来排查一下 安装pidstat工具 yum -y install sysstat 查看java服务的pid jps 通过pid查看线…

C# WinForm —— 25 ProgressBar 介绍与使用

1. 简介 用于显示某个操作的进度 2. 常用属性 属性解释(Name)控件ID&#xff0c;在代码里引用的时候会用到,一般以 pbar 开头ContextMenuStrip右键菜单Enabled控件是否可用ForeColor用于显示进度的颜色MarqueeAnimationSpeed进度条动画更新的速度&#xff0c;以毫秒为单位M…

CSAPP Lab08——Proxy Lab完成思路

蓝色的思念 突然演变成了阳光的夏天 空气中的温暖不会很遥远 ——被风吹过的夏天 完整代码见&#xff1a;CSAPP/proxylab-handout at main SnowLegend-star/CSAPP (github.com) Q&#xff1a;计算机网络中port的作用是什么&#xff1f; A&#xff1a;在计算机网络中&#xff…