python-pytorch编写transformer模型实现翻译0.5.00-写模型

前言

在网上看了一篇文章,借用了文章的大部分代码,并对代码的预测进行修改使得可以正确的预测了,具体链接找了半天找不到

代码

import numpy as np # 导入 numpy 库
import torch # 导入 torch 库
import torch.nn as nn # 导入 torch.nn 库
d_k = 64 # K(=Q) 维度
d_v = 64 # V 维度
# 定义缩放点积注意力类
class ScaledDotProductAttention(nn.Module):def __init__(self):super(ScaledDotProductAttention, self).__init__()        def forward(self, Q, K, V, attn_mask):#------------------------- 维度信息 --------------------------------        # Q K V [batch_size, n_heads, len_q/k/v, dim_q=k/v] (dim_q=dim_k)# attn_mask [batch_size, n_heads, len_q, len_k]#----------------------------------------------------------------# 计算注意力分数(原始权重)[batch_size,n_heads,len_q,len_k]scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) #------------------------- 维度信息 --------------------------------        # scores [batch_size, n_heads, len_q, len_k]#-----------------------------------------------------------------        # 使用注意力掩码,将 attn_mask 中值为 1 的位置的权重替换为极小值#------------------------- 维度信息 -------------------------------- # attn_mask [batch_size, n_heads, len_q, len_k], 形状和 scores 相同#-----------------------------------------------------------------    scores.masked_fill_(attn_mask, -1e9) # 对注意力分数进行 softmax 归一化weights = nn.Softmax(dim=-1)(scores) #------------------------- 维度信息 -------------------------------- # weights [batch_size, n_heads, len_q, len_k], 形状和 scores 相同#-----------------------------------------------------------------         # 计算上下文向量(也就是注意力的输出), 是上下文信息的紧凑表示context = torch.matmul(weights, V) #------------------------- 维度信息 -------------------------------- # context [batch_size, n_heads, len_q, dim_v]#-----------------------------------------------------------------    return context, weights # 返回上下文向量和注意力分数# 定义多头自注意力类
d_embedding = 512  # Embedding 的维度
n_heads = 8  # Multi-Head Attention 中头的个数
batch_size = 6 # 每一批的数据大小
class MultiHeadAttention(nn.Module):def __init__(self):super(MultiHeadAttention, self).__init__()self.W_Q = nn.Linear(d_embedding, d_k * n_heads) # Q的线性变换层self.W_K = nn.Linear(d_embedding, d_k * n_heads) # K的线性变换层self.W_V = nn.Linear(d_embedding, d_v * n_heads) # V的线性变换层self.linear = nn.Linear(n_heads * d_v, d_embedding)self.layer_norm = nn.LayerNorm(d_embedding)def forward(self, Q, K, V, attn_mask): #------------------------- 维度信息 -------------------------------- # Q K V [batch_size, len_q/k/v, embedding_dim] #-----------------------------------------------------------------        residual, batch_size = Q, Q.size(0) # 保留残差连接# 将输入进行线性变换和重塑,以便后续处理q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)        k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2)v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2)#------------------------- 维度信息 -------------------------------- # q_s k_s v_s: [batch_size, n_heads, len_q/k/v, d_q=k/v]#----------------------------------------------------------------- # 将注意力掩码复制到多头 attn_mask: [batch_size, n_heads, len_q, len_k]attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1<

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/19760.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vscode怎么点击路径直接跳转对应文件

在vue项目中经常要引入工具类、组件、模版等&#xff0c;想要直接去看对应文件&#xff0c;只能自己找到对应路径再去打开。 我们可用在js项目中创建一个 jsconfig.json文件&#xff0c;TS项目可以创建tsconfig.json 文件代码 {"compilerOptions": {"baseUrl&…

52-QSplitter类QDockWidget类

一 QSplitter类 Qt提供QSplitter(QSplitter)类来进行分裂布局&#xff0c;QSplitter派生于QFrame。 #ifndef MAINWINDOW_H #define MAINWINDOW_H#include <QMainWindow>class MainWindow : public QMainWindow {Q_OBJECTpublic:MainWindow(QWidget *parent nullptr);~…

linux /www/server/cron内log文件占用空间过大,/www/server/cron是什么内容,/www/server/cron是否可以删除

linux服务器长期使用宝塔自带计划任务&#xff0c;计划任务执行记录占用服务器空间过大&#xff0c;导致服务器根目录爆满&#xff0c;需要长期排查并删除 /www/server/cron 占用空间过大问题处理 /www/server/cron是什么内容&#xff1f;/www/server/cron是否可以删除&#xf…

vue2 bug 小白求助!!!(未解决,大概是浏览器缓存的问题或者是路由的问题)

我的vue2项目出现了一个超级恶心的bug 具体流程&#xff1a; 页面a点击a标签->到页面b->页面b用户退出刷新页面->点击浏览器的返回按钮返回上一页 返回页面后页面没有刷新导致用户名还显示这 项目中没有用keep-alive缓存 也在设置了key 尝试了window.removeEventLi…

GDI双缓冲技术绘图

C#双缓冲绘图技术&#xff1a;提升图形性能与用户体验 目录&#xff1a; 引言C#中的绘图技术 单缓冲绘图双缓冲绘图 双缓冲绘图的优势实现双缓冲绘图的步骤示例代码 创建双缓冲窗体在双缓冲窗体上绘制图形 总结 正文&#xff1a; 引言 在计算机图形编程中&#xff0c;绘图技…

vue UI组件整理

Vue2Vue3Element - The worlds most popular Vue UI frameworkOverview 组件总览 | Element Plushttps://v2.iviewui.com/docs/guide/installhttps://www.iviewui.com/view-ui-plus/guide/introduce按钮 Button - Ant Design按钮 Button - Ant DesignVuetify — A Material Des…

考试题库:华为HCIA-Datacom易错题⑪(含答案解析)

华为认证HCIA-Datacom易错题举例和答案分析。 需要更多题库资料&#xff0c;可以在文末领取 1、运行STP协议的设备端口处于Forwarding状态&#xff0c;下列说法正确的有? A.该端口端口既转发用户流量也处理BPDU报文 B.该端口会根据收到的用户流量构建MAC地址表&#xf…

算法每日一题(python,2024.05.24) day.6

题目来源&#xff08;力扣. - 力扣&#xff08;LeetCode&#xff09;&#xff0c;简单&#xff09; 解题思路&#xff1a; 排序&#xff0b;双指针 先将两个数组进行排序&#xff0c;cursor1和cursor分别指向两个数组的首位&#xff0c;比较两个指针所指的值的大小&#xff0…

swiftUI使用VideoPlayer和AVPlayer播放视频

使用VideoPlayer包播放视频&#xff1a;https://github.com/wxxsw/VideoPlayer 提供一些可供测试的视频链接&#xff0c;不保证稳定可用哦&#xff1a; https://vfx.mtime.cn/Video/2019/06/15/mp4/190615103827358781.mp4https://clips.vorwaerts-gmbh.de/big_buck_bunny.mp…

Flutter 中的 SliverFillRemaining 小部件:全面指南

Flutter 中的 SliverFillRemaining 小部件&#xff1a;全面指南 Flutter 是一个由 Google 开发的跨平台 UI 框架&#xff0c;它允许开发者使用 Dart 语言来构建高性能、美观的移动、Web 和桌面应用。在 Flutter 的丰富组件库中&#xff0c;SliverFillRemaining 是一个用于 Cus…

B端UI设计,演绎高情逸态之妙

B端UI设计&#xff0c;演绎高情逸态之妙

Unity 实现让物体渲染在最前面

演示 实现方案 1.创建一个shader脚本 2.删掉原来的内容&#xff1a;我们自己写 附上完整的shader代码&#xff1a; Shader "Custom/ZTestAlways" {Properties {_Color ("Color Tint",Color) (1,1,1,1)_MainTex("Main Tex",2D) "white&q…

神经网络与深度学习——第3章 线性模型

本文讨论的内容参考自《神经网络与深度学习》https://nndl.github.io/ 第3章 线性模型 线性模型 线性模型&#xff08;Linear Model&#xff09;是机器学习中应用最广泛的模型&#xff0c;指通过样本特征的线性组合来进行预测的模型&#xff0c;给定一个 D D D维样本 x [ x …

Java—认识异常

1. 异常的概念与体系结构 1.1 异常的概念 在生活中&#xff0c;一个人表情痛苦&#xff0c;出于关心&#xff0c;可能会问&#xff1a;你是不是生病了&#xff0c;需要我陪你去看医生吗&#xff1f; 在程序中也是一样&#xff0c;程序猿是一帮办事严谨、追求完美的高科技人才…

SpringBoot集成Quartz

一、引入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-quartz</artifactId> </dependency>二、示例任务 import lombok.extern.slf4j.Slf4j; import org.quartz.JobExecutionContext;…

【风控】可解释机器学习之InterpretML

【风控】可解释机器学习之InterpretML 在金融风控领域&#xff0c;机器学习模型因其强大的预测能力而备受青睐。然而&#xff0c;随着模型复杂性的增加&#xff0c;模型的可解释性逐渐成为一个挑战。监管要求、业务逻辑的透明度以及对模型决策的信任度&#xff0c;都迫切需要我…

getway整合sentinel流控降级

3. 启动sentinel控制台增加流控规则&#xff1a; 根据API分组进行流控&#xff1a; 1.设置API分组&#xff1a; 2.根据API分组进行流控&#xff1a; 自定义统一异常处理&#xff1a; nginx负载配置&#xff1a;

4.nginx反向代理、负载均衡

nginx反向代理、负载均衡 一、反向代理1、语法2、注意事项3、后端服务器记录客户端真实IP3.1 在nginx反向代理时添加x-real-ip字段3.2 后端httpd修改combined日志格式3.3 后端是nginx的情况 二、负载均衡 upstream模块1、负载均衡作用2、调度算法3、配置应用 一、反向代理 隐藏…

FinRobot:一个由大型语言模型(LLM)支持的新型开源AI Agent平台,支持多个金融专业AI Agent

财务分析一直是解读市场趋势、预测经济结果和提供投资策略的关键。这一领域传统上依赖数据&#xff0c;但随着时间的推移&#xff0c;越来越多地使用人工智能&#xff08;AI&#xff09;和算法方法来处理日益增长的复杂数据。AI在金融领域的作用显著增强&#xff0c;它自动化了…

Netty中半包粘包的产生与处理:短连接、固定长度、固定分隔符、预设长度;redis、http协议举例;网络数据的发送和接收过程

目录 粘包、半包 相关概念 网络数据发送和接收过程 Netty半包粘包解决方案 ByteBuf获取和默认大小 短链接 固定长度 固定分隔符 预设长度 常见协议代码举例 redis协议 http协议 参考链接 粘包、半包 相关概念 程序处理过程中我们会通过缓冲区接收数据&#xff0c…