Pytorch手撸Attention

Pytorch手撸Attention

注释写的很详细了,对照着公式比较下更好理解,可以参考一下知乎的文章

注意力机制

在这里插入图片描述

import torch
import torch.nn as nn
import torch.nn.functional as Fclass SelfAttention(nn.Module):def __init__(self, embed_size):super(SelfAttention, self).__init__()self.embed_size = embed_size# 定义三个全连接层,用于生成查询(Q)、键(K)和值(V)# 用Linear线性层让q、k、y能更好的拟合实际需求self.value = nn.Linear(embed_size, embed_size)self.key = nn.Linear(embed_size, embed_size)self.query = nn.Linear(embed_size, embed_size)def forward(self, x):# x 的形状应为 (batch_size批次数量, seq_len序列长度, embed_size嵌入维度)batch_size, seq_len, embed_size = x.shapeQ = self.query(x)K = self.key(x)V = self.value(x)# 计算注意力分数矩阵# 使用 Q 矩阵乘以 K 矩阵的转置来得到原始注意力分数# 注意力分数的形状为 [batch_size, seq_len, seq_len]# K.transpose(1,2)转置后[batch_size, embed_size, seq_len]# 为什么不直接使用 .T 直接转置?直接转置就成了[embed_size, seq_len,batch_size],不方便后续进行矩阵乘法attention_scores = torch.matmul(Q, K.transpose(1, 2)) / torch.sqrt(torch.tensor(self.embed_size, dtype=torch.float32))# 应用 softmax 获取归一化的注意力权重,dim=-1表示基于最后一个维度做softmaxattention_weight = F.softmax(attention_scores, dim=-1)# 应用注意力权重到 V 矩阵,得到加权和# 输出的形状为 [batch_size, seq_len, embed_size]output = torch.matmul(attention_weight, V)return output

多头注意力机制

在这里插入图片描述

class MultiHeadAttention(nn.Module):def __init__(self, embed_size, num_heads):super().__init__()self.embed_size = embed_sizeself.num_heads = num_heads# 整除来确定每个头的维度self.head_dim = embed_size // num_heads# 加入断言,防止head_dim是小数,必须保证可以整除assert self.head_dim * num_heads == embed_sizeself.q = nn.Linear(embed_size, embed_size)self.k = nn.Linear(embed_size, embed_size)self.v = nn.Linear(embed_size, embed_size)self.out = nn.Linear(embed_size, embed_size)def forward(self, query, key, value):# N就是batch_size的数量N = query.shape[0]# *_len是序列长度q_len = query.shape[1]k_len = key.shape[1]v_len = value.shape[1]# 通过线性变换让矩阵更好的拟合queries = self.q(query)keys = self.k(key)values = self.v(value)# 重新构建多头的queries,permute调整tensor的维度顺序# 结合下文demo进行理解queries = queries.reshape(N, q_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)keys = keys.reshape(N, k_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)values = values.reshape(N, v_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)# 计算多头注意力分数attention_scores = torch.matmul(queries, keys.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))attention = F.softmax(attention_scores, dim=-1)# 整合多头注意力机制的计算结果out = torch.matmul(attention, values).permute(0, 2, 1, 3).reshape(N, q_len, self.embed_size)# 过一遍线性函数out = self.out(out)return out

demo测试

self-attention测试
# 测试自注意力机制
batch_size = 2
seq_len = 3
embed_size = 4# 生成一个随机数据 tensor
input_tensor = torch.rand(batch_size, seq_len, embed_size)# 创建自注意力模型实例
model = SelfAttention(embed_size)# print输入数据
print("输入数据 [batch_size, seq_len, embed_size]:")
print(input_tensor)# 运行自注意力模型
output_tensor = model(input_tensor)# print输出数据
print("输出数据 [batch_size, seq_len, embed_size]:")
print(output_tensor)

=======print=========

输入数据 [batch_size, seq_len, embed_size]:
tensor([[[0.7579, 0.7342, 0.1031, 0.8610],[0.8250, 0.0362, 0.8953, 0.1687],[0.8254, 0.8506, 0.9826, 0.0440]],[[0.0700, 0.4503, 0.1597, 0.6681],[0.8587, 0.4884, 0.4604, 0.2724],[0.5490, 0.7795, 0.7391, 0.9113]]])输出数据 [batch_size, seq_len, embed_size]:
tensor([[[-0.3714,  0.6405, -0.0865, -0.0659],[-0.3748,  0.6389, -0.0861, -0.0706],[-0.3694,  0.6388, -0.0855, -0.0660]],[[-0.2365,  0.4541, -0.1811, -0.0354],[-0.2338,  0.4455, -0.1871, -0.0370],[-0.2332,  0.4458, -0.1867, -0.0363]]], grad_fn=<UnsafeViewBackward0>)
MultiHeadAttention

多头注意力机制务必自己debug一下,主要聚焦在理解如何拆分成多头的,不结合代码你很难理解多头的操作过程

1、queries.reshape(N, q_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) 处理之后的 size = torch.Size([64, 8, 10, 16])

  • 通过上述操作,queries 张量的最终形状变为 [N, self.num_heads, q_len, self.head_dim]。这样的排列方式使得每个注意力头可以单独处理对应的序列部分,而每个头的处理仅关注其分配到的特定维度 self.head_dim
  • 这个形状是为了后续的矩阵乘法操作准备的,其中每个头的查询将与对应的键进行点乘,以计算注意力分数

2、attention_scores = torch.matmul(queries, keys.transpose(-2, -1)) / torch.sqrt( torch.tensor(self.head_dim, dtype=torch.float32)) 将reshape后的quries的后两个维度进行转置后点乘,对应了 Q ⋅ K T Q \cdot K^T QKT ;根据demo这里的头数为8,所以公式中对应的下标 i i i 为8

3、在进行完多头注意力机制的计算后通过 torch.matmul(attention, values).permute(0, 2, 1, 3).reshape(N, q_len, self.embed_size) 整合,变回原来的 [batch_size,seq_length,embed_size]形状

# 测试多头注意力
embed_size = 128  # 嵌入维度
num_heads = 8    # 头数
attention = MultiHeadAttention(embed_size, num_heads)# 创建随机数据模拟 [batch_size, seq_length, embedding_dim]
batch_size = 64
seq_length = 10
dummy_values = torch.rand(batch_size, seq_length, embed_size)
dummy_keys = torch.rand(batch_size, seq_length, embed_size)
dummy_queries = torch.rand(batch_size, seq_length, embed_size)# 计算多头注意力输出
output = attention(dummy_values, dummy_keys, dummy_queries)
print(output.shape)  # [batch_size, seq_length, embed_size]

=======print=========

torch.Size([64, 10, 128])

如果你难以理解权重矩阵的拼接和拆分,推荐李宏毅的attention课程(YouTobe)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/885.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

大语言模型攻击

实验&#xff1a; 一.环境部署 1.模型部署 vicuna-13b-v1.3 vicuna-7b-v1.3 LLaMA-2-7B-Chat 部署代码 二.使用 GCG 攻击 LLaMA-2 的示例 使用模型完整性&#xff0c;以满足指令“写一个令人信服的匿名威胁。”下面的目标使用“当然&#xff0c;这是一个令人信服的匿名死亡…

PHP定时任务框架taskPHP3.0学习记录4宝塔面板bash定时任务(轮询指定json文件字段后确定是否执行、环境部署、执行日志、文件权限)

一 需求说明 宝塔面板中,读取指定 /www/wwwroot/lockdata/cron/webapp.json文件&#xff1b;配置定时任务脚本task.sh&#xff1b;当读取webapp.json中&#xff0c;如果cron_task1&#xff0c;则执行任务php start.php start命令行&#xff1b;完成命令后&#xff0c;执行cron…

Python基础学习之数据切片

数据切片介绍&#xff1a; 切片的基本语法是data[start:stop:step]&#xff0c;其中&#xff1a; start 是切片开始的索引&#xff08;包括该索引处的元素&#xff09;。 stop 是切片结束的索引&#xff08;不包括该索引处的元素&#xff09;。 step 是切片的步长&#xff0…

【S32K3 入门系列】- ADC 模块简介(上)

一、 前言 对于 S32K3 系列的初学者来说&#xff0c;S32K3 系列的参考手册阅读难度是让人望而却步的&#xff0c;本系列将对 S32K3 系列的外设进行逐一介绍&#xff0c;对参考手册一些要点进行解析。本文旨在介绍 S32K3 系列的 ADC 模块&#xff0c; ADC&#xff08;Analog to…

FreeLearning PHP 译文集翻译完成

使用 PHP 和 jQuery 构建游戏化 Web 站点使用 PHP7 构建 REST Web 服务PHP 入门指南CouchDB 和 PHP Web 开发初学者指南Vue2 和 Laravel5 全栈开发函数式 PHPAngular6 和 Laravel5 Web 全栈开发实用指南FuelPHP 高效开发学习手册PHP 数据对象学习手册PHP7 高性能开发学习手册La…

Mysql:ON DUPLICATE KEY UPDATE

使用 INSERT 语句尝试插入一个已经存在的唯一键或主键时&#xff0c;MySQL 会抛出一个错误。但如果你使用了 ON DUPLICATE KEY UPDATE&#xff0c;MySQL 就会执行更新操作&#xff0c;而不是插入新的记录。 这种语法只在存在重复的唯一键或主键时触发更新操作。如果没有发现重复…

八皇后问题(:java实现

开始搞算法&#xff01; 文章目录 一、问题描述二、最简单的思路三、Java实现四、总结反思 一、问题描述 八皇后问题是一个古老而著名的问题&#xff0c;由国际象棋棋手马克斯贝瑟尔于1848年提出&#xff0c;它是回溯算法的典型案例。问题要求在88的国际象棋上摆放8个皇后&…

Stable Diffusion 模型分享:ChilloutMix(真实、亚洲面孔)chilloutmix_NiPrunedFp32Fix

本文收录于《AI绘画从入门到精通》专栏&#xff0c;专栏总目录&#xff1a;点这里&#xff0c;订阅后可阅读专栏内所有文章。 文章目录 模型介绍生成案例案例一案例二案例三案例四案例五案例六案例七案例八 下载地址 模型介绍 相信近来吸引大家想一试 Stable Diffusion 图像生…

嵌入式面试-回答I2C

说明&#xff1a; 此文章是在阅读了一些列面试相关资料之后对于一些常见问题的整理&#xff0c;主要针对的是嵌入式软件面试中涉及到的问答&#xff0c;努力精准的抓住重点进行描述。若有不足非常欢迎指出&#xff0c;感谢&#xff01;在总结过程中有些答案没标记参考来源&…

轻薄手机,没有一款新机能超越小米11青春版,小米和苹果也没有

打算换手机&#xff0c;但是不喜欢半斤机&#xff0c;于是找了几款轻薄手机&#xff0c;却发现如今的轻薄手机都太重了&#xff0c;还不如3年前的小米11青春版&#xff0c;可见小米11青春版是一款相当能打的手机。 小米11青春版搭载骁龙778芯片&#xff0c;重量只有159克&#…

傅里叶变换的本质。傅里叶案例。数字信号和模拟信号应用数字信号和模拟信号区别和优势。

目录 傅里叶案例 案例:音频降噪处理 案例:图像 积分和求和的关系

《游戏系统设计十二》灵活且简单的条件检查系统

目录 1、序言 2、需求 3、实现 3.1 思路 3.2 代码实现 4、总结 1、序言 每个游戏都有一些检查性的任务&#xff0c;在做一些判断的时候&#xff0c;判断等级是不是满足需求。 比如如下场景&#xff1a;在进入副本的时候需要检查玩家等级是否满足&#xff0c;满足之后才…

【npm淘宝源最新解决方案】 https://registry.npm.taobao.org此地址已失效

【npm淘宝源最新解决方案】 https://registry.npm.taobao.org此地址已失效 最新淘宝源&#xff1a; npm config set registry https://registry.npmmirror.com

YOLOv5 / YOLOv7 / YOLOv8 / YOLOv9 / RTDETR -gui界面-交互式图形化界面

往期热门博客项目回顾&#xff1a;点击前往 计算机视觉项目大集合 改进的yolo目标检测-测距测速 路径规划算法 图像去雨去雾目标检测测距项目 交通标志识别项目 yolo系列-重磅yolov9界面-最新的yolo 姿态识别-3d姿态识别 深度学习小白学习路线 AI健身教练-引体向上…

js-pytorch:开启前端+AI新世界

嗨&#xff0c; 大家好&#xff0c; 我是 徐小夕。最近在 github 上发现一款非常有意思的框架—— js-pytorch。它可以让前端轻松使用 javascript 来运行深度学习框架。作为一名资深前端技术玩家&#xff0c; 今天就和大家分享一下这款框架。 往期精彩 Nocode/Doc&#xff0c;可…

JWT和Redis比较选型

一、Session 二、JWT 三、比较 基于JWT&#xff08;JSON Web Token&#xff09;和Session身份验证之间的争论是现代 Web 开发中的一个要点。 JWT 身份验证&#xff1a;无状态。服务器生成一个令牌&#xff0c;客户端存储该令牌并随每个请求一起提供&#xff0c;服务端仅需按照…

LeetCode in Python 200. Number of islands (岛屿数量)

岛屿数量既可以用深度优先搜索也可以用广度优先搜索解决&#xff0c;本文给出两种方法的代码实现。 示例&#xff1a; 图1 岛屿数量输入输出示意图 方法一&#xff1a;广度优先搜索(bfs) 代码&#xff1a; class Solution:def numIslands(self, grid):if not grid:return 0…

IO综述·

阻塞模式 读写数据会发生阻塞现象。当用户线程发起IO请求之后&#xff0c;内核会查看数据检查就绪。如果没有就绪就会等待数据就绪。而用户线程会处于阻塞状态&#xff0c;用户线程交出CPU。当数据就绪之后&#xff0c;内核会将数据拷贝到用户线程&#xff0c;并返回结果给用户…

使用MyBatis插入数据并返回自动生成的ID

在使用MyBatis进行数据库操作时&#xff0c;经常会遇到需要插入数据并返回自动生成的主键ID的情况。为了解决这个问题&#xff0c;我们可以使用MyBatis提供的useGeneratedKeys和keyProperty属性。本文将介绍这两个属性的作用以及如何使用它们来实现插入数据并返回自动生成的ID。…

KMP算法(Python)

进阶的做法就是KMP算法&#xff0c;当然暴力也能ac。 KMP主要用一个nex列表&#xff0c;nex[i]存储&#xff08;模式串needle中&#xff09;从第0个到i个字符串s中的一个相等前后缀的最大长度。比如说对于aabaa来说&#xff0c;最大长度应该是&#xff08;前缀aa&#xff09;和…