Transformer——多头注意力机制(Pytorch)

1. 原理图

2. 代码

import torch
import torch.nn as nnclass Multi_Head_Self_Attention(nn.Module):def __init__(self, embed_size, heads):super(Multi_Head_Self_Attention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsself.queries = nn.Linear(self.embed_size, self.embed_size, bias=False)self.keys = nn.Linear(self.embed_size, self.embed_size, bias=False)self.values = nn.Linear(self.embed_size, self.embed_size, bias=False)self.fc_out = nn.Linear(self.embed_size, self.embed_size, bias=False)def forward(self,queries, keys, values, mask):N = queries.shape[0]  # batch_sizequery_len = queries.shape[1]  # sequence_lengthkey_len = keys.shape[1]  # sequence_length value_len = values.shape[1]  # sequence_lengthqueries = self.queries(queries)keys = self.keys(keys)values = self.values(values)# Split the embedding into self.heads pieces# batch_size, sequence_length, embed_size(512) --> # batch_size, sequence_length, heads(8), head_dim(64)queries = queries.reshape(N, query_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)values = values.reshape(N, value_len, self.heads, self.head_dim)# batch_size, sequence_length, heads(8), head_dim(64) --> # batch_size, heads(8), sequence_length, head_dim(64)queries = queries.transpose(1, 2)keys = keys.transpose(1, 2)values = values.transpose(1, 2)# Scaled dot-product attentionscore = torch.matmul(queries, keys.transpose(-2, -1)) / (self.head_dim ** (1/2))if mask is not None:score = score.masked_fill(mask == 0, float("-inf"))# batch_size, heads(8), sequence_length, sequence_lengthattention = torch.softmax(score, dim=-1)out = torch.matmul(attention, values)# batch_size, heads(8), sequence_length, head_dim(64) --># batch_size, sequence_length, heads(8), head_dim(64) --># batch_size, sequence_length, embed_size(512)# 为了方便送入后面的网络out = out.transpose(1, 2).contiguous().reshape(N, query_len, self.embed_size)out = self.fc_out(out)return outbatch_size = 64
sequence_length = 10
embed_size = 512
heads = 8
mask = NoneQ = torch.randn(batch_size, sequence_length, embed_size)  
K = torch.randn(batch_size, sequence_length, embed_size)  
V = torch.randn(batch_size, sequence_length, embed_size)  model = Multi_Head_Self_Attention(embed_size, heads)
output = model(Q, K, V, mask)
print(output.shape)

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/870595.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【VSCode】设置背景图片

1.下载background插件:拓展程序→background→install安装→设置: 2.点击在 settings.json 中编辑: 3.将settings.json文件中所有代码注释,添加以下代码: {// 是否开启背景图显示"background.enabled": t…

【Linux杂货铺】1.环境变量

1.环境变量基本概念 环境变量( environment variables )一般是指在操作系统中用来指定操作系统运行环境的一些参数。如:我们在编写 C / C +代码的时候,在链接的时候,从来不知道我们的所链接的动态静态库在哪…

什么是Web3D?国内有哪些公司可以做?

Web3D 是一种基于网页的三维立体虚拟现实技术。利用计算机图形学、图像处理、人机交互等技术,将现实世界中的物体、场景或概念以三维立体的方式呈现在网页里。Web3D 技术可以让用户在任何时间、任何地点,通过互联网与虚拟世界进行互动,获得身…

昇思25天学习打卡营第1天|初步了解

1在昇思平台上申请过相关资源之后,将示例代码粘贴到输入框内。可以在下图中创建一个新的文档。 2不过初次运行的时候会遇到一个问题,点击运行的时候会出现新的输入框,而不是直接运行。遇到此问题等待就可以了,或者稍微等一下再运…

【JVM】对象的生命周期一 | 对象的创建与存储

Java | 对象的生命周期1-对象的创建与存储 文章目录 前言对象的创建过程内存空间的分配方式方式1 | 指针碰撞方式2 | 空闲列表 线程安全问题 | 避免空间冲突的方式方式1 | 同步处理(加锁)方式2 | 本地线程分配缓存 对象的内存布局Part1 | 对象头Mark Word类型指针 P…

内网安全:域内信息探测

1.域内基本信息收集 2.NET命令详解 3.内网主要使用的域收集方法 4.查找域控制器的方法 5.查询域内用户的基本信息 6.定位域管 7.powershell命令和定位敏感信息 1.域内基本信息收集: 四种情况: 1.本地用户:user 2.本地管理员用户&#x…

solidity实战练习1

//SPDX-License-Identifier:MIT pragma solidity ^0.8.24; contract PiggyBank{constructor()payable{emit Deposit(msg.value);//触发事件1//意味着在部署合约的时候,可以向合约发送以太币(不是通过调用函数,而是直接在部署合约时发送&#…

外泌体相关基因肝癌临床模型预测——2-3分纯生信文章复现——4.预后相关外泌体基因确定之生存曲线(4)

内容如下: 1.外泌体和肝癌TCGA数据下载 2.数据格式整理 3.差异表达基因筛选 4.预后相关外泌体基因确定 5.拷贝数变异及突变图谱 6.外泌体基因功能注释 7.LASSO回归筛选外泌体预后模型 8.预后模型验证 9.预后模型鲁棒性分析 10.独立预后因素分析及与临床的相关性分析…

[Linux]对Linux中的命令的本质

上回我们讲了Linux的指令,本篇是一个短篇,主要是对命令本质的讲解。 我们知道命令一般都是直接使用的 而可执行程序需要加上当前的路径 (这个mytest是我们上上回写的,作用实际是打印Hello world!) 我们很直观的可以发…

git为文件添加可执行权限

查看文件权限 git ls-files --stage .\SecretFinder.py100644 表示文件的所有者有读取和写入权限 添加可执行权限 git update-index --chmod x .\SecretFinder.py再次查看文件权限 git ls-files --stage .\SecretFinder.py100755 表示文件的所有者有读取、写入和执行权限

git查看版本,查看安装路径、更新版本

一、查看安装路径 where git查看安装路径 二、更新版本 git update-git-for-windows 更新版本 三、查看版本 git version 查看版本

【鸿蒙学习笔记】文件管理

官方文档:Core File Kit简介 目录标题 文件分类什么是应用沙箱? 文件分类 应用文件,比如应用的安装包,自己的资源文件等。用户文件,比如用户自己的照片,录制的音视频等。 什么是应用沙箱? 应…

maven高级1——一个项目拆成多个

把原来一个项目,拆成多个项目。 !!他们之间,靠接口通信。 以ssm整合好的项目为例: 如何看拆的ok不ok 只要compile通过就ok。 拆分pojo 先新建一个项目模块,再把内容复制进去。 拆分dao 1.和上面一样…

NI VST 毫米波测试仪器创新

目录 概览​从UHF至V频段的频率覆盖范围:54 GHz远程测量模块​PXIe-5842:VST架构的扩展54 GHz扩频PXIe-5842功能​​宽频覆盖范围​IF和毫米波测试端口可满足多频带需求​高达2 GHz瞬时带宽误差矢量幅度测量性能相位相干同步基于PXI平台集成多种仪器 互补…

maven6——生命周期与插件

生命周期 生命周期:指运行的阶段(比如几岁) maven有三个生命周期如下,每个生命周期大概做的事情如下: 注意:每次执行某个,他会把上面的都执行一遍 插件: 每一个插件&#xf…

【Git基本操作】创建本地仓库 | 配置本地仓库 | 认识工作区、暂存区、版本库、对象库 | add和commit操作

目录 1.创建Git本地仓库 1.1创建仓库 1.2创建和初始化Git本地仓库 1.3查看隐藏目录.git 2.配置本地仓库 2.1新增配置 2.2删除重置配置 2.3查看配置选项 2.4全局范围的新增和删除配置 3.工作区、暂存区、版本库、对象库 ​4.add操作和commit操作 4.1add操作 4.2com…

labelme 标注检查经验

1. python labelImg.py D:\BaiduNetdiskDownload\yoloDt_qiuyi_num\yoloDt_qiuye_num\train\images D:\BaiduNetdiskDownload\yoloDt_qiuyi_num\yoloDt_qiuye_num\train\labels\classes.txt 2. 目录另存为会找到classes.txt的类,然后标注起来。

idm站点抓取可以用来做什么 idm站点抓取能抓取本地网页吗 idm站点抓取怎么用 网络下载加速器

在下载工具众多且竞争激烈的市场中,Internet Download Manager(简称IDM)作为一款专业的下载加速软件,仍然能够赢得众多用户的青睐,这都要得益于它的强大的下载功能。我们在开始使用IDM的时候总是有很多疑问&#xff0c…

链接服务器“XX”的OLEDB访问接口“MSOLEDBSQL”返回了消息“登录超时已过期” 解决方法

目录 1. 问题所示2. 原理分析3. 解决方法1. 问题所示 出现如下问题: 与链接服务器的测试连接失败。执行Transact-SQL 语句或批处理时发生了异常。命名管道提供程序:无法打开与SQL SERVER的链接[53]链接服务器“XX”的OLEDB访问接口“MSOLEDBSQL”返回了消息“登录超时已过期…

IntelliJ IDEA中刷新Git分支数据:操作指南与命令详解

前言 在软件开发过程中,频繁地与Git仓库交互是常态,确保本地分支信息与远程仓库保持同步至关重要。IntelliJ IDEA作为一款强大的集成开发环境,提供了直观的图形界面和终端命令行两种方式来帮助开发者高效地管理Git分支。本文将详细介绍如何在…