自注意力机制、多头自注意力机制、填充掩码 Python实现

原理讲解

【Transformer系列(2)】注意力机制、自注意力机制、多头注意力机制、通道注意力机制、空间注意力机制超详细讲解

自注意力机制

import torch
import torch.nn as nn# 自注意力机制
class SelfAttention(nn.Module):def __init__(self, input_dim):super(SelfAttention, self).__init__()self.query = nn.Linear(input_dim, input_dim)self.key = nn.Linear(input_dim, input_dim)self.value = nn.Linear(input_dim, input_dim)        def forward(self, x, mask=None):batch_size, seq_len, input_dim = x.shapeq = self.query(x)k = self.key(x)v = self.value(x)atten_weights = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(input_dim, dtype=torch.float))if mask is not None:mask = mask.unsqueeze(1)attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))        atten_scores = torch.softmax(atten_weights, dim=-1)attented_values = torch.matmul(atten_scores, v)return attented_values# 自动填充函数
def pad_sequences(sequences, max_len=None):batch_size = len(sequences)input_dim = sequences[0].shape[-1]lengths = torch.tensor([seq.shape[0] for seq in sequences])max_len = max_len or lengths.max().item()padded = torch.zeros(batch_size, max_len, input_dim)for i, seq in enumerate(sequences):seq_len = seq.shape[0]padded[i, :seq_len, :] = seqmask = torch.arange(max_len).expand(batch_size, max_len) < lengths.unsqueeze(1)return padded, mask.long()if __name__ == '__main__':batch_size = 2seq_len = 3input_dim = 128seq_len_1 = 3seq_len_2 = 5x1 = torch.randn(seq_len_1, input_dim)            x2 = torch.randn(seq_len_2, input_dim)target_seq_len = 10    padded_x, mask = pad_sequences([x1, x2], target_seq_len)selfattention = SelfAttention(input_dim)    attention = selfattention(padded_x)print(attention)

多头自注意力机制

import torch
import torch.nn as nn# 定义多头自注意力模块
class MultiHeadSelfAttention(nn.Module):def __init__(self, input_dim, num_heads):super(MultiHeadSelfAttention, self).__init__()self.num_heads = num_headsself.head_dim = input_dim // num_headsself.query = nn.Linear(input_dim, input_dim)self.key = nn.Linear(input_dim, input_dim)self.value = nn.Linear(input_dim, input_dim)        def forward(self, x, mask=None):batch_size, seq_len, input_dim = x.shape# 将输入向量拆分为多个头## transpose(1,2)后变成 (batch_size, self.num_heads, seq_len, self.head_dim)形式q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)k = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)v = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)# 计算注意力权重attn_weights = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))# 应用 padding maskif mask is not None:# mask: (batch_size, seq_len) -> (batch_size, 1, 1, seq_len) 用于广播mask = mask.unsqueeze(1).unsqueeze(2)  # 扩展维度以便于广播attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))        attn_scores = torch.softmax(attn_weights, dim=-1)# 注意力加权求和attended_values = torch.matmul(attn_scores, v).transpose(1, 2).contiguous().view(batch_size, seq_len, input_dim)return attended_values# 自动填充函数
def pad_sequences(sequences, max_len=None):batch_size = len(sequences)input_dim = sequences[0].shape[-1]lengths = torch.tensor([seq.shape[0] for seq in sequences])max_len = max_len or lengths.max().item()padded = torch.zeros(batch_size, max_len, input_dim)for i, seq in enumerate(sequences):seq_len = seq.shape[0]padded[i, :seq_len, :] = seqmask = torch.arange(max_len).expand(batch_size, max_len) < lengths.unsqueeze(1)return padded, mask.long()if __name__ == '__main__':heads = 2batch_size = 2seq_len_1 = 3seq_len_2 = 5input_dim = 128x1 = torch.randn(seq_len_1, input_dim)            x2 = torch.randn(seq_len_2, input_dim)target_seq_len = 10    padded_x, mask = pad_sequences([x1, x2], target_seq_len)multiheadattention = MultiHeadSelfAttention(input_dim, heads)attention = multiheadattention(padded_x, mask)    print(attention)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/77796.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【大模型】Browser-Use AI驱动的浏览器自动化工具

Browser-Use AI驱动的浏览器自动化工具 1. 项目概述2. 核心架构3. 实战指南3.1 环境安装3.2 快速启动3.3 进阶功能 4. 常见问题与解决5. 项目优势与局限6. 扩展资源7. 总结 1. 项目概述 项目地址&#xff1a;browser-use Browser-Use 是一个开源工具&#xff0c;旨在通过 AI 代…

ubuntu20.04安装安装x11vnc服务基于gdm3或lightdm这两种主流的显示管理器。

前言&#xff1a;在服务端安装vnc服务&#xff0c;可以方便的远程操作服务器&#xff0c;而不用非要插上显示器才行。所以在服务器上安装vnc是很重要的。在ubuntu20中&#xff0c;默认的显示管理器已经变为gdm3&#xff0c;它可以带来与 GNOME 无缝衔接的体验&#xff0c;强调功…

用银河麒麟 LiveCD 快速查看原系统 IP 和打印机配置

原文链接&#xff1a;用银河麒麟 LiveCD 快速查看原系统 IP 和打印机配置 Hello&#xff0c;大家好啊&#xff01;今天给大家带来一篇在银河麒麟操作系统的 LiveCD 或系统试用镜像环境下&#xff0c;如何查看原系统中电脑的 IP 地址与网络打印机 IP 地址的实用教程。在系统损坏…

C++——STL——容器deque(简单介绍),适配器——stack,queue,priority_queue

目录 1.deque&#xff08;简单介绍&#xff09; 1.1 deque介绍&#xff1a; 1.2 deque迭代器底层 1.2.1 那么比如说用迭代器实现元素的遍历&#xff0c;是如何实现的呢&#xff1f; 1.2.2 头插 1.2.3 尾插 1.2.4 实现 ​编辑 1.2.5 总结 2.stack 2.1 函数介绍 2.2 模…

Java并发编程-线程池

Java并发编程-线程池 线程池运行原理线程池生命周期线程池的核心参数线程池的阻塞队列线程池的拒绝策略线程池的种类newFixedThreadPoolnewSingleThreadExecutornewCachedThreadPoolnewScheduledThreadPool 创建线程池jdk的Executors(不建议&#xff0c;会导致OOM)jdk的ThreadP…

【前沿】成像“跨界”测量——扫焦光场成像

01 背景 眼睛是人类认识世界的重要“窗口”&#xff0c;而相机作为眼睛的“延伸”&#xff0c;已经成为生产生活中最常见的工具之一&#xff0c;广泛应用于工业检测、医疗诊断与影音娱乐等领域。传统相机通常以“所见即所得”的方式记录场景&#xff0c;传感器捕捉到的二维图像…

TM1640学习手册及示例代码

数据手册 TM1640数据手册 数据手册解读 这里我们看管脚定义DIN和SCLK&#xff0c;一个数据线一个时钟线 SEG1~SEG8为段码&#xff0c;GRID1~GRID16为位码&#xff08;共阴极情况下&#xff09; 这里VDD给5V 数据指令 数据命令设置 地址命令设置 显示控制命令 共阴极硬件连接图…

uni-app 开发企业级小程序课程

课程大小&#xff1a;7.7G 课程下载&#xff1a;https://download.csdn.net/download/m0_66047725/90616393 更多资源下载&#xff1a;关注我 备注&#xff1a;缺少两个视频5-14 tabs组件进行基本的数据展示和搜索历史 处理searchData的删除操作 1-1导学.mp4 2-10小程序内…

判断点是否在多边形内

代码段解析: const intersect = ((yi > y) !== (yj > y)) && (x < (xj - xi) * (y - yi) / (yj - yi) + xi); 第一部分:(yi > y) !== (yj > y) 作用:检查点 (x,y) 的垂直位置是否跨越多边形的当前边。 yi > y 和 yj > y 分别检查边的两个端…

【redis】集群 如何搭建集群详解

文章目录 集群搭建1. 创建目录和配置2. 编写 docker-compose.yml完整配置文件 3. 启动容器4. 构建集群超时 集群搭建 基于 docker 在我们云服务器上搭建出一个 redis 集群出来 当前节点&#xff0c;主要是因为我们只有一个云服务器&#xff0c;搞分布式系统&#xff0c;就比较…

[langchain教程]langchain03——用langchain构建RAG应用

RAG RAG过程 离线过程&#xff1a; 加载文档将文档按一定条件切割成片段将切割的文本片段转为向量&#xff0c;存入检索引擎&#xff08;向量库&#xff09; 在线过程&#xff1a; 用户输入Query&#xff0c;将Query转为向量从向量库检索&#xff0c;获得相似度TopN信息将…

C语言复习笔记--字符函数和字符串函数(下)

在上篇我们了解了部分字符函数及字符串函数,下面我们来看剩下的字符串函数. strstr 的使用和模拟实现 老规矩,我们先了解一下strstr这个函数,下面看下这个函数的函数原型. char * strstr ( const char * str1, const char * str2); 如果没找到就返回NULL指针. 下面我们看下它的…

FreeRTOS中的优先级翻转问题及其解决方案:互斥信号量详解

FreeRTOS中的优先级翻转问题及其解决方案&#xff1a;互斥信号量详解 在实时操作系统中&#xff0c;任务调度是基于优先级的&#xff0c;高优先级任务应该优先于低优先级任务执行。但在实际应用中&#xff0c;有时会出现"优先级翻转"的现象&#xff0c;严重影响系统…

深度学习-全连接神经网络

四、参数初始化 神经网络的参数初始化是训练深度学习模型的关键步骤之一。初始化参数&#xff08;通常是权重和偏置&#xff09;会对模型的训练速度、收敛性以及最终的性能产生重要影响。下面是关于神经网络参数初始化的一些常见方法及其相关知识点。 官方文档参考&#xff1…

GIS开发笔记(9)结合osg及osgEarth实现三维球经纬网格绘制及显隐

一、实现效果 二、实现原理 按照5的间隔分别创建经纬线的节点,挂在到组合节点,组合节点挂接到根节点。可以根据需要设置间隔度数和线宽、线的颜色。 三、参考代码 //创建经纬线的节点 osg::Node *GlobeWidget::createGraticuleGeometry(float interval, const osg::Vec4 …

《Relay IR的基石:expr.h 中的表达式类型系统剖析》

TVM Relay源码深度解读 文章目录 TVM Relay源码深度解读一 、从Constant看Relay表达式的设计哲学1. 类定义概述2. ConstantNode 详解1. 核心成员2. 关键方法3. 类型系统注册 3. Constant 详解1. 核心功能 二. 核心内容概述(1) Relay表达式基类1. RelayExprNode 和 RelayExpr 的…

自动驾驶地图数据传输协议ADASIS v2

ADASIS&#xff08;Advanced Driver Assistance Systems Interface Specification&#xff09;直译过来就是 ADAS 接口规格&#xff0c;它要负责的东西其实很简单&#xff0c;就是为自动驾驶车辆提供前方道路交通相关的数据&#xff0c;这些数据被抽象成一个标准化的概念&#…

Flutter 状态管理 Riverpod

Android Studio版本 Flutter SDK 版本 将依赖项添加到您的应用 flutter pub add flutter_riverpod flutter pub add riverpod_annotation flutter pub add dev:riverpod_generator flutter pub add dev:build_runner flutter pub add dev:custom_lint flutter pub add dev:riv…

【EasyPan】MySQL主键与索引核心作用解析

【EasyPan】项目常见问题解答&#xff08;自用&持续更新中…&#xff09;汇总版 MySQL主键与索引核心作用解析 一、主键&#xff08;PRIMARY KEY&#xff09;核心作用 1. 数据唯一标识 -- 创建表时定义主键 CREATE TABLE users (id INT AUTO_INCREMENT PRIMARY KEY,use…

IcePlayer音乐播放器项目分析及学习指南

IcePlayer音乐播放器项目分析及学习指南 项目概述 IcePlayer是一个基于Qt5框架开发的音乐播放器应用程序&#xff0c;使用Visual Studio 2013作为开发环境。该项目实现了音乐播放、歌词显示、专辑图片获取等功能&#xff0c;展现了桌面应用程序开发的核心技术和设计思想。 技…