Transformer self-attention源码及原理理解

 自注意力计算公式:

  • 在公式(1)中Q(query)是输入一个序列中的一个token,K(key)代表序列中所有token的特征。
  •  QK^{T}可以得到当前token与序列中其他token的相关性。
  • 在论文原文中d_{model}=512,表示每个token用512维特征表示(序列符号的embedding长度)。 d_{k}=d_{model}\div h=64,表示每个头的大小为64。

自注意力机制的pytorch实现:

def attention(query, key, value, mask=None, dropout=None):"Compute 'Scaled Dot Product Attention'"d_k = query.size(-1)scores = torch.matmul(query, key.transpose(-2, -1)) \/ math.sqrt(d_k)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)p_attn = F.softmax(scores, dim = -1)if dropout is not None:p_attn = dropout(p_attn)return torch.matmul(p_attn, value), p_attn

多头注意力机制的pytorch实现如下:

class MultiHeadedAttention(nn.Module):def __init__(self, h, d_model, dropout=0.1):"Take in model size and number of heads."super(MultiHeadedAttention, self).__init__()assert d_model % h == 0# We assume d_v always equals d_kself.d_k = d_model // hself.h = hself.linears = clones(nn.Linear(d_model, d_model), 4)self.attn = Noneself.dropout = nn.Dropout(p=dropout)def forward(self, query, key, value, mask=None):"Implements Figure 2"if mask is not None:# Same mask applied to all h heads.mask = mask.unsqueeze(1)nbatches = query.size(0)# 1) Do all the linear projections in batch from d_model => h x d_k query, key, value = \[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)for l, x in zip(self.linears, (query, key, value))]#这段代码首先使用zip函数,将self.linears和(query, key, value)这两个列表打包成一个元组列表,其中每个元组包含一个线性层对象和一个输入张量#对遍历的每一个Linear层,对query key value分别计算,结果放在query key value中输出# 2) Apply attention on all the projected vectors in batch. x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)# 3) "Concat" using a view and apply a final linear. x = x.transpose(1, 2).contiguous() \.view(nbatches, -1, self.h * self.d_k)return self.linears[-1](x)

W_{i}^{Q} \quad W_{i}^{Q} \quad W_{i}^{V}对应Figure2中的三个Linear层的权重,通过训练可得,它们的形状是(需要从代码理解),用来将原始的Q K V投影到下一层做Dot-Production attention计算。

首先Q K V怎么来的?

 和输入序列的token有关

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/753922.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C语言中大小写字母如何转化

🌟 前言 欢迎来到我的技术小宇宙!🌌 这里不仅是我记录技术点滴的后花园,也是我分享学习心得和项目经验的乐园。📚 无论你是技术小白还是资深大牛,这里总有一些内容能触动你的好奇心。🔍 &#x…

Linux TCP参数——tcp_adv_win_scale

文章目录 tcp_adv_win_scaleip-sysctl.txt解释buffering overhead内核缓存和应用缓存示例计算深入理解从2到1(tcp_adv_win_scale的值)总结 tcp_adv_win_scale adv-advise;win-window; 用于指示TCP中接收缓存比例的值。 static inline int tcp_win_from_space(int …

【字符串算法题】541. 反转字符串 II

题目链接 思考 把字符串以2k的步长分成count_reverse个子区间。考虑最后一个子区间的字符数量:1)如果大于等于k,则它要和前面的子区间一样,要对区间内的前k个字符进行反转;2)如果小于k,则它的…

SpringSecurity(SpringBoot2.X版本实现)

资料来源于 SpringSecurity框架教程-Spring SecurityJWT实现项目级前端分离认证授权 侵权删 目录 介绍 快速开始 认证 认证流程 登录校验流程 SpringSecurity完整流程 认证流程详解 代码实现 准备工作 mysql mybatis-plus redis 统一返回类 核心代码 密码加密存…

Mesh网格obj文件构成解析

众所周知,Mesh网格是三维重建的常用手法,通过顶点-三角面的形式来完成对三维物体的表达。其中,最常见的Mesh网格文件格式就是obj格式。看起来复杂的三维形状其实在数值表示上是很简单的,大家跟我一起来做个小实验就好:…

echarts散点图自定义tooltip,鼠标放上去展示多行数据

先放效果图 如图,就是鼠标悬停在散点上(这里的散点我替换成了图片,具体做法参考这篇文章:echarts散点图的散点用自定义图片替代-CSDN博客)时,可以展示多行数据。之前查找资料的时候,很多用字符串…

【兆易创新GD32H759I-EVAL开发板】 LUT功能

颜色查找表(LUT, Lookup Table)模式在图像处理和显示中是一种有效的数据表示和压缩方式。它通过将图像中的颜色映射到一个预定义的颜色表来实现,这样每个像素不是直接存储完整的颜色值,而是存储一个指向颜色表中特定颜色的索引。这…

练习unittest+Fixture实现

练习01 创建⼀个⽬录 case, 作⽤就是⽤来存放⽤例脚本,在这个⽬录中创建 5 个⽤例代码⽂件 , test_case1.py使⽤ TestLoader 去执⾏⽤例 将来的代码 ⽤例都是单独的⽬录中存放的 test_项⽬_模块_功能.py test_case1.py # 1. 导包 unittest import unittest # 2. 定义测试类, 只…

面试经典150题(114-118)

leetcode 150道题 计划花两个月时候刷完之未完成后转,今天完成了5道(114-118)150 gap 了一周,以后就不记录时间了。。 114.(70. 爬楼梯) 题目描述: 假设你正在爬楼梯。需要 n 阶你才能到达楼顶。 每次你可以爬 1 或 2 个台阶。你有多少种不…

24计算机考研调剂 | 集美大学(11408)

[2024考研调剂]集美大学计算机工程学院智慧城市创新实验室招收学硕 考研调剂招生信息 学校:集美大学 专业:工学->软件工程 年级:2024 招生人数:8 招生状态:正在招生中 联系方式: ********* (为保护个人隐私,联系方式仅限APP查看) 补充内容 实验…

旅行社旅游线路预定管理系统asp.net

旅行社旅游线路预定管理系统 首页 国内游 境外游 旅游景点 新闻资讯 酒店信息―留言板 后台管理 后台管理导航菜单系统管理修改密码留言管理注册会员管理基础数据设置国别设置有份设地区设置 旅行社管理友情链接管理添加友情链接友情链接管理新闻资讯管理添加新闻资讯新闻资讯管…

LayerNormalization 和 RMSNormalization的计算方法和区别

目录 问题来源 Layer Normalization 与 RMSNormalization 的详细计算方法 Layer Normalization(层归一化) RMSNormalization(均方根归一化) Layer Normalization与RMSNormalization的异同 Layer Normalization RMSNormaliza…

24 OpenCV直方图反向投影

文章目录 参考反向投影作用calceackProject 反向投影mixchannels 通道图像分割示例 参考 直方图反向投影 反向投影 反向投影是反映直方图模型在目标图像中的分布情况简单点说就是用直方图模型去目标图像中寻找是否有相似的对象。通常用HSV色彩空间的HS两个通道直方图模型 作用…

探索未来科技:量子计算的前沿与挑战

随着信息技术的飞速发展,传统的计算模式已经难以满足日益增长的数据处理需求。在这个背景下,量子计算作为一种全新的计算模式,逐渐进入人们的视野。本文将探讨量子计算的前沿技术以及在软件开发领域所面临的挑战。 量子计算的前沿技术 量子计…

基于时空上下文(STC)的运动目标跟踪算法,Matlab实现

博主简介: 专注、专一于Matlab图像处理学习、交流,matlab图像代码代做/项目合作可以联系(QQ:3249726188) 个人主页:Matlab_ImagePro-CSDN博客 原则:代码均由本人编写完成,非中介,提供…

Android Framework开发之Linux +Vim命令

一、linux常用命令 在Android源码开发中,Linux命令的运用是至关重要的。这些命令不仅帮助开发者有效管理文件、目录和系统资源,还能在源码编译、调试和排错过程中发挥关键作用。以下是对Android源码开发中常用Linux命令的更详细介绍: 当然可…

Midjourney 和 Dall-E 的优劣势比较

Midjourney 和 Dall-E 的优劣势比较 Midjourney 和 Dall-E 都是强大的 AI 绘画工具,可以根据文本描述生成图像。 它们都使用深度学习模型来理解文本并将其转换为图像。 但是,它们在功能、可用性和成本方面存在一些差异。 Midjourney 优势: 可以生成更…

js判断对象是否有某个属性

前端判断后端接口是否返回某个字段的时候 <script>var obj { name: "John", age: 30 };console.log(obj.hasOwnProperty("name")); // 输出 trueconsole.log(obj.hasOwnProperty("email")); // 输出 falselet obj11 { name: "Joh…

9. 编程常见错误归类

编程常见错误归类 9.1 编译型错误9.2 链接型错误9.3 运行时错误 9.1 编译型错误 编译型错误⼀般都是语法错误&#xff0c;这类错误⼀般看错误信息就能找到⼀些蛛丝马迹的&#xff0c;双击错误信息也能初步的跳转到代码错误的地方或者附近。编译错误&#xff0c;随着语言的熟练…

本地mysql测试成功后上传至云服务器出现了这么多问题?

本地MySQL数据库迁移至云服务器的过程中可能出现多种问题,以下是常见的一些原因及其解决思路: 权限问题: 账户权限:本地MySQL数据库的用户权限设置可能与云服务器上的MySQL实例不同,比如未授权远程连接或赋予了错误的权限。你需要确认云服务器MySQL数据库的用户是否有从远…