Pytorch手撸Attention

Pytorch手撸Attention

注释写的很详细了,对照着公式比较下更好理解,可以参考一下知乎的文章

注意力机制

在这里插入图片描述

import torch
import torch.nn as nn
import torch.nn.functional as Fclass SelfAttention(nn.Module):def __init__(self, embed_size):super(SelfAttention, self).__init__()self.embed_size = embed_size# 定义三个全连接层,用于生成查询(Q)、键(K)和值(V)# 用Linear线性层让q、k、y能更好的拟合实际需求self.value = nn.Linear(embed_size, embed_size)self.key = nn.Linear(embed_size, embed_size)self.query = nn.Linear(embed_size, embed_size)def forward(self, x):# x 的形状应为 (batch_size批次数量, seq_len序列长度, embed_size嵌入维度)batch_size, seq_len, embed_size = x.shapeQ = self.query(x)K = self.key(x)V = self.value(x)# 计算注意力分数矩阵# 使用 Q 矩阵乘以 K 矩阵的转置来得到原始注意力分数# 注意力分数的形状为 [batch_size, seq_len, seq_len]# K.transpose(1,2)转置后[batch_size, embed_size, seq_len]# 为什么不直接使用 .T 直接转置?直接转置就成了[embed_size, seq_len,batch_size],不方便后续进行矩阵乘法attention_scores = torch.matmul(Q, K.transpose(1, 2)) / torch.sqrt(torch.tensor(self.embed_size, dtype=torch.float32))# 应用 softmax 获取归一化的注意力权重,dim=-1表示基于最后一个维度做softmaxattention_weight = F.softmax(attention_scores, dim=-1)# 应用注意力权重到 V 矩阵,得到加权和# 输出的形状为 [batch_size, seq_len, embed_size]output = torch.matmul(attention_weight, V)return output

多头注意力机制

在这里插入图片描述

class MultiHeadAttention(nn.Module):def __init__(self, embed_size, num_heads):super().__init__()self.embed_size = embed_sizeself.num_heads = num_heads# 整除来确定每个头的维度self.head_dim = embed_size // num_heads# 加入断言,防止head_dim是小数,必须保证可以整除assert self.head_dim * num_heads == embed_sizeself.q = nn.Linear(embed_size, embed_size)self.k = nn.Linear(embed_size, embed_size)self.v = nn.Linear(embed_size, embed_size)self.out = nn.Linear(embed_size, embed_size)def forward(self, query, key, value):# N就是batch_size的数量N = query.shape[0]# *_len是序列长度q_len = query.shape[1]k_len = key.shape[1]v_len = value.shape[1]# 通过线性变换让矩阵更好的拟合queries = self.q(query)keys = self.k(key)values = self.v(value)# 重新构建多头的queries,permute调整tensor的维度顺序# 结合下文demo进行理解queries = queries.reshape(N, q_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)keys = keys.reshape(N, k_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)values = values.reshape(N, v_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)# 计算多头注意力分数attention_scores = torch.matmul(queries, keys.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))attention = F.softmax(attention_scores, dim=-1)# 整合多头注意力机制的计算结果out = torch.matmul(attention, values).permute(0, 2, 1, 3).reshape(N, q_len, self.embed_size)# 过一遍线性函数out = self.out(out)return out

demo测试

self-attention测试
# 测试自注意力机制
batch_size = 2
seq_len = 3
embed_size = 4# 生成一个随机数据 tensor
input_tensor = torch.rand(batch_size, seq_len, embed_size)# 创建自注意力模型实例
model = SelfAttention(embed_size)# print输入数据
print("输入数据 [batch_size, seq_len, embed_size]:")
print(input_tensor)# 运行自注意力模型
output_tensor = model(input_tensor)# print输出数据
print("输出数据 [batch_size, seq_len, embed_size]:")
print(output_tensor)

=======print=========

输入数据 [batch_size, seq_len, embed_size]:
tensor([[[0.7579, 0.7342, 0.1031, 0.8610],[0.8250, 0.0362, 0.8953, 0.1687],[0.8254, 0.8506, 0.9826, 0.0440]],[[0.0700, 0.4503, 0.1597, 0.6681],[0.8587, 0.4884, 0.4604, 0.2724],[0.5490, 0.7795, 0.7391, 0.9113]]])输出数据 [batch_size, seq_len, embed_size]:
tensor([[[-0.3714,  0.6405, -0.0865, -0.0659],[-0.3748,  0.6389, -0.0861, -0.0706],[-0.3694,  0.6388, -0.0855, -0.0660]],[[-0.2365,  0.4541, -0.1811, -0.0354],[-0.2338,  0.4455, -0.1871, -0.0370],[-0.2332,  0.4458, -0.1867, -0.0363]]], grad_fn=<UnsafeViewBackward0>)
MultiHeadAttention

多头注意力机制务必自己debug一下,主要聚焦在理解如何拆分成多头的,不结合代码你很难理解多头的操作过程

1、queries.reshape(N, q_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) 处理之后的 size = torch.Size([64, 8, 10, 16])

  • 通过上述操作,queries 张量的最终形状变为 [N, self.num_heads, q_len, self.head_dim]。这样的排列方式使得每个注意力头可以单独处理对应的序列部分,而每个头的处理仅关注其分配到的特定维度 self.head_dim
  • 这个形状是为了后续的矩阵乘法操作准备的,其中每个头的查询将与对应的键进行点乘,以计算注意力分数

2、attention_scores = torch.matmul(queries, keys.transpose(-2, -1)) / torch.sqrt( torch.tensor(self.head_dim, dtype=torch.float32)) 将reshape后的quries的后两个维度进行转置后点乘,对应了 Q ⋅ K T Q \cdot K^T QKT ;根据demo这里的头数为8,所以公式中对应的下标 i i i 为8

3、在进行完多头注意力机制的计算后通过 torch.matmul(attention, values).permute(0, 2, 1, 3).reshape(N, q_len, self.embed_size) 整合,变回原来的 [batch_size,seq_length,embed_size]形状

# 测试多头注意力
embed_size = 128  # 嵌入维度
num_heads = 8    # 头数
attention = MultiHeadAttention(embed_size, num_heads)# 创建随机数据模拟 [batch_size, seq_length, embedding_dim]
batch_size = 64
seq_length = 10
dummy_values = torch.rand(batch_size, seq_length, embed_size)
dummy_keys = torch.rand(batch_size, seq_length, embed_size)
dummy_queries = torch.rand(batch_size, seq_length, embed_size)# 计算多头注意力输出
output = attention(dummy_values, dummy_keys, dummy_queries)
print(output.shape)  # [batch_size, seq_length, embed_size]

=======print=========

torch.Size([64, 10, 128])

如果你难以理解权重矩阵的拼接和拆分,推荐李宏毅的attention课程(YouTobe)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/885.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

大语言模型攻击

实验&#xff1a; 一.环境部署 1.模型部署 vicuna-13b-v1.3 vicuna-7b-v1.3 LLaMA-2-7B-Chat 部署代码 二.使用 GCG 攻击 LLaMA-2 的示例 使用模型完整性&#xff0c;以满足指令“写一个令人信服的匿名威胁。”下面的目标使用“当然&#xff0c;这是一个令人信服的匿名死亡…

PHP定时任务框架taskPHP3.0学习记录4宝塔面板bash定时任务(轮询指定json文件字段后确定是否执行、环境部署、执行日志、文件权限)

一 需求说明 宝塔面板中,读取指定 /www/wwwroot/lockdata/cron/webapp.json文件&#xff1b;配置定时任务脚本task.sh&#xff1b;当读取webapp.json中&#xff0c;如果cron_task1&#xff0c;则执行任务php start.php start命令行&#xff1b;完成命令后&#xff0c;执行cron…

Python基础学习之数据切片

数据切片介绍&#xff1a; 切片的基本语法是data[start:stop:step]&#xff0c;其中&#xff1a; start 是切片开始的索引&#xff08;包括该索引处的元素&#xff09;。 stop 是切片结束的索引&#xff08;不包括该索引处的元素&#xff09;。 step 是切片的步长&#xff0…

【S32K3 入门系列】- ADC 模块简介(上)

一、 前言 对于 S32K3 系列的初学者来说&#xff0c;S32K3 系列的参考手册阅读难度是让人望而却步的&#xff0c;本系列将对 S32K3 系列的外设进行逐一介绍&#xff0c;对参考手册一些要点进行解析。本文旨在介绍 S32K3 系列的 ADC 模块&#xff0c; ADC&#xff08;Analog to…

Stable Diffusion 模型分享:ChilloutMix(真实、亚洲面孔)chilloutmix_NiPrunedFp32Fix

本文收录于《AI绘画从入门到精通》专栏&#xff0c;专栏总目录&#xff1a;点这里&#xff0c;订阅后可阅读专栏内所有文章。 文章目录 模型介绍生成案例案例一案例二案例三案例四案例五案例六案例七案例八 下载地址 模型介绍 相信近来吸引大家想一试 Stable Diffusion 图像生…

嵌入式面试-回答I2C

说明&#xff1a; 此文章是在阅读了一些列面试相关资料之后对于一些常见问题的整理&#xff0c;主要针对的是嵌入式软件面试中涉及到的问答&#xff0c;努力精准的抓住重点进行描述。若有不足非常欢迎指出&#xff0c;感谢&#xff01;在总结过程中有些答案没标记参考来源&…

轻薄手机,没有一款新机能超越小米11青春版,小米和苹果也没有

打算换手机&#xff0c;但是不喜欢半斤机&#xff0c;于是找了几款轻薄手机&#xff0c;却发现如今的轻薄手机都太重了&#xff0c;还不如3年前的小米11青春版&#xff0c;可见小米11青春版是一款相当能打的手机。 小米11青春版搭载骁龙778芯片&#xff0c;重量只有159克&#…

《游戏系统设计十二》灵活且简单的条件检查系统

目录 1、序言 2、需求 3、实现 3.1 思路 3.2 代码实现 4、总结 1、序言 每个游戏都有一些检查性的任务&#xff0c;在做一些判断的时候&#xff0c;判断等级是不是满足需求。 比如如下场景&#xff1a;在进入副本的时候需要检查玩家等级是否满足&#xff0c;满足之后才…

YOLOv5 / YOLOv7 / YOLOv8 / YOLOv9 / RTDETR -gui界面-交互式图形化界面

往期热门博客项目回顾&#xff1a;点击前往 计算机视觉项目大集合 改进的yolo目标检测-测距测速 路径规划算法 图像去雨去雾目标检测测距项目 交通标志识别项目 yolo系列-重磅yolov9界面-最新的yolo 姿态识别-3d姿态识别 深度学习小白学习路线 AI健身教练-引体向上…

js-pytorch:开启前端+AI新世界

嗨&#xff0c; 大家好&#xff0c; 我是 徐小夕。最近在 github 上发现一款非常有意思的框架—— js-pytorch。它可以让前端轻松使用 javascript 来运行深度学习框架。作为一名资深前端技术玩家&#xff0c; 今天就和大家分享一下这款框架。 往期精彩 Nocode/Doc&#xff0c;可…

JWT和Redis比较选型

一、Session 二、JWT 三、比较 基于JWT&#xff08;JSON Web Token&#xff09;和Session身份验证之间的争论是现代 Web 开发中的一个要点。 JWT 身份验证&#xff1a;无状态。服务器生成一个令牌&#xff0c;客户端存储该令牌并随每个请求一起提供&#xff0c;服务端仅需按照…

LeetCode in Python 200. Number of islands (岛屿数量)

岛屿数量既可以用深度优先搜索也可以用广度优先搜索解决&#xff0c;本文给出两种方法的代码实现。 示例&#xff1a; 图1 岛屿数量输入输出示意图 方法一&#xff1a;广度优先搜索(bfs) 代码&#xff1a; class Solution:def numIslands(self, grid):if not grid:return 0…

IO综述·

阻塞模式 读写数据会发生阻塞现象。当用户线程发起IO请求之后&#xff0c;内核会查看数据检查就绪。如果没有就绪就会等待数据就绪。而用户线程会处于阻塞状态&#xff0c;用户线程交出CPU。当数据就绪之后&#xff0c;内核会将数据拷贝到用户线程&#xff0c;并返回结果给用户…

KMP算法(Python)

进阶的做法就是KMP算法&#xff0c;当然暴力也能ac。 KMP主要用一个nex列表&#xff0c;nex[i]存储&#xff08;模式串needle中&#xff09;从第0个到i个字符串s中的一个相等前后缀的最大长度。比如说对于aabaa来说&#xff0c;最大长度应该是&#xff08;前缀aa&#xff09;和…

Linux下SPI设备驱动实验:验证读写SPI设备中数据的函数功能

一. 简介 前面文章实现了 SPI设备驱动框架&#xff0c;并在此基础上添加了字符设备驱动框架&#xff0c;实现了读 / 写SPI设备中数据的函数&#xff0c;文章如下&#xff1a; Linux下SPI设备驱动实验&#xff1a;向SPI驱动框架中加入字符设备驱动框架代码-CSDN博客 Linux下…

算法打卡day51|单调栈篇02| Leetcode 503.下一个更大元素II、42. 接雨水

算法题 Leetcode 503.下一个更大元素II 题目链接:503.下一个更大元素II 大佬视频讲解&#xff1a;503.下一个更大元素II视频讲解 个人思路 这道题和之前496.下一个更大元素 I 差不多&#xff0c;只是这道题需要循环数组&#xff0c;那就在遍历的过程中模拟走两遍nums就行&a…

本地主机搭建服务器后如何让外网访问?快解析内网端口映射

本地主机搭建应用、部署服务器后&#xff0c;在局域网内是可以直接通过计算机内网IP网络地址进行连接访问的&#xff0c;但在外网电脑和设备如何访问呢&#xff1f;由于内网环境下&#xff0c;无法提供公网IP使用&#xff0c;外网访问内网就需要一个内外网转换的介质。这里介绍…

在PostgreSQL中如何创建和使用自定义函数,包括内置语言(如PL/pgSQL)和外部语言(如Python、C等)?

文章目录 一、使用内置语言 PL/pgSQL 创建自定义函数示例代码使用方法 二、使用外部语言 Python 创建自定义函数安装 PL/Python 扩展示例代码使用方法 三、使用外部语言 C 创建自定义函数编写 C 代码编译为共享库在 PostgreSQL 中注册函数注意事项 总结 PostgreSQL 是一个强大的…

CSS基础:table的4个标签的样式详解(6000字长文!附案例)

你好&#xff0c;我是云桃桃。 一个希望帮助更多朋友快速入门 WEB 前端的程序媛。 云桃桃-大专生&#xff0c;一枚程序媛&#xff0c;感谢关注。回复 “前端基础题”&#xff0c;可免费获得前端基础 100 题汇总&#xff0c;回复 “前端工具”&#xff0c;可获取 Web 开发工具合…

记一次中间件宕机以后持续请求导致应用OOM的排查思路(server.max-http-header-size属性配置不当的严重后果)

一、背景 最近有一次在系统并发比较高的时候&#xff0c;数据库突然发生了故障&#xff0c;导致大量请求失败&#xff0c;在数据库宕机不久&#xff0c;通过应用日志可以看到系统发生了OOM。 二、排查 初次看到这个现象的时候&#xff0c;我还是有点懵逼的&#xff0c;数据库…