transformer_多头注意力机制代码笔记

transformer_多头注意力机制代码笔记

以GPT-2中多头注意力机制代码为例

class CausalSelfAttention(nn.Module):"""因果掩码+多头自注意力机制A vanilla multi-head masked self-attention layer with a projection at the end.It is possible to use torch.nn.MultiheadAttention here but I am including anexplicit implementation here to show that there is nothing too scary here."""def __init__(self, config):super().__init__()assert config.n_embd % config.n_head == 0# key, query, value projections for all heads, but in a batchself.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)  # Q/K/V的线性映射# output projectionself.c_proj = nn.Linear(config.n_embd, config.n_embd)  # 输出线性映射# regularizationself.attn_dropout = nn.Dropout(config.attn_pdrop)  # dropout正则化self.resid_dropout = nn.Dropout(config.resid_pdrop)# causal mask to ensure that attention is only applied to the left in the input sequenceself.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))  # 构建casual mask矩阵self.n_head = config.n_headself.n_embd = config.n_embddef forward(self, x):B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)# calculate query, key, values for all heads in batch and move head forward to be the batch dimq, k ,v  = self.c_attn(x).split(self.n_embd, dim=2)  # 同时计算QKV# 多头切分, nh*hs=C# 一定要先view (B, T, nh, hs), 再transpose,不能直接view到(B, nh, T, hs)  k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)  q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)# 除以根号下dk保持方差稳定att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))# 因果掩码att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))att = F.softmax(att, dim=-1)att = self.attn_dropout(att)y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)# contiguous:张量在底层存储时必须连续才可以使用viewy = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side# output projectiony = self.resid_dropout(self.c_proj(y))return y

以下为对多头注意力机制代码做分步笔记
上述代码中初始化中定义了构建多头注意力机制代码的组件(结构),在forward的方法中将使用初始化中的组件构建多头注意力机制。
从forward方法开始阅读,当使用到初始化方法中的代码时再进行阅读

# x是输入的数据
# B是批大小
# T是序列长度
# C是embedding的维度
B, T, C = x.size()
# 通过self.c_atten初始化q,k,v
q, k ,v  = self.c_attn(x).split(self.n_embd, dim=2)# 初始化q,k,v
# 验证数据维度与多头的数量是否一致
assert config.n_embd % config.n_head == 0
# 因为q,k,v是相同的大小的不同矩阵,通过线性映射获得初始化的q,k,v
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
# 多头切分, nh*hs=C,其中nh为头的个数,hs为头的维度
# 一定要先view (B, T, nh, hs), 再transpose,不能直接view到(B, nh, T, hs)  
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)  
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)# 其中self.n_head是配置文件中预先设置好的
self.n_head = config.n_head
# transpose是为了转换不同维度上的数据
# 此处是attention的公式
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
# 除以根号下dk保持方差稳定
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
# 因果掩码
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
# contiguous:张量在底层存储时必须连续才可以使用view
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side# 前面讲C拆解为两个矩阵相乘,现在讲得到的结果返回最初的数据形状
# contiguous()是深拷贝
# output projection
y = self.resid_dropout(self.c_proj(y))# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd)  # 输出线性映射
# regularization
self.attn_dropout = nn.Dropout(config.attn_pdrop)  # dropout正则化
self.resid_dropout = nn.Dropout(config.resid_pdrop)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/660445.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【C语言】const修饰指针的不同作用

目录 const修饰变量 const修饰指针变量 ①不用const修饰 ②const放在*的左边 ③const放在*的右边 ④*的左右两边都有const 结论 const修饰变量 变量是可以修改的,如果把变量的地址交给⼀个指针变量,通过指针变量的也可以修改这个变量。 但…

电脑文件打不开是什么原因?常见原因有这9点

在日常生活和工作中,我们经常会使用电脑来处理文件。然而,有时候我们会遇到电脑文件打不开的情况,这给我们的工作和生活带来了很大的不便。本文将为大家介绍电脑文件打不开的原因,帮助大家更好地应对这一问题。 原因1、文件格式问…

交易策略开发:如何揣摩投资心理,研究交易策略

文章目录 揣摩其他投资者的心理,首先要知道他们学习了什么投资知识。永远记住策略一定是弱于机制的。每种交易技术是如何做交易的,各位可以对号入座**马丁网格类均线,MACD等指标类价格行为类缠论类对冲套利类基本面类订单流资金流秘籍类 揣摩…

论文解读:DeepBDC小样本图像分类

Joint Distribution Matters: Deep Brownian Distance Covariance for Few-Shot Classification 摘要 由于每个新任务只给出很少的训练样例,所以few -shot分类是一个具有挑战性的问题。解决这一挑战的有效研究路线之一是专注于学习由查询图像和某些类别的少数支持…

shell脚本自动备份数据库表

今日目标:shell脚本自动备份数据库中的表并记录执行日志和mysql输出日志 编写思路: (1)shell脚本运行mysql命令 (2)脚本输出记录到日志中 (3)定时任务自动执行shell脚本 1、she…

【Tomcat与网络9】提高Tomcat启动速度的八大措施

本文我们来看一下如何对Tomcat进行调优,我们对于Tomcat的调优主要集中在三个方面:提高启动速度、提高系统稳定性和提高并发能力,后两者很多时候是相辅相成的,我们放在一起看。 Tomcat现在一般都嵌入在SpringBoot里,因…

Linux 驱动开发基础知识——总线设备驱动模型(八)

个人名片: 🦁作者简介:学生 🐯个人主页:妄北y 🐧个人QQ:2061314755 🐻个人邮箱:2061314755qq.com 🦉个人WeChat:Vir2021GKBS 🐼本文由…

动网格-尺寸函数耦合运动(五)

尺寸函数 **尺寸函数(Size Function)**通常和局部体网格重构结合使用,尺寸函数用于控制重构过程中的网格分布。简单地说,尺寸函数的功能就是在运动边界处约束网格,使其维持在一个较小的尺度,在远离运动边界处,逐步将其…

Windows存储空间不足局域网文件共享 Dism备份系统空间不足

问题情景 在日常使用中难免遇到Windows的空间不足的情况,常用办法是清理垃圾释放空间,部分场景例如我们需要使用Dism备份完整系统,所以需要非常大的存储空间不够,如果空间不够什么才是最有效的方案呢? 我们假设身边没有…

如何使用docker部署Swagger Editor并实现无公网ip远程协作编辑文档

文章目录 Swagger Editor本地接口文档公网远程访问1. 部署Swagger Editor2. Linux安装Cpolar3. 配置Swagger Editor公网地址4. 远程访问Swagger Editor5. 固定Swagger Editor公网地址 Swagger Editor本地接口文档公网远程访问 Swagger Editor是一个用于编写OpenAPI规范的开源编…

【方案】TSINGSEE青犀智能分析网关V4+EasyCVR智慧服务区一体化监控平台

随着年关将近,春运大潮已然开启,届时又伴随着大雨暴雪天气,高速路况的新闻层出不穷。由于长期驾车且高速拥堵严重,不少人就聚集在服务区休息,导致服务区流量爆满,空前的拥堵极易导致服务区瘫痪。如何利用智…

计算机毕业设计 | springboot 多功能商城 购物网站(附源码)

1, 概述 国家大力推进信息化建设的大背景下,城市网络基础设施和信息化应用水平得到了极大的提高和提高。特别是在经济发达的沿海地区,商业和服务业也比较发达,公众接受新事物的能力和消费水平也比较高。开展商贸流通产业的信息化…

OpenHarmony—编辑器使用技巧

DevEco Studio支持使用多种语言进行应用/服务的开发,包括ArkTS、JS和C/C。在编写应用/服务阶段,可以通过掌握代码编写的各种常用技巧,来提升编码效率。 代码高亮 支持对代码关键字、运算符、字符串、类、标识符、注释等进行高亮显示&#x…

少儿编程教育市场分析:行业规模有望在2025年达到约500亿元

少儿编程教育是通过编程游戏启蒙、可视化图形编程等课程,培养学生的计算思维和创新解难能力的课程。与成人的编程不同,少儿编程教育并非高等教育那样学习如何写代码、编制应用程序,而是通过编程游戏启蒙、可视化图形编程等课程,培…

C语言——标准输入函数(scanf、getchar和gets)

目录 1. 标准输入输出头文件2. scanf2.1 scanf2.1.1 函数申明2.1.2 基本用法2.1.3 返回值2.1.4 占位符2.1.5 赋值忽略符 3. getchar3.1 函数申明3.2 基本用法 4. gets4.1 函数申明4.2 基本用法 1. 标准输入输出头文件 #include <stdio.h>在使用标准输入输出函数的时候都…

摄影分享|基于Springboot的摄影分享网站设计与实现(源码+数据库+文档)

摄影分享网站目录 目录 基于Springboot的摄影分享网站设计与实现 一、前言 二、系统功能设计 三、系统实现 1、用户信息管理 2、图片素材管理 3、视频素材管理 4、公告信息管理 四、数据库设计 1、实体ER图 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐…

WebRTC系列-H264视频帧组包(视频花屏问题)

文章目录 工具函数是否满足组帧条件函数PotentialNewFrame更新丢失包记录 UpdateMissingPackets重要属性1. InsertPacket2. FindFramesWebRTC在弱网环境下传输较大的视频数据,比如:屏幕共享数据;会偶发的出现黑屏的问题;也就是说当视频的码率比较大且视频的分辨率比较高的时…

企业网络基础架构监控工具

IT 基础架构已成为提供基本业务服务的基石&#xff0c;无论是内部管理操作还是为客户托管的应用程序服务&#xff0c;监控 IT 基础设施至关重要&#xff0c;并且已经建立起来&#xff0c;SMB IT 基础架构需要简单的网络监控工具来监控性能和报告问题。通常&#xff0c;几个 IT …

Linux系统各目录作用

/etc文件系统 /etc 目录包含各种系统配置文件&#xff0c;下面说明其中的一些。其他的你应该知道它们属于哪个程序&#xff0c;并阅读该程序的m a n页。许多网络配置文件也在/etc 中。 1. /etc/rc或/etc/rc.d或/etc/rc?.d 启动、或改变运行级时运行的脚本或脚本的目录。 2. /…

UE5 虚幻游戏报错常用解决方法(幻兽帕鲁UE5报错)

在体验使用虚幻引擎5、4&#xff08;UE5/UE4&#xff09;开发的游戏如《幻兽帕鲁》时&#xff0c;玩家可能会遇到各种报错情况&#xff0c;例如黑屏、闪退、C运行时错误等。本博客将汇集一系列有效解决方案&#xff0c;通过调整虚幻引擎内置命令行参数以及优化系统环境&#xf…