Transformer self-attention源码及原理理解

 自注意力计算公式:

  • 在公式(1)中Q(query)是输入一个序列中的一个token,K(key)代表序列中所有token的特征。
  •  QK^{T}可以得到当前token与序列中其他token的相关性。
  • 在论文原文中d_{model}=512,表示每个token用512维特征表示(序列符号的embedding长度)。 d_{k}=d_{model}\div h=64,表示每个头的大小为64。

自注意力机制的pytorch实现:

def attention(query, key, value, mask=None, dropout=None):"Compute 'Scaled Dot Product Attention'"d_k = query.size(-1)scores = torch.matmul(query, key.transpose(-2, -1)) \/ math.sqrt(d_k)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)p_attn = F.softmax(scores, dim = -1)if dropout is not None:p_attn = dropout(p_attn)return torch.matmul(p_attn, value), p_attn

多头注意力机制的pytorch实现如下:

class MultiHeadedAttention(nn.Module):def __init__(self, h, d_model, dropout=0.1):"Take in model size and number of heads."super(MultiHeadedAttention, self).__init__()assert d_model % h == 0# We assume d_v always equals d_kself.d_k = d_model // hself.h = hself.linears = clones(nn.Linear(d_model, d_model), 4)self.attn = Noneself.dropout = nn.Dropout(p=dropout)def forward(self, query, key, value, mask=None):"Implements Figure 2"if mask is not None:# Same mask applied to all h heads.mask = mask.unsqueeze(1)nbatches = query.size(0)# 1) Do all the linear projections in batch from d_model => h x d_k query, key, value = \[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)for l, x in zip(self.linears, (query, key, value))]#这段代码首先使用zip函数,将self.linears和(query, key, value)这两个列表打包成一个元组列表,其中每个元组包含一个线性层对象和一个输入张量#对遍历的每一个Linear层,对query key value分别计算,结果放在query key value中输出# 2) Apply attention on all the projected vectors in batch. x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)# 3) "Concat" using a view and apply a final linear. x = x.transpose(1, 2).contiguous() \.view(nbatches, -1, self.h * self.d_k)return self.linears[-1](x)

W_{i}^{Q} \quad W_{i}^{Q} \quad W_{i}^{V}对应Figure2中的三个Linear层的权重,通过训练可得,它们的形状是(需要从代码理解),用来将原始的Q K V投影到下一层做Dot-Production attention计算。

首先Q K V怎么来的?

 和输入序列的token有关

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/753922.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C语言中大小写字母如何转化

🌟 前言 欢迎来到我的技术小宇宙!🌌 这里不仅是我记录技术点滴的后花园,也是我分享学习心得和项目经验的乐园。📚 无论你是技术小白还是资深大牛,这里总有一些内容能触动你的好奇心。🔍 &#x…

Linux TCP参数——tcp_adv_win_scale

文章目录 tcp_adv_win_scaleip-sysctl.txt解释buffering overhead内核缓存和应用缓存示例计算深入理解从2到1(tcp_adv_win_scale的值)总结 tcp_adv_win_scale adv-advise;win-window; 用于指示TCP中接收缓存比例的值。 static inline int tcp_win_from_space(int …

SpringSecurity(SpringBoot2.X版本实现)

资料来源于 SpringSecurity框架教程-Spring SecurityJWT实现项目级前端分离认证授权 侵权删 目录 介绍 快速开始 认证 认证流程 登录校验流程 SpringSecurity完整流程 认证流程详解 代码实现 准备工作 mysql mybatis-plus redis 统一返回类 核心代码 密码加密存…

Mesh网格obj文件构成解析

众所周知,Mesh网格是三维重建的常用手法,通过顶点-三角面的形式来完成对三维物体的表达。其中,最常见的Mesh网格文件格式就是obj格式。看起来复杂的三维形状其实在数值表示上是很简单的,大家跟我一起来做个小实验就好:…

echarts散点图自定义tooltip,鼠标放上去展示多行数据

先放效果图 如图,就是鼠标悬停在散点上(这里的散点我替换成了图片,具体做法参考这篇文章:echarts散点图的散点用自定义图片替代-CSDN博客)时,可以展示多行数据。之前查找资料的时候,很多用字符串…

练习unittest+Fixture实现

练习01 创建⼀个⽬录 case, 作⽤就是⽤来存放⽤例脚本,在这个⽬录中创建 5 个⽤例代码⽂件 , test_case1.py使⽤ TestLoader 去执⾏⽤例 将来的代码 ⽤例都是单独的⽬录中存放的 test_项⽬_模块_功能.py test_case1.py # 1. 导包 unittest import unittest # 2. 定义测试类, 只…

面试经典150题(114-118)

leetcode 150道题 计划花两个月时候刷完之未完成后转,今天完成了5道(114-118)150 gap 了一周,以后就不记录时间了。。 114.(70. 爬楼梯) 题目描述: 假设你正在爬楼梯。需要 n 阶你才能到达楼顶。 每次你可以爬 1 或 2 个台阶。你有多少种不…

旅行社旅游线路预定管理系统asp.net

旅行社旅游线路预定管理系统 首页 国内游 境外游 旅游景点 新闻资讯 酒店信息―留言板 后台管理 后台管理导航菜单系统管理修改密码留言管理注册会员管理基础数据设置国别设置有份设地区设置 旅行社管理友情链接管理添加友情链接友情链接管理新闻资讯管理添加新闻资讯新闻资讯管…

LayerNormalization 和 RMSNormalization的计算方法和区别

目录 问题来源 Layer Normalization 与 RMSNormalization 的详细计算方法 Layer Normalization(层归一化) RMSNormalization(均方根归一化) Layer Normalization与RMSNormalization的异同 Layer Normalization RMSNormaliza…

24 OpenCV直方图反向投影

文章目录 参考反向投影作用calceackProject 反向投影mixchannels 通道图像分割示例 参考 直方图反向投影 反向投影 反向投影是反映直方图模型在目标图像中的分布情况简单点说就是用直方图模型去目标图像中寻找是否有相似的对象。通常用HSV色彩空间的HS两个通道直方图模型 作用…

基于时空上下文(STC)的运动目标跟踪算法,Matlab实现

博主简介: 专注、专一于Matlab图像处理学习、交流,matlab图像代码代做/项目合作可以联系(QQ:3249726188) 个人主页:Matlab_ImagePro-CSDN博客 原则:代码均由本人编写完成,非中介,提供…

Midjourney 和 Dall-E 的优劣势比较

Midjourney 和 Dall-E 的优劣势比较 Midjourney 和 Dall-E 都是强大的 AI 绘画工具,可以根据文本描述生成图像。 它们都使用深度学习模型来理解文本并将其转换为图像。 但是,它们在功能、可用性和成本方面存在一些差异。 Midjourney 优势: 可以生成更…

js判断对象是否有某个属性

前端判断后端接口是否返回某个字段的时候 <script>var obj { name: "John", age: 30 };console.log(obj.hasOwnProperty("name")); // 输出 trueconsole.log(obj.hasOwnProperty("email")); // 输出 falselet obj11 { name: "Joh…

9. 编程常见错误归类

编程常见错误归类 9.1 编译型错误9.2 链接型错误9.3 运行时错误 9.1 编译型错误 编译型错误⼀般都是语法错误&#xff0c;这类错误⼀般看错误信息就能找到⼀些蛛丝马迹的&#xff0c;双击错误信息也能初步的跳转到代码错误的地方或者附近。编译错误&#xff0c;随着语言的熟练…

力扣栈题:删除最外层括号

char* removeOuterParentheses(char* s) {int stack 0;int num0;for(int i0;i<strlen(s);i){if(s[i](){stack;if(stack>1){s[num]s[i];}}else{--stack;if(stack>0){s[num]s[i];}}}s[num]\0;return s; } 思路&#xff1a;迭代加栈&#xff0c;如果不是第一个的左括号则…

苍穹外卖-day10:Spring Task、订单状态定时处理、来单提醒(WebSocket的应用)、客户催单(WebSocket的应用)

苍穹外卖-day10 课程内容 Spring Task订单状态定时处理WebSocket来单提醒客户催单 功能实现&#xff1a;订单状态定时处理、来单提醒和客户催单 订单状态定时处理&#xff1a; 来单提醒&#xff1a; 客户催单&#xff1a; 1. Spring Task 1.1 介绍 Spring Task 是Spring框…

win32汇编弹出对话框

之前书上有一个win32 asm 的odbc例子&#xff0c;它有一个窗体&#xff0c;可以执行sql&#xff1b;下面看一下弹出一个录入数据的对话框&#xff1b; 之前它在.code段包含2个单独的asm文件&#xff0c;增加第三个&#xff0c;增加的这个里面是弹出对话框的窗口过程&#xff0…

哪些AI知识库比较好用?企业高管必看!

在科技进步的时代&#xff0c;工作效率和知识管理是企业面临的两大挑战。而AI知识库&#xff0c;正是解决这个问题的利剑。接下来&#xff0c;我将与你分享三款好用的AI知识库平台&#xff0c;感兴趣就往下看吧。 首先&#xff0c;我们不得不提的是Helplook。这是一个根据人工智…

使用Python进行数据库连接与操作SQLite和MySQL【第144篇—SQLite和MySQL】

&#x1f47d;发现宝藏 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。【点击进入巨牛的人工智能学习网站】。 使用Python进行数据库连接与操作&#xff1a;SQLite和MySQL 在现代应用程序开发中&#xf…

spring-boot-starter-thymeleaf加载外部html文件

在Spring MVC中&#xff0c;我们可以使用Thymeleaf模板引擎来实现加载外部HTML文件。 1.Thymeleaf介绍 Thymeleaf是一种现代化的服务器端Java模板引擎&#xff0c;用于构建漂亮、可维护且易于测试的动态Web应用程序。它适用于与Spring框架集成&#xff0c;并且可以与Spring M…