002 self-attention自注意力

目录

一、环境

二、self-attention原理

三、完整代码


一、环境

本文使用环境为:

  • Windows10
  • Python 3.9.17
  • torch 1.13.1+cu117
  • torchvision 0.14.1+cu117

二、self-attention原理

自注意力(Self-Attention)操作是基于 Transformer 的机器翻译模型的基本操作,在源语言的编
码和目标语言的生成中频繁地被使用以建模源语言、目标语言任意两个单词之间的依赖关系。给
定由单词语义嵌入及其位置编码叠加得到的输入表示 {xi ∈ Rd},为了实现对上下文语义依赖的建模,进一步引入在自注意力机制中涉及到的三个元素:查询 qi(Query),键 ki(Key),值 vi (Value)。在编码输入序列中每一个单词的表示的过程中,这三个元素用于计算上下文单词所对应的权重得分。直观地说,这些权重反映了在编码当前单词的表示时,对于上下文不同部分所需要的关注程度。具体来说,如图所示,通过三个线性变换 WQ,WK ,WV 将输入序列中的每一个单词表示 xi 转换为其对应的 qi,ki ,vi  向量。

为了得到编码单词 xi 时所需要关注的上下文信息,通过位置 i 查询向量与其他位置的键向量做点积得到匹配分数 qi · k1, qi · k2, ..., qi · kt。为了防止过大的匹配分数在后续 Softmax 计算过程中导致的梯度爆炸以及收敛效率差的问题,这些得分会除放缩因子 √d 以稳定优化。放缩后的得分经过 Softmax 归一化为概率之后,与其他位置的值向量相乘来聚合希望关注的上下文信息,并最小化不相关信息的干扰。上述计算过程可以被形式化地表述如下:

其中 Q  , K  ,V  分别表示输入序列中的不同单词的 q, k, v 向量拼接组成的矩阵,L 表示序列长度,Z 表示自注意力操作的输出。为了进一步增强自注意力机制聚合上下文信息的能力,提出了多头自注意力(Multi-head Attention)的机制,以关注上下文的不同侧面。具体来说,上下文中每一个单词的表示 xi 经过多组线性 {WQ*WK*WV } 映射到不同的表示子空间中。公式会在不同的子空间中分别计算并得到不同的上下文相关的单词序列表示{Zj}。最终,线性变换 WO 用于综合不同子空间中的上下文表示并形成自注意力层最终的输出 xi 。

三、完整代码

import torch.nn as nn
import torch
import math
import torch.nn.functional as Fclass MultiHeadAttention(nn.Module):def __init__(self, heads, d_model, dropout = 0.1):super().__init__()self.d_model = d_modelself.d_k = d_model // heads # 512 / 8 self.h = headsself.q_linear = nn.Linear(d_model, d_model)self.v_linear = nn.Linear(d_model, d_model)self.k_linear = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)self.out = nn.Linear(d_model, d_model)def attention(self, q, k, v, d_k, mask=None, dropout=None):scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) # self-attention公式# 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0if mask is not None:mask = mask.unsqueeze(1)scores = scores.masked_fill(mask == 0, -1e9)scores = F.softmax(scores, dim=-1) # self-attention公式if dropout is not None:scores = dropout(scores)output = torch.matmul(scores, v) # self-attention公式return outputdef forward(self, q, k, v, mask=None):bs = q.size(0) # 进行线性操作划分为成 h 个头k = self.k_linear(k).view(bs, -1, self.h, self.d_k)q = self.q_linear(q).view(bs, -1, self.h, self.d_k)v = self.v_linear(v).view(bs, -1, self.h, self.d_k)# 矩阵转置k = k.transpose(1,2) q = q.transpose(1,2) v = v.transpose(1,2) # 计算 attentionscores = self.attention(q, k, v, self.d_k, mask, self.dropout)# 连接多个头并输入到最后的线性层concat = scores.transpose(1,2).contiguous().view(bs, -1, self.d_model)output = self.out(concat)return output# 准备q、k、v张量
d_model = 512
num_heads = 8
batch_size = 32
seq_len = 64q = torch.randn(batch_size, seq_len, d_model) # 64 x 512
k = torch.randn(batch_size, seq_len, d_model) # 64 x 512
v = torch.randn(batch_size, seq_len, d_model) # 64 x 512sa = MultiHeadAttention(heads = num_heads, d_model=d_model)
print(sa(q, k, v).shape) # torch.Size([32, 64, 512])
print('')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/208826.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

华清远见嵌入式学习——QT——作业2

作业要求&#xff1a; 代码运行效果图&#xff1a; 登录失败 和 最小化 和 取消登录 登录成功 和 X号退出 代码&#xff1a; ①&#xff1a;头文件 #ifndef LOGIN_H #define LOGIN_H#include <QMainWindow> #include <QLineEdit> //行编辑器类 #include…

Java Spring + SpringMVC + MyBatis(SSM)期末作业项目

本系统是一个图书管理系统&#xff0c;比较适合当作期末作业主要技术栈如下&#xff1a; - 数据库&#xff1a;MySQL - 开发工具&#xff1a;IDEA - 数据连接池&#xff1a;Druid - Web容器&#xff1a;Apache Tomcat - 项目管理工具&#xff1a;Maven - 版本控制工具&#xf…

探索人工智能领域——每日20个名词详解【day12】

目录 前言 正文 总结 &#x1f308;嗨&#xff01;我是Filotimo__&#x1f308;。很高兴与大家相识&#xff0c;希望我的博客能对你有所帮助。 &#x1f4a1;本文由Filotimo__✍️原创&#xff0c;首发于CSDN&#x1f4da;。 &#x1f4e3;如需转载&#xff0c;请事先与我联系以…

进程、线程、线程池状态

线程几种状态和状态转换 进程主要写明三种基本状态&#xff1a; 线程池的几种状态&#xff1a;

STM32的BKP与RTC简介

芯片的供电引脚 引脚表橙色的是芯片的供电引脚&#xff0c;其中VSS/VDD是芯片内部数字部分的供电&#xff0c;VSSA/VDDA是芯片内部模拟部分的供电&#xff0c;这4组以VDD开头的供电都是系统的主电源&#xff0c;正常使用时&#xff0c;全部都要接3.3V的电源上&#xff0c;VBAT是…

Leetcode2477. 到达首都的最少油耗

Every day a Leetcode 题目来源&#xff1a;2477. 到达首都的最少油耗 解法1&#xff1a;贪心 深度优先搜索 题目等价于给出了一棵以节点 0 为根结点的树&#xff0c;并且初始树上的每一个节点上都有一个人&#xff0c;现在所有人都需要通过「车子」向结点 0 移动。 对于…

从阻抗匹配看拥塞控制

先来理解阻抗匹配&#xff0c;但我不按传统方式解释&#xff0c;因为传统方案你要先理解如何定义阻抗&#xff0c;然后再学习什么是输入阻抗和输出阻抗&#xff0c;最后再看如何让它们匹配&#xff0c;而让它们匹配的目标仅仅是信号不反射&#xff0c;以最大能效被负载接收。 …

Amazon CodeWhisperer 开箱初体验

文章作者&#xff1a;Coder9527 科技的进步日新月异&#xff0c;正当人工智能发展如火如荼的时候&#xff0c;各大厂商在“解放”码农的道路上不断创造出各种 Coding 利器&#xff0c;今天在下就带大家开箱体验一个 Coding 利器&#xff1a; Amazon CodeWhisperer。 亚马逊云科…

99基于matlab的小波分解和小波能量熵函数

基于matlab的小波分解和小波能量熵函数&#xff0c;通过GUI界面导入西储大学轴承故障数据&#xff0c;以可视化的图对结果进行展现。数据可更换自己的&#xff0c;程序已调通&#xff0c;可直接运行。 99小波分解和小波能量熵函数 (xiaohongshu.com)https://www.xiaohongshu.co…

【LeetCode每日一题合集】2023.11.27-2023.12.3 (⭐)

文章目录 907. 子数组的最小值之和&#xff08;单调栈贡献法&#xff09;1670. 设计前中后队列⭐&#xff08;设计数据结构&#xff09;解法1——双向链表解法2——两个双端队列 2336. 无限集中的最小数字解法1——维护最小变量mn 和 哈希表维护已经去掉的数字解法2——维护原本…

二分查找|前缀和|滑动窗口|2302:统计得分小于 K 的子数组数目

作者推荐 贪心算法LeetCode2071:你可以安排的最多任务数目 本文涉及的基础知识点 二分查找算法合集 题目 一个数组的 分数 定义为数组之和 乘以 数组的长度。 比方说&#xff0c;[1, 2, 3, 4, 5] 的分数为 (1 2 3 4 5) * 5 75 。 给你一个正整数数组 nums 和一个整数…

response应用及重定向和request转发

请求和转发&#xff1a; response说明一、response文件下载二、response验证码实现1.前置知识&#xff1a;2.具体实现&#xff1a;3.知识总结 三、response重定向四、request转发五、重定向和转发的区别 response说明 response是指HttpServletResponse,该响应有很多的应用&…

Kafka在微服务架构中的应用:实现高效通信与数据流动

微服务架构的兴起带来了分布式系统的复杂性&#xff0c;而Kafka作为一款强大的分布式消息系统&#xff0c;为微服务之间的通信和数据流动提供了理想的解决方案。本文将深入探讨Kafka在微服务架构中的应用&#xff0c;并通过丰富的示例代码&#xff0c;帮助大家更全面地理解和应…

PaddleClas学习3——使用PPLCNet模型对车辆朝向进行识别(c++)

使用PPLCNet模型对车辆朝向进行识别 1 准备环境2 准备模型2.1 模型导出2.2 修改配置文件3 编译3.1 使用CMake生成项目文件3.2 编译3.3 执行3.4 添加后处理程序3.4.1 postprocess.h3.4.2 postprocess.cpp3.4.3 在cls.h中添加函数声明3.4.4 在cls.cpp中添加函数定义3.4.5 在main.…

时间序列预测 — VMD-LSTM实现单变量多步光伏预测(Tensorflow):单变量转为多变量

目录 1 数据处理 1.1 导入库文件 1.2 导入数据集 1.3 缺失值分析 2 VMD经验模态分解 3 构造训练数据 4 LSTM模型训练 5 预测 1 数据处理 1.1 导入库文件 import time import datetime import pandas as pd import numpy as np import matplotlib.pyplot as plt f…

优化算法 学习记录

文章目录 相关资料 优化算法梯度下降学习率牛顿法 随机梯度下降小批量随机梯度下降动量法动量法解决上述问题 AdaGrad 算法RMSProp算法Adam学习率调度器余弦学习率调度预热 相关资料 李沐 动手学深度学习 优化算法 优化算法使我们能够继续更新模型参数&#xff0c;并使损失函…

Elasticsearch:使用 Elasticsearch 向量搜索及 RAG 来实现 Chatbot

Elasticsearch 的向量搜索为我们的语义搜索提供了可能。而在人工智能的动态格局中&#xff0c;检索增强生成&#xff08;Retrieval Augmented Generation - RAG&#xff09;已经成为游戏规则的改变者&#xff0c;彻底改变了我们生成文本和与文本交互的方式。 RAG 使用大型语言模…

MongoDB的删除文档、查询文档语句

本文主要介绍MongoDB的删除文档、查询文档命令语句。 目录 MongoDB删除文档MongoDB查询文档 MongoDB删除文档 MongoDB是一种基于文档的NoSQL数据库&#xff0c;它使用BSON格式存储文档。删除文档是MongoDB数据库中的常见操作之一。 下面是MongoDB删除文档的详细介绍和示例&am…

导入自定义模块出现红色波浪线,但是能正常执行

问题描述&#xff1a; 导入自己定义的模块时&#xff0c;出现红色波浪线&#xff0c;可以继续执行 解决&#xff1a; 在存放当前执行文件的文件夹右键&#xff0c;然后将其设置为sources root即可 结果&#xff1a;

基于深度学习yolov5实现安全帽人体识别工地安全识别系统-反光衣识别系统

欢迎大家点赞、收藏、关注、评论啦 &#xff0c;由于篇幅有限&#xff0c;只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 实现安全帽人体识别工地安全识别系统需要使用深度学习技术&#xff0c;特别是YOLOv5算法。下面是对基于YOLOv5实现安…