Transformer——多头注意力机制(Pytorch)

1. 原理图

2. 代码

import torch
import torch.nn as nnclass Multi_Head_Self_Attention(nn.Module):def __init__(self, embed_size, heads):super(Multi_Head_Self_Attention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsself.queries = nn.Linear(self.embed_size, self.embed_size, bias=False)self.keys = nn.Linear(self.embed_size, self.embed_size, bias=False)self.values = nn.Linear(self.embed_size, self.embed_size, bias=False)self.fc_out = nn.Linear(self.embed_size, self.embed_size, bias=False)def forward(self,queries, keys, values, mask):N = queries.shape[0]  # batch_sizequery_len = queries.shape[1]  # sequence_lengthkey_len = keys.shape[1]  # sequence_length value_len = values.shape[1]  # sequence_lengthqueries = self.queries(queries)keys = self.keys(keys)values = self.values(values)# Split the embedding into self.heads pieces# batch_size, sequence_length, embed_size(512) --> # batch_size, sequence_length, heads(8), head_dim(64)queries = queries.reshape(N, query_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)values = values.reshape(N, value_len, self.heads, self.head_dim)# batch_size, sequence_length, heads(8), head_dim(64) --> # batch_size, heads(8), sequence_length, head_dim(64)queries = queries.transpose(1, 2)keys = keys.transpose(1, 2)values = values.transpose(1, 2)# Scaled dot-product attentionscore = torch.matmul(queries, keys.transpose(-2, -1)) / (self.head_dim ** (1/2))if mask is not None:score = score.masked_fill(mask == 0, float("-inf"))# batch_size, heads(8), sequence_length, sequence_lengthattention = torch.softmax(score, dim=-1)out = torch.matmul(attention, values)# batch_size, heads(8), sequence_length, head_dim(64) --># batch_size, sequence_length, heads(8), head_dim(64) --># batch_size, sequence_length, embed_size(512)# 为了方便送入后面的网络out = out.transpose(1, 2).contiguous().reshape(N, query_len, self.embed_size)out = self.fc_out(out)return outbatch_size = 64
sequence_length = 10
embed_size = 512
heads = 8
mask = NoneQ = torch.randn(batch_size, sequence_length, embed_size)  
K = torch.randn(batch_size, sequence_length, embed_size)  
V = torch.randn(batch_size, sequence_length, embed_size)  model = Multi_Head_Self_Attention(embed_size, heads)
output = model(Q, K, V, mask)
print(output.shape)

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/870595.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【VSCode】设置背景图片

1.下载background插件:拓展程序→background→install安装→设置: 2.点击在 settings.json 中编辑: 3.将settings.json文件中所有代码注释,添加以下代码: {// 是否开启背景图显示"background.enabled": t…

【Linux杂货铺】1.环境变量

1.环境变量基本概念 环境变量( environment variables )一般是指在操作系统中用来指定操作系统运行环境的一些参数。如:我们在编写 C / C +代码的时候,在链接的时候,从来不知道我们的所链接的动态静态库在哪…

【Go系列】 Go语言数据结构

承上启下 在上一次的思维碰撞中,我们的小试牛刀是一段温馨的代码小练习——将“Hello World”这个熟悉的问候,替换成了我们自己的名字。是的,你没听错,就是这么简单!以我为例,我将“Hello World”轻轻一变&…

算法训练 | 图论Part8 | 拓扑排序、dijkstra(朴素版)

目录 117. 软件构建 拓扑排序法 47. 参加科学大会 dijkstra法 117. 软件构建 题目链接&#xff1a;117. 软件构建 文章讲解&#xff1a;代码随想录 拓扑排序法 代码一&#xff1a;拓扑排序 #include <iostream> #include <vector> #include <queue> …

什么是Web3D?国内有哪些公司可以做?

Web3D 是一种基于网页的三维立体虚拟现实技术。利用计算机图形学、图像处理、人机交互等技术&#xff0c;将现实世界中的物体、场景或概念以三维立体的方式呈现在网页里。Web3D 技术可以让用户在任何时间、任何地点&#xff0c;通过互联网与虚拟世界进行互动&#xff0c;获得身…

模型剪枝介绍

Ref&#xff1a;https://www.cnblogs.com/the-art-of-ai/p/17500399.html 1、背景介绍 深度学习模型在图像识别、自然语言处理、语音识别等领域取得了显著的成果&#xff0c;但是这些模型往往需要大量的计算资源和存储空间。尤其是在移动设备和嵌入式系统等资源受限的环境下&a…

昇思25天学习打卡营第1天|初步了解

1在昇思平台上申请过相关资源之后&#xff0c;将示例代码粘贴到输入框内。可以在下图中创建一个新的文档。 2不过初次运行的时候会遇到一个问题&#xff0c;点击运行的时候会出现新的输入框&#xff0c;而不是直接运行。遇到此问题等待就可以了&#xff0c;或者稍微等一下再运…

【JVM】对象的生命周期一 | 对象的创建与存储

Java | 对象的生命周期1-对象的创建与存储 文章目录 前言对象的创建过程内存空间的分配方式方式1 | 指针碰撞方式2 | 空闲列表 线程安全问题 | 避免空间冲突的方式方式1 | 同步处理&#xff08;加锁)方式2 | 本地线程分配缓存 对象的内存布局Part1 | 对象头Mark Word类型指针 P…

内网安全:域内信息探测

1.域内基本信息收集 2.NET命令详解 3.内网主要使用的域收集方法 4.查找域控制器的方法 5.查询域内用户的基本信息 6.定位域管 7.powershell命令和定位敏感信息 1.域内基本信息收集&#xff1a; 四种情况&#xff1a; 1.本地用户&#xff1a;user 2.本地管理员用户&#x…

solidity实战练习1

//SPDX-License-Identifier:MIT pragma solidity ^0.8.24; contract PiggyBank{constructor()payable{emit Deposit(msg.value);//触发事件1//意味着在部署合约的时候&#xff0c;可以向合约发送以太币&#xff08;不是通过调用函数&#xff0c;而是直接在部署合约时发送&#…

C++ STL for_each的用法和实现

目录 一&#xff1a;功能 二&#xff1a;用法 三&#xff1a;实现 一&#xff1a;功能 遍历元素 二&#xff1a;用法 //C 11 #include <vector> #include <algorithm> #include <iostream> #include <format>struct StatsFn {int cnt 0;int sum…

外泌体相关基因肝癌临床模型预测——2-3分纯生信文章复现——4.预后相关外泌体基因确定之生存曲线(4)

内容如下: 1.外泌体和肝癌TCGA数据下载 2.数据格式整理 3.差异表达基因筛选 4.预后相关外泌体基因确定 5.拷贝数变异及突变图谱 6.外泌体基因功能注释 7.LASSO回归筛选外泌体预后模型 8.预后模型验证 9.预后模型鲁棒性分析 10.独立预后因素分析及与临床的相关性分析…

【算法】二叉树算法基本概念及实现

目录 一、二叉树的基本概念 二、二叉树的性质 三、二叉树的算法实现 四、二叉树的应用 C# 实现 Python 实现 二叉树算法是计算机科学中常用的一种数据结构算法,主要用于处理具有层级关系的数据。以下是对二叉树算法的详细介绍: 一、二叉树的基本概念 定义:二叉树是n…

【机器翻译】基于术语词典干预的机器翻译挑战赛

赛题链接&#xff1a;https://challenge.xfyun.cn/topic/info?typemachine-translation-2024 赛题解读 安装库 spacy 1.查看本地spacy版本 pip show spacy我安装3.6.0 pip install en_core_web_sm-3.6.0.tar.gzen_core_web_sm下载链接&#xff1a;https://github.com/ex…

[Linux]对Linux中的命令的本质

上回我们讲了Linux的指令&#xff0c;本篇是一个短篇&#xff0c;主要是对命令本质的讲解。 我们知道命令一般都是直接使用的 而可执行程序需要加上当前的路径 &#xff08;这个mytest是我们上上回写的&#xff0c;作用实际是打印Hello world!&#xff09; 我们很直观的可以发…

git为文件添加可执行权限

查看文件权限 git ls-files --stage .\SecretFinder.py100644 表示文件的所有者有读取和写入权限 添加可执行权限 git update-index --chmod x .\SecretFinder.py再次查看文件权限 git ls-files --stage .\SecretFinder.py100755 表示文件的所有者有读取、写入和执行权限

git查看版本,查看安装路径、更新版本

一、查看安装路径 where git查看安装路径 二、更新版本 git update-git-for-windows 更新版本 三、查看版本 git version 查看版本

【鸿蒙学习笔记】文件管理

官方文档&#xff1a;Core File Kit简介 目录标题 文件分类什么是应用沙箱&#xff1f; 文件分类 应用文件&#xff0c;比如应用的安装包&#xff0c;自己的资源文件等。用户文件&#xff0c;比如用户自己的照片&#xff0c;录制的音视频等。 什么是应用沙箱&#xff1f; 应…

linux boost 例子 加 编译

在Linux下&#xff0c;使用Boost Asio库可以轻松实现非阻塞的网络API调用。以下是一个简单的例子&#xff0c;展示了如何设置一个非阻塞的TCP socket&#xff1a; #include <boost/asio.hpp> #include <iostream>int main() {// 创建IO服务对象&#xff0c;它负责…

第二周周三总结

题目总结 1.给你一个二进制数组 nums 。 你可以对数组执行以下操作 任意 次&#xff08;也可以 0 次&#xff09;&#xff1a; 选择数组中 任意连续 3 个元素&#xff0c;并将它们 全部反转 。 反转 一个元素指的是将它的值从 0 变 1 &#xff0c;或者从 1 变 0 。 请你…