逐行解析多头注意力机制

多头注意力机制是NLP算法岗常考的代码题,本篇文章将逐行梳理多头注意力机制的代码。

全部代码

import math
import torch
import torch.nn as nnclass MultiHeadAttention(nn.Module):def __init__(self, d_model, nums_head):super(MultiHeadAttention, self).__init__()self.d_model = d_modelself.nums_head = nums_headassert self.d_model % self.nums_head == 0self.split = d_model // nums_headself.q_linear = nn.linear(self.d_model, self.d_model)self.k_linear = nn.linear(self.d_model, self.d_model)self.v_linear = nn.linear(self.d_model, self.d_model)def split_head(self, x):batch_size, seq_length, hidden_size = x.size()return x.view(batch_size, seq_length, self.nums_head, self.split).transpose(1, 2)def forward(self, q, k, v, mask=None):query = self.q_linear(q)key = self.k_linear(k)value = self.v_linear(v)query = self.split_head(query)key = self.split_head(key)value = self.split_head(value)score = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.split)if mask is not None:score = score + score.masked_fill(mask == 0, -float('inf'))weight = torch.softmax(score, dim=-1)att_output = torch.matmul(weight, value)batch_size, _, seq_length, hidden_size = att_output.size()att_output = att_output.transpose(1, 2).view(batch_size, seq_length, self.d_model)return att_output

细节解析

设置断言,保证原始维度可以整除注意力头数。

assert self.d_model % self.nums_head == 0

将x的维度进行转化,并且将seq_length和nums_head的维度交换(允许对每个序列位置和每个注意力头独立地处理数据。这样,模型就可以并行地关注输入序列的不同部分,从而提高处理复杂输入序列的能力)。

def split_head(self, x):batch_size, seq_length, hidden_size = x.size()return x.view(batch_size, seq_length, self.nums_head, self.split).transpose(1, 2)

mask == 0:这部分代码会生成一个与mask形状相同的布尔张量(Boolean Tensor),其中mask中值为0的位置在布尔张量中对应为True,其他位置为False。
scores.masked_fill(mask == 0, -1e9):masked_fill函数接受一个布尔张量和一个标量值作为输入。它会遍历布尔张量,并将原始张量(这里是scores)中对应为True的位置填充为指定的标量值(这里是-1e9)。因此,如果mask中的某个位置是0,那么scores中对应位置的值将被替换为-float(‘inf’)。
scores += …:最后,这个替换后的张量(其中mask为0的位置被替换为-float(‘inf’))被加回到原始的scores张量上。然而,由于masked_fill实际上返回了一个新张量(即原始张量的一个副本,但其中某些位置被修改了),这里的加法操作实际上是将原始scores张量中的对应位置更新为-1e9加上它们原来的值。但是,由于-float(‘inf’)是一个非常大的负数,并且softmax函数在计算时会将输入值转换为概率分布,所以任何接近-float(‘inf’)的值在softmax中都会接近于0。因此,这个操作实际上是在告诉模型在计算注意力权重时忽略mask为0的位置。

if mask is not None:score = score + score.masked_fill(mask == 0, -float('inf'))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/52471.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

QT 自定义组件 界面跳转

一、引用组件需要的类(头文件) 1、按钮类 QPushButton: 普通按钮; QToolButton: 工具按钮; QRadioButton: 单选按钮; QCheckBox: 复选按钮; QCommandLinkButton: 命令连接按钮; 2、布局类 QHBoxLayout水平 QVBoxLayout垂直 QGridLayout网格 QFormLayout…

存储芯片行业的封装类型

存储芯片行业的封装类型 存储芯片分类: 随机存储器(RAM):这是易失性存储器,断电后存储的数据会丢失。它包括: 动态随机存储器(DRAM):这是最常见的系统内存类型&#xf…

智能头盔语音识别声控芯片,AI离线语音识别ic方案,NRK3301

头盔是交通事故中保护电动车车主安全的最后一道屏障。为了增加骑行用户的安全保护,改善骑行用户的出行体验,让用户从被动使用头盔到主动佩戴头盔,头盔厂家与九芯电子合作,推出了语音智能头盔,它具备首家骑行专用的智能…

【网络安全】-xss跨站脚本攻击实战-xss-labs(1~10)

Level1: 检查页面源代码: function函数: (function(){try{let tn ;if(tn.includes(oem)){Object.defineProperty(document, referrer, {get: function(){return ;}});}else if(tn.includes(hao_pg)){if(!document.referrer.match(tn)){Object.definePro…

【python】python 安装和 pycharm 安装

1 python 安装 1.1 下载 下载地址:python 官网 1.2 安装 windows 安装为例。 双击.exe文件打开 安装界面 安装完成 1.3 检查安装是否成功 win/start 键r 键 运行窗口输入 cmd 回车 3 输入 python查看 显示版本信息,表示已经安装成功。 …

协议头,wireshark,http

目录 协议头 ip头 udp头 mac层 网络工具 telnet wireshark Http 一、HTTP 协议介绍 二、HTTP 协议的工作过程 三、使用抓包工具抓取报文 四、获取到http请求报文: 五、http请求(request) (一)、认识URL 项…

如果 Android 手机出现数据丢失,如何在Android上恢复丢失的数据

当您的 Android 手机发生数据丢失时,您可能需要检索丢失的文件。为了帮助您完成此过程,以下是执行 Android 数据恢复的一些有效方法: 如何在Android上检索数据 如果您的 Android 手机出现数据丢失,您可能需要检索丢失的文件。为了…

OpenWRT有三个地方设置DNS,究竟设置哪个地方会更好?

前言 刚上手OpenWRT软路由系统的小伙伴或许都会有这样的疑问:OpenWRT这个系统有三个地方是设置DNS的,究竟设置哪一个才是正确的? 这个还得从实际应用说起。 一般来说,咱们在使用路由器的时候,DNS都是默认运营商的DN…

前端框架大观:探索现代Web开发的基石

目录 引言 一、前端框架概述 二、主流前端框架介绍 2.1 React 2.1.1 简介 2.1.2 特点 2.1.3 代码示例 2.2 Vue.js 2.2.1 简介 2.2.2 特点 2.2.3 代码示例 2.3 Angular 2.3.1 简介 2.3.2 特点 2.3.3 代码示例 三、其他前端框架与库 四、前端框架的选择 五、结…

计算机毕业设计选题推荐-自驾游攻略管理系统-Java/Python项目实战

✨作者主页:IT研究室✨ 个人简介:曾从事计算机专业培训教学,擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Python…

CentOs7 解决yum更新源报错:[Errno 14] HTTP Error 404 - Not Found 正在尝试其它镜像。

CentOs7 解决yum更新源报错:[Errno 14] HTTP Error 404 - Not Found 正在尝试其它镜像。 前言问题解决方法: 前言 遇到这个问题大概率是镜像源的问题可以参照这篇文章的内容试一下 镜像源问题相关解决方法 根据自己的情况对症下药,如果还不…

LAMP环境下项目部署

目录 1、创建一台虚拟机 centos 源的配置 备份源 修改源 重新加载缓存 安装软件 2、关闭防火墙和selinux 查看防火墙状态 关闭防火墙 查看SELinux的状态 临时关闭防火墙 永久关闭SELinux:编辑SELinux的配置文件 配置文件的修改内容 3、检查系统中是否…

计算机毕业设计 家校互联管理系统 Java+SpringBoot+Vue 前后端分离 文档报告 代码讲解 安装调试

🍊作者:计算机编程-吉哥 🍊简介:专业从事JavaWeb程序开发,微信小程序开发,定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事,生活就是快乐的。 🍊心愿:点…

单片机-STM32 看门狗(八)

目录 一、看门狗概念 1、定义: 二、单片机中的看门狗 1、功能描述: 2、看门狗设置部分 预分频寄存器(IWDG_PR) 3、窗口看门狗 特性: 4、看门狗配置: 一、看门狗概念 看门狗--定时器(不属于基本定时器、通用定…

svg图标的使用

图片的格式有很多,前端经常使用的有以下类型:jpg,jpeg,png,gif,svg,这篇文章将简单svg的情况,以及项目中如何使用和配置svg图标 目录 什么是svg图标 SVG图标的优缺点 优点 缺点 svg前端使用场景 SVG在代码中的使用 简单使用创建svg 作为图标引入…

udp网络通信 socket

套接字是实现进程间通信的编程。IP可以标定主机在全网的唯一性,端口可以标定进程在主机的唯一性,那么socket通过IP端口号就可以让两个在全网唯一标定的进程进行通信。 套接字有三种: 域间套接字:实现主机内部的进程通信的编程 …

yolov5 +gui界面+单目测距 实现对图片视频摄像头的测距

可实现对图片,视频,摄像头的检测 项目概述 本项目旨在实现一个集成了YOLOv5目标检测算法、图形用户界面(GUI)以及单目测距功能的系统。该系统能够对图片、视频或实时摄像头输入进行目标检测,并估算目标的距离。通过…

Linux shell编程学习笔记78:cpio命令——文件和目录归档工具

0 前言 在Linux系统中,除了tar命令,我们还可以使用cpio命令来进行文件和目录的归档。 1 cpio命令的功能,帮助信息,格式,选项和参数说明 1.1 cpio命令的功能 cpio 名字来自 "copy in, copy out"&#xf…

具有RC反馈电路的正弦波振荡器(文氏桥振荡器+相移振荡器+双T振荡器)

2024-9-10,星期二,22:13,天气:雨,心情:晴。今天从下午开始淅淅沥沥一直在下雨,还好我有先见之明没骑自行车,但是我忘带伞了,属于说是有点脑子但是不多了,2333…

如何注册谷歌账号(“此电话号码无法验证”问题)

如何注册谷歌账号(“此电话号码无法验证”问题) 以下注册账号的步骤于 2024.9.10 20:00 成功实施。 文章目录 如何注册谷歌账号(“此电话号码无法验证”问题)1)打开谷歌浏览器2)设置浏览器语言【英语&…