辅导男朋友转算法岗的第2天|self Attention与kv cache

文章目录

    • 公式
    • KV Cache
    • MHA、MQA、GQA
  • 面试题

公式

$ \text{Output} = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) \times V$ 复杂度是O( n 2 n^2 n2)

KV Cache

推理阶段最常用的缓存机制,用空间换时间。

原理:

在进行自回归解码的时候,新生成的token会加入序列,一起作为下一次解码的输入。

由于单向注意力的存在,新加入的token并不会影响前面序列的计算,因此可以把已经计算过的每层的kv值保存起来,这样就节省了和本次生成无关的计算量。

通过把kv值存储在速度远快于显存的L2缓存中,可以大大减少kv值的保存和读取,这样就极大加快了模型推理的速度。

分别做一个k cache和一个v cache,把之前计算的k和v存起来

以v cache为例:

在这里插入图片描述

存在的问题:存储碎片化

解决方法:page attention(封装在vllm里了)

MHA、MQA、GQA

Multi-Head Attention、Multi-Query Attention、Group-Query Attention

目的:优化KV Cache所需空间大小

原理是共享k和v,但是使用MQA效果会差一些,于是又出现了GQA这种折中的办法

在这里插入图片描述

面试题

为什么除以 d k \sqrt{d_k} dk

压缩softmax输入值,以免输入值过大,进入了softmax的饱和区,导致梯度值太小而难以训练。

Multihead的好处

1、每个head捕获不同的信息,多个头能够分别关注到不同的特征,增强了表达能力。多个头中,会有部分头能够学习到更高级的特征,并减少注意力权重对角线值过大的情况。

比如部分头关注语法信息,部分头关注知识内容,部分头关注近距离文本,部分头关注远距离文本,这样减少信息缺失,提升模型容量。

2、类似集成学习,多个模型做决策,降低误差

decoder-only模型在训练阶段和推理阶段的input有什么不同?

  • 训练阶段:模型一次性处理整个输入序列,输入是完整的序列,掩码矩阵是固定的上三角矩阵。
  • 推理阶段:模型逐步生成序列,输入是一个初始序列,然后逐步添加生成的 token。掩码矩阵需要动态调整,以适应不断增加的序列长度,并考虑缓存机制。

手撕必背-多头注意力

逐头计算

import torch.nn as nn
class MultiHeadAttentionScores(nn.Module):def __init__(self, hidden_size, num_attention_heads, attention_head_size):super(MultiHeadAttentionScores, self).__init__()self.num_attention_heads = num_attention_heads # 8,16, 32, 64# Create a query, key, and value projection layer# for each attention head.  W^Q, W^K, W^Vself.query_layers = nn.ModuleList([nn.Linear(hidden_size, attention_head_size) for _ in range(num_attention_heads)])self.key_layers = nn.ModuleList([nn.Linear(hidden_size, attention_head_size) for _ in range(num_attention_heads)])self.value_layers = nn.ModuleList([nn.Linear(hidden_size, attention_head_size) for _ in range(num_attention_heads)])def forward(self, hidden_states):# Create a list to store the outputs of each attention headall_attention_outputs = []for i in range(self.num_attention_heads): # i.e. 8query_vectors = self.query_layers[i](hidden_states)key_vectors = self.key_layers[i](hidden_states)value_vectors = self.value_layers[i](hidden_states)# softmax(Q&K^T)*Vattention_scores = torch.matmul(query_vectors, key_vectors.transpose(-1, -2))# attention_scores combined with softmax--> normalized_attention_scoreattention_outputs = torch.matmul(attention_scores, value_vectors)all_attention_outputs.append(attention_outputs)return all_attention_outputs

矩阵运算

import torch
import torch.nn as nnclass MultiHeadAttentionScores(nn.Module):def __init__(self, hidden_size, num_attention_heads, attention_head_size):super(MultiHeadAttentionScores, self).__init__()self.num_attention_heads = num_attention_headsself.attention_head_size = attention_head_sizeself.hidden_size = hidden_sizeself.query = nn.Linear(hidden_size, num_attention_heads * attention_head_size)self.key = nn.Linear(hidden_size, num_attention_heads * attention_head_size)self.value = nn.Linear(hidden_size, num_attention_heads * attention_head_size)def forward(self, hidden_states):batch_size = hidden_states.size(0)query_layer = self.query(hidden_states)key_layer = self.key(hidden_states)value_layer = self.value(hidden_states)query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)value_layer = value_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))attention_probs = nn.Softmax(dim=-1)(attention_scores)attention_outputs = torch.matmul(attention_probs, value_layer)attention_outputs = attention_outputs.transpose(1, 2).contiguous().view(batch_size, -1, self.num_attention_heads * self.attention_head_size)return attention_outputs

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/844982.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C#,JavaScript实现浮点数格式化自动保留合适的小数位数

目标 由于浮点数有漂移问题,转成字符串时 3.6 有可能得到 3.6000000000001,总之很长的一串,通常需要截取,但按照固定长度截取不一定能使用各种情况,如果能根据数值大小保留有效位数就好了。 C#实现 我们可以在基础库里…

【错题集-编程题】过桥(BFS)

牛客对应题目链接&#xff1a;过桥 (nowcoder.com) 一、分析题目 类似层序遍历的思想。 二、代码 //值得学习的代码 #include <iostream>using namespace std;const int N 2010;int n; int arr[N];int bfs() {int left 1, right 1;int ret 0;while(left < right)…

JDK环境配置、安装

DK环境配置&#xff08;备注&#xff1a;分32位与64位JDK&#xff0c;32位电脑只能按照32位JDK&#xff0c;64位电脑兼容32、64位JDK&#xff09; 一、检查自己电脑是否安装过JDK 1.在电脑屏幕左下角&#xff0c;输入命令提示符CMD&#xff0c;打开命令提示符应用 2.在打开界…

vivo X200系列即将发布:首发将搭载天玑最新芯片9400

随着智能手机技术的不断进步&#xff0c;vivo作为全球知名的智能手机制造商&#xff0c;一直在为用户带来创新和惊喜。最近&#xff0c;vivo的粉丝们有理由感到兴奋&#xff0c;因为最新的消息称&#xff0c;vivo X200系列即将发布&#xff0c;并且将首发搭载天玑最新的9400处理…

如何实现一个AI聊天功能

最近公司的网站上需要对接一个AI聊天功能&#xff0c;领导把这个任务分给了我&#xff0c;从最初的调研&#xff0c;学习&#xff0c;中间也踩过一些坑&#xff0c;碰到过问题&#xff0c;但最后对接成功&#xff0c;还是挺有成就感的&#xff0c;今天把这个历程和项目整理一下…

中草药识别系统Python+深度学习人工智能+TensorFlow+卷积神经网络算法模型

一、介绍 中草药识别系统。本系统基于TensorFlow搭建卷积神经网络算法&#xff08;ResNet50算法&#xff09;通过对10中常见的中草药图片数据集&#xff08;‘丹参’, ‘五味子’, ‘山茱萸’, ‘柴胡’, ‘桔梗’, ‘牡丹皮’, ‘连翘’, ‘金银花’, ‘黄姜’, ‘黄芩’&…

5.26机器人基础-DH参数 正解

1.建立DH坐标系 1.确定Zi轴&#xff08;关节轴&#xff09; 2.确定基础坐标系 3.确定Xi方向&#xff08;垂直于zi和zi1的平面&#xff09; 4.完全确定各个坐标系 例子&#xff1a; 坐标系的布局是由个人决定的&#xff0c;可以有不同的选择 标准坐标系布局&#xff1a; …

HTML静态网页成品作业(HTML+CSS)——企业装饰公司介绍网页(4个页面)

&#x1f389;不定期分享源码&#xff0c;关注不丢失哦 文章目录 一、作品介绍二、作品演示三、代码目录四、网站代码HTML部分代码 五、源码获取 一、作品介绍 &#x1f3f7;️本套采用HTMLCSS&#xff0c;未使用Javacsript代码&#xff0c;共有4个页面。 二、作品演示 三、代…

笔记:Windows故障转移集群下的oracle打补丁

以下方法比较暴力&#xff0c;请谨慎使用 1&#xff0c;关闭并禁用故障转移集群的服务&#xff0c;如下 2&#xff0c;关闭故障转移集群中资源的自启动 3&#xff0c;重启服务器 4&#xff0c;手动关闭服务 net stop msdtc net stop winmgmt 5&#xff0c;分别对所有节点打…

【Qt秘籍】[001]-从入门到成神-前言

一、Qt是什么&#xff1f;[概念] Qt是一个跨平台的应用程序开发框架&#xff0c;简单来说&#xff0c;它是一套工具和库&#xff0c;帮助软件开发者编写可以在多种操作系统上运行的图形用户界面&#xff08;GUI&#xff09;应用程序。比如&#xff0c;你用Qt写了一个软件&#…

成绩发布小程序哪个好用?

大家好&#xff0c;今天我要来跟大家分享一个超级实用的小秘密——易查分小程序&#xff01;作为老师&#xff0c;你是不是还在为发放成绩而头疼&#xff1f;是不是还在为通知家长而烦恼&#xff1f;别急&#xff0c;易查分小程序来帮你啦&#xff01; 易查分简直是老师们的贴心…

C++的第一道门坎:类与对象(三)

目录 一.再谈构造函数 1.1构造函数体赋值 1.2初始化列表 1.3explicit关键字 二.static成员 2.1概念 ​编辑 2.2特性 三.友元 3.1友元函数 3.2友元类 4.内部类 一.再谈构造函数 1.1构造函数体赋值 class Date { public:Date(int year,int month,int day){_year ye…

内核编译版本号带有+问题

编译内核4.19.163以后 make ARCHarm64 modules_install INSTALL_MOD_PATH../aarch64_modules/ 发现 DEPMOD 4.19.246 修改 scripts/setlocalversion 把那个号给它干掉 解决问题

订单共享模式:开启你的终身财富之旅

在当今这个信息爆炸的时代&#xff0c;每个人都在寻找着属于自己的财富增长之道。而“二人订单共享结束制”作为一种全新的商业模式&#xff0c;正以其独特的魅力吸引着越来越多的目光。只需499元的终身消费&#xff0c;你便能成为平台的会员&#xff0c;开启一段与众不同的赚钱…

范闲通过MD5哈希算法破解庆帝与神庙信件的精彩解析

价值万元免费资料领取欢迎关注 公众号 数据分析螺丝钉 剧情背景 在《庆余年2》中&#xff0c;范闲与庆帝和神庙之间的权谋斗争愈演愈烈。但是其实早在第一季&#xff0c;范闲宫中在找打开箱子钥匙的时候就发现了一封秘信&#xff0c;这封信件可能隐藏着揭露叶轻眉的一些关键信…

基于Pytorch框架的深度学习EfficientNetV2神经网络中草药识别分类系统源码

第一步&#xff1a;准备数据 5种中草药数据&#xff1a;self.class_indict ["百合", "党参", "山魈", "枸杞", "槐花", "金银花"] &#xff0c;总共有900张图片&#xff0c;每个文件夹单独放一种数据 第二步&a…

Docker搭建FRP内网穿透服务器

使用Docker搭建一个frp内网穿透 在现代网络环境中&#xff0c;由于防火墙和NAT等原因&#xff0c;内网设备无法直接被外网访问。FRP (Fast Reverse Proxy) 是一款非常流行的内网穿透工具&#xff0c;它能够帮助我们将内网服务暴露给外网。本文将介绍如何在Linux服务器上使用Do…

压测工具Jmeter的使用

一、安装 下载地址&#xff1a; 国外地址&#xff1a;jmeter.apache.org&#xff08;下载会很慢&#xff0c;建议使用国内地址&#xff09; 国内地址&#xff1a;apache-jmeter-binaries安装包下载_开源镜像站-阿里云 下载好进入bin文件下&#xff0c;双击jmeter.bat 打开…

哈希传递(PTH)

使用Mimikatz进行PTH Pass The Hash 哈希传递攻击简称 PTH&#xff0c;该方法通过找到与账户相关的密码散列&#xff08;NTLLHash&#xff09;来进 行攻击。由于在Windows系统中&#xff0c;通常会使用NTLM Hash对访问资源的用户进行身份认证&#xff0c;所以该攻 击可以在不需…

算法学习笔记(7.2)-贪心算法(最大容量问题)

目录 ##问题描述 ##问题示例 ##释 ##贪心策略的确定 ##代码示例 ##正确性验证 ##问题描述 输入一个数组 ℎ&#x1d461; &#xff0c;其中的每个元素代表一个垂直隔板的高度。数组中的任意两个隔板&#xff0c;以及它们之间的空间可以组成一个容器。 容器的容量等于高度和宽…