简单实现Transformer的自注意力

简单实现Transformer的自注意力

关注{晓理紫|小李子},获取技术推送信息,如感兴趣,请转发给有需要的同学,谢谢支持!!

如果你感觉对你有所帮助,请关注我。

在这里插入图片描述
源码获取:VX关注并回复chatgpt-0获得

  • 实现的功能

假如有八个令牌,现在想让每一个令牌至于其前面的通信,如第5个令牌不与6,7,8位置的令牌通信(这是未来的令牌),只与4,3,2,1位置的令牌通信。因此只能通过以前的上下文信息猜测后面的;一种弱的通信方式是取前面的平局值。如5位置==5,4,3,2,1位置上的平局值。

  • 实现
    • 循环的版本
import torch
from torch.nn import functional as F
import torch.nn as nn
torch.manual_seed(1337)B,T,C = 4,8,2 #batch,time,channels 
x = torch.randn(B,T,C)
xbow = torch.zeros((B,T,C))
print(f'x: {x[0]}')
for b in range(B):for t in range(T):xprev = x[b,:t+1] #()t,Cxbow[b,t] = torch.mean(xprev,0)
print(f'xbow: {xbow[0]}')#结果
x: tensor([[ 0.1808, -0.0700],[-0.3596, -0.9152],[ 0.6258,  0.0255],[ 0.9545,  0.0643],[ 0.3612,  1.1679],[-1.3499, -0.5102],[ 0.2360, -0.2398],[-0.9211,  1.5433]])
xbow: tensor([[ 0.1808, -0.0700],[-0.0894, -0.4926],[ 0.1490, -0.3199],[ 0.3504, -0.2238],[ 0.3525,  0.0545],[ 0.0688, -0.0396],[ 0.0927, -0.0682],[-0.0341,  0.1332]])
# 每一行至于自己以及自己以前的数据进行通信
  • 通过数据矩阵高效实现
a = torch.tril(torch.ones(3,3)) #下三角函数
a = a/torch.sum(a,1,keepdim=True) #对a求平均数
b = torch.randint(0,10,(3,2)).float()
c = a @ bprint(f'a:{a}')
print(f'b:{b}')
print(f'c:{c}')#结果a:tensor([[1.0000, 0.0000, 0.0000],[0.5000, 0.5000, 0.0000],[0.3333, 0.3333, 0.3333]])
b:tensor([[0., 4.],[1., 2.],[5., 5.]])
c:tensor([[0.0000, 4.0000],[0.5000, 3.0000],[2.0000, 3.6667]])
  • 使用Softmax
tril = torch.tril(torch.ones(T,T))  #下三角函数
print(f'tril:{tril}')wei = torch.zeros((T,T))
wei = wei.masked_fill(tril==0,float('-inf'))# mask填充,对于tril为0的填充负无穷大
print(f'wei: {wei}')
wei = F.softmax(wei,dim=-1)# softmax对没一行的每个元素进行求幂,在求平均数
print(f'wei: {wei}')
xbow3 = wei @ xprint(f'xbow3: {xbow3}')
print(torch.allclose(xbow,xbow3))
  • 单头自注意力

    • 上面的自注意力是通过相同的方式获取以往的信息。但是实际上并不希望是统一的方式,因为不同的token标记会发现其他不同的标记。
    • 例如:我是元音,那么也许我正在寻找过去的辅音,或与我想知道这些辅音是什么。希望这些信息流向我,所以我现在想以依赖数据的方式收集过去的信息。这就是自注意力解决的问题。
    • 方式如下:每个节点或每个位置的每个令牌都会发出两个向量,一个发出查询query,一个发出键key。查询向量粗略的说就是我要找的东西,键向量粗略的讲就是我包含什么。
    • 现在在序列中获取这些标记之间的亲和力的方式基本上只是在键和查询之间做一个点乘积。所以我的查询与所有的其他tokens令牌的所有键进行点乘积。并且点积方式变了。如果键和查询有点对齐,它们将交互到非常高的数量,然后我将了解有关特定标记的更多信息,而不是其他不再序列中的任何其他标记。
head_size = 16
key = nn.Linear(C,head_size,bias=False)
query = nn.Linear(C,head_size,bias=False)k = key(x) #(B,T,16)
q = key(x) #(B,T,16)
wei = q @ k.transpose(-2,-1) #转置时最后两个维度为负 (B,T,16) @ (B,16,T) ---> (B,T,T)tril = torch.tril(torch.ones(T,T))  #下三角函数
wei = wei.masked_fill(tril==0,float('-inf'))# mask填充,对于tril为0的填充负无穷大 主要是为了避免关注后面信息。如果想让所有节点进行交流删除词句。解码器中保留,编码器删除允许所有节点通信
wei = F.softmax(wei,dim=-1)# softmax对没一行的每个元素进行求幂,在求平均数 主要为了避免关注过小的信息主要是负数
print(f'wei: {wei[0]}')
out = wei @ x
print(f'out:{out.shape}')
  • 但是在真是中并不聚合到x而是计算一个v.x看作为该令牌的私人信息,与不同头交流的信息存储在v中
head_size = 16
key = nn.Linear(C,head_size,bias=False)
query = nn.Linear(C,head_size,bias=False)k = key(x) #(B,T,16)
q = key(x) #(B,T,16)
wei = q @ k.transpose(-2,-1) #转置时最后两个维度为负 (B,T,16) @ (B,16,T) ---> (B,T,T)tril = torch.tril(torch.ones(T,T))  #下三角函数
wei = wei.masked_fill(tril==0,float('-inf'))# mask填充,对于tril为0的填充负无穷大 主要是为了避免关注后面信息。如果想让所有节点进行交流删除词句。解码器中保留,编码器删除允许所有节点通信
wei = F.softmax(wei,dim=-1)# softmax对没一行的每个元素进行求幂,在求平均数 主要为了避免关注过小的信息主要是负数
print(f'wei: {wei[0]}')
value = nn.Linear(C,head_size,bias=False)
v = value(x)
out = wei @ v
print(f'out:{out.shape}')

简单实现自注意力

关注{晓理紫|小李子},获取技术推送信息,如感兴趣,请转发给有需要的同学,谢谢支持!!

如果你感觉对你有所帮助,请关注我。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/716636.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

二叉树的右视图,力扣

目录 题目: 我们直接看题解吧: 快速理解解题思路小建议: 审题目事例提示: 解题方法: 解题分析: 解题思路: 代码实现(DFS): 代码1: 补充说明: 代码2&#xff1…

AI:148-开发一种智能语音助手,能够理解和执行复杂任务

🚀点击这里跳转到本专栏,可查阅专栏顶置最新的指南宝典~ 🎉🎊🎉 你的技术旅程将在这里启航! 从基础到实践,深入学习。无论你是初学者还是经验丰富的老手,对于本专栏案例和项目实践都有参考学习意义。 ✨✨✨ 每一个案例都附带关键代码,详细讲解供大家学习,希望…

[技巧]Arcgis之图斑四至点批量计算

前言 上一篇介绍了arcgis之图斑四至范围计算,这里介绍的图斑四至点的计算及获取,两者之间还是有差异的。 [技巧]Arcgis之图斑四至范围计算 这里说的四至点指的是图斑最东、最西、最南、最北的四个地理位置点坐标,如下图: 四至点…

STM32进阶笔记——复位、时钟与滴答定时器

本专栏争取每周三更新直到更新完成,期待大家的订阅关注,欢迎互相学习交流。 目录 一、复位1.1 软件复位1.2 低功耗管理复位 二、时钟2.1 系统时钟(SYSCLK)选择2.2 系统时钟初始化 三、滴答定时器(Systick)3.1 SysTick部分寄存器3.…

部署bpmn项目实现activiti流程图的在线绘制

本教程基于centos7.6环境中完成 github开源项目: https://github.com/Yiuman/bpmn-vue-activiti软件:git、docker 1. 下载源代码 git clone https://github.com/Yiuman/bpmn-vue-activiti.git2. 修改Dockerfile文件 声明基础镜像,将项目打包&#xff…

EasyRecovery数据恢复软件有什么优势呢?

EasyRecovery数据恢复软件具有以下优势: 强大的恢复能力:EasyRecovery采用先进的扫描和恢复技术,能够深度扫描存储设备,寻找并恢复因各种原因丢失的数据。无论是误删除、格式化、分区损坏还是病毒感染,它都能提供有效…

设计模式(十一)策略模式

请直接看原文:设计模式(十一)策略模式_某移动支付系统在实现账户资金转入和转出时需要进行身份验证,该系统为用户提供了-CSDN博客 ----------------------------------------------------------------------------------------------------------------…

SpringMVC 学习(十一)之数据校验

目录 1 数据校验介绍 2 普通校验 3 分组校验 4 参考文档 1 数据校验介绍 在实际的项目中,一般会有两种校验数据的方式:客户端校验和服务端校验 客户端校验:这种校验一般是在前端页面使用 JS 代码进行校验,主要是验证输入数据…

文物预防性保护系统方案的需求分析

没有文物保存环境监测,就不能实施有效的文物预防性保护。因此要建立文物预防性保护体系,一定要先有良好的文物状态监测制度,进而进行科学有效的文物保护管理。所以,导入文物预防性保护监测与调控系统,首先就是要针对文物进行全年温度、湿度、光照等关键参…

使用Zint库生成一维码/条形码

下面代码是是使用 Zint 库生成 Code 128 类型的条形码&#xff0c;并将生成的条形码保存为 output.bmp 文件。下面是对代码的详细解释&#xff1a; #include 和 #include <zint.h>&#xff1a;这两行代码包含了所需的头文件&#xff0c;分别是标准输入输出流的头文件和 Z…

LeetCode---【链表的操作】

目录 206反转链表【链表结构基础】21合并两个有序链表【递归】我的答案【错误】自己修改【超出时间限制】在官方那里学到的【然后自己复写,错误】对照官方【自己修改】 160相交链表【未理解题目目的】在b站up那里学到的【然后自己复写,错误】【超出时间限制】对照官方【自己修改…

(C语言)qsort函数模拟实现

前言 我们需先了解qsort函数 qsort函数详解&#xff1a;http://t.csdnimg.cn/rTNv9 qsort函数可以排序多种数据类型&#xff0c;很是神奇&#xff0c;这是为什么&#xff0c;我们在里模拟实现这样的功能 目录 1. qsort函数模拟实现 2. 我们使用bubble_sort函数排序整形数…

Sunshine v0.21.0 安装卡住,闪退的问题解决

上期博客讲了如何利用 Sunshine 和 Moonlight 实现 iPad 当作 Windows 副屏&#xff0c;用官方 Windows installer 安装 Sunshine 过程中&#xff0c;遇到了安装卡住&#xff08;这个是因为需要国外网络环境&#xff09;&#xff0c;安装后运行闪退的问题。 Sunshine 下载地址…

OpenCV 4基础篇| OpenCV图像的裁切

目录 1. Numpy切片1.1 注意事项1.2 代码示例 2. cv2.selectROI()2.1 语法结构2.2 注意事项2.3 代码示例 3. Pillow.crop3.1 语法结构3.2 注意事项3.3 代码示例 4. 扩展示例&#xff1a;单张大图裁切成多张小图5. 总结 1. Numpy切片 语法结构&#xff1a; retval img[y:yh, x…

以目标检测和分类任务为例理解One-Hot Code

在目标检测和分类任务中&#xff0c;每一个类别都需要一个编码来表示&#xff0c;同时&#xff0c;这个编码会用来计算网络的loss。比如有猫&#xff0c;狗&#xff0c;猪三种动物&#xff0c;这三种动物相互独立&#xff0c;在分类中&#xff0c;将其中任意一种分类为其他都同…

YOLOv9独家原创改进|使用可改变核卷积AKConv改进RepNCSPELAN4

专栏介绍&#xff1a;YOLOv9改进系列 | 包含深度学习最新创新&#xff0c;主力高效涨点&#xff01;&#xff01;&#xff01; 一、改进点介绍 AKConv是一种具有任意数量的参数和任意采样形状的可变卷积核&#xff0c;对不规则特征有更好的提取效果。 RepNCSPELAN4是YOLOv9中的…

2023年12月CCF-GESP编程能力等级认证Scratch图形化编程四级真题解析

一、单选题(共15题,共30分) 第1题 现代计算机是指电子计算机,它所基于的是( )体系结构。 A:艾伦图灵 B:冯诺依曼 C:阿塔纳索夫 D:埃克特-莫克利 答案:B 第2题 默认小猫角色,执行下列程序,以下说法正确的是? ( ) A:舞台上会出现无数个小猫 B:舞台只会出现…

java spring 02. AbstractApplicationContext

spring创建对象的顺序&#xff0c;先创建beanfactory&#xff0c;再会把xml文件读取到spring。 public ClassPathXmlApplicationContext(String[] configLocations, boolean refresh, Nullable ApplicationContext parent)throws BeansException {//调用父类的构造方法super(p…

Redis常用指令,jedis与持久化

1.redis常用指令 第一个是key的常用指令&#xff0c;第二个是数据库的常用指令 前面的那些指令都是针对某一个数据类型操作的&#xff0c;现在的都是对所有的操作的 1.key常用指令 key应该设计哪些操作 key是一个字符串&#xff0c;通过key获取redis中保存的数据 对于key…

flink重温笔记(九):Flink 高级 API 开发——flink 四大基石之WaterMark(Time为核心)

Flink学习笔记 前言&#xff1a;今天是学习 flink 的第 9 天啦&#xff01;学习了 flink 四大基石之 Time的应用—> Watermark&#xff08;水印&#xff0c;也称水位线&#xff09;&#xff0c;主要是解决数据由于网络延迟问题&#xff0c;出现数据乱序或者迟到数据现象&…