大模型基础——从零实现一个Transformer(5)

大模型基础——从零实现一个Transformer(1)-CSDN博客

大模型基础——从零实现一个Transformer(2)-CSDN博客

大模型基础——从零实现一个Transformer(3)-CSDN博客

大模型基础——从零实现一个Transformer(4)-CSDN博客


一、前言

上一篇文章已经把Encoder模块和Decoder模块都已经实现了,

接下来来实现完整的Transformer

二、Transformer

Transformer整体架构如上,直接把我们实现的Encoder 和Decoder模块引入,开始堆叠

import torch
from torch import  nn,Tensor
from torch.nn import  Embedding#引入自己实现的模块
from llm_base.embedding.PositionalEncoding import PositionalEmbedding
from llm_base.encoder import Encoder
from llm_base.decoder import Decoder
from llm_base.mask.target_mask import make_target_maskclass Transformer(nn.Module):def __init__(self,source_vocab_size:int,target_vocab_size:int,d_model: int = 512,n_heads: int = 8,num_encoder_layers: int = 6,num_decoder_layers: int = 6,d_ff: int = 2048,dropout: float = 0.1,max_positions:int = 5000,pad_idx: int = 0,norm_first: bool=False) -> None:''':param source_vocab_size: size of the source vocabulary.:param target_vocab_size: size of the target vocabulary.:param d_model: dimension of embeddings. Defaults to 512.:param n_heads: number of heads. Defaults to 8.:param num_encoder_layers: number of encoder blocks. Defaults to 6.:param num_decoder_layers: number of decoder blocks. Defaults to 6.:param d_ff: dimension of inner feed-forward network. Defaults to 2048.:param dropout: dropout ratio. Defaults to 0.1.:param max_positions: maximum sequence length for positional encoding. Defaults to 5000.:param pad_idx: pad index. Defaults to 0.:param norm_first: if True, layer norm is done prior to attention and feedforward operations(Pre-Norm).Otherwise it's done after(Post-Norm). Default to False.'''super().__init__()# Token embeddingself.src_embeddings = Embedding(source_vocab_size,d_model)self.target_embeddings = Embedding(target_vocab_size,d_model)# Position embeddingself.encoder_pos = PositionalEmbedding(d_model,dropout,max_positions)self.decoder_pos = PositionalEmbedding(d_model,dropout,max_positions)# 编码层定义self.encoder = Encoder(d_model,num_encoder_layers,n_heads,d_ff,dropout,norm_first)# 解码层定义self.decoder = Decoder(d_model,num_decoder_layers,n_heads,d_ff,dropout,norm_first)self.pad_idx = pad_idxdef encode(self,src:Tensor,src_mask: Tensor=None,keep_attentions: bool=False) -> Tensor:'''编码过程:param src: (batch_size, src_seq_length) the sequence to the encoder:param src_mask: (batch_size, 1, src_seq_length) the mask for the sequence:param keep_attentions:  whether keep attention weigths or not. Defaults to False.:return: (batch_size, seq_length, d_model) encoder output'''src_embedding_tensor = self.src_embeddings(src)src_embedded = self.encoder_pos(src_embedding_tensor)return self.encoder(src_embedded,src_mask,keep_attentions)def decode(self,target_tensor: Tensor,memory: Tensor,target_mask: Tensor = None,memory_mask: Tensor = None,keep_attentions: bool = False) ->Tensor:''':param target_tensor: (batch_size, tgt_seq_length) the sequence to the decoder.:param memory: (batch_size, src_seq_length, d_model) the  sequence from the last layer of the encoder.:param target_mask: (batch_size, 1, 1, tgt_seq_length) the mask for the target sequence. Defaults to None.:param memory_mask: (batch_size, 1, 1, src_seq_length) the mask for the memory sequence. Defaults to None.:param keep_attentions:  whether keep attention weigths or not. Defaults to False.:return: output (batch_size, tgt_seq_length, tgt_vocab_size)'''target_embedding_tensor = self.target_embeddings(target_tensor)target_embedded = self.decoder_pos(target_embedding_tensor)# logits (batch_size, target_seq_length, d_model)logits = self.decoder(target_embedded,memory,target_mask,memory_mask,keep_attentions)return  logitsdef forward(self,src: Tensor,target: Tensor,src_mask: Tensor=None,target_mask: Tensor=None,keep_attention:bool=False)->Tensor:''':param src: (batch_size, src_seq_length) the sequence to the encoder:param target:  (batch_size, tgt_seq_length) the sequence to the decoder:param src_mask::param target_mask::param keep_attention: whether keep attention weigths or not. Defaults to False.:return: (batch_size, tgt_seq_length, tgt_vocab_size)'''memory = self.encode(src,src_mask,keep_attention)return  self.decode(target,memory,target_mask,src_mask,keep_attention)

三、测试

写个简单的main函数,测试一下整体网络是否正常

if __name__ == '__main__':source_vocab_size = 300target_vocab_size = 300# padding对应的index,一般都是0pad_idx = 0batch_size = 1max_positions = 20model = Transformer(source_vocab_size=source_vocab_size,target_vocab_size=target_vocab_size)src_tensor = torch.randint(0,source_vocab_size,(batch_size,max_positions))target_tensor = torch.randint(0,source_vocab_size,(batch_size,max_positions))## 最后5位置是paddingsrc_tensor[:,-5:] = 0## 最后10位置是paddingtarget_tensor[:, -10:] = 0src_mask = (src_tensor != pad_idx).unsqueeze(1)targe_mask =  make_target_mask(target_tensor)logits = model(src_tensor,target_tensor,src_mask,targe_mask)print(logits.shape) #torch.Size([1, 20, 512])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/29219.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深度學習筆記12-優化器對比(Tensorflow)

🍨 本文為🔗365天深度學習訓練營 中的學習紀錄博客🍖 原作者:K同学啊 | 接輔導、項目定制 一、我的環境 電腦系統:Windows 10 顯卡:NVIDIA Quadro P620 語言環境:Python 3.7.0 開發工具&…

基于GTX的64B66B编码IP生成(高速收发器二十)

点击进入高速收发器系列文章导航界面 1、配置GTX IP 相关参数 前文讲解了64B66B编码解码原理,以及GTX IP实现64B66B编解码的相关信号组成,本文生成64B66B编码的GTX IP。 首先如下图所示,需要对GTX共享逻辑进行设置,为了便于扩展&a…

Gnu/Linux 之 C 语言函数列表初步整理

Linux为C语言编程提供了丰富的函数库,这些函数库覆盖了从基本输入输出、文件操作、字符串处理到系统调用等各个方面。以下是一些常见的Linux C函数示例: 输入输出函数 printf(): 输出格式化的字符串到标准输出。scanf(): 从标准输入读取格式化的数据。…

cuda c programming guide - 编程接口

cuda c是对c的扩充,具体参考C Language Extensions。 cuda c编写的kernel需要使用nvcc进行编译。 运行时库需要使用cuda runtime,其提供了在host上执行的c/c函数,可用于allocate/deallcote device memory,transfer data between…

MATLAB算法实战应用案例精讲-【图像处理】SLAM技术详解(基础篇)(四)

目录 几个高频面试题目 如何插入LinK3D、CSF、BALM来直接插入各个SLAM框架中 算法原理 SLAM-Open3D 1. Open3D环境安装 2. Open3D示例 3. Open3D在SLAM当中的应用 同步定位和建图 地图生成和姿态估计 机械扫描激光雷达: 固态激光雷达: 闪光激光雷达: 为 SLAM 选…

逆向学习 MFC 篇:视图分割和在 C++ 的 Windows 窗口程序中添加图标的方法

本节课在线学习视频(网盘地址,保存后即可免费观看): ​​​​https://pan.quark.cn/s/a165bd3ba6f3​​ Microsoft Foundation Class (MFC) 是用于创建基于 Windows 的应用程序的 C 库。它提供了丰富的类库来简化 Windows 编程&…

【开发工具】git服务器端安装部署+客户端配置

自己安装一个轻量级的git服务端,仅仅作为代码维护,尤其适合个人代码管理。毕竟代码的版本管理是很有必要的。 这里把git服务端部署在centos系统里,部署完成后可以通过命令行推拉代码,进行版本和用户管理。 一、服务端安装配置 …

【2024最新华为OD-C/D卷试题汇总】[支持在线评测] 内存访问热度分析(100分) - 三语言AC题解(Python/Java/Cpp)

🍭 大家好这里是清隆学长 ,一枚热爱算法的程序员 ✨ 本系列打算持续跟新华为OD-C/D卷的三语言AC题解 💻 ACM银牌🥈| 多次AK大厂笔试 | 编程一对一辅导 👏 感谢大家的订阅➕ 和 喜欢💗 &#x1f…

windows环境下,怎么查看本机的IP、MAC地址和端口占用情况

1.输入ipconfig,按回车。即查看了IP地址,子码掩码,网关信息。 2.输入ipconfig/all,按回车。即查看了包含IP地址,子码掩码,网关信息以及MAC地址 3.我们有时在启动应用程序的时候提示端口被占用,如何知道谁占有了我们需要…

Vue57-组件的自定义事件_解绑

给谁绑的自定义事件,就找谁去触发;给谁绑的自定义事件,就找谁去解绑; 一、解绑自定义事件 1-1、解绑一个自定义事件 到student.vue组件中去解绑。 1-2、解绑多个自定义事件 使用数组来解绑多个。 1-3、解绑所有的自定义事件 二、…

Android Studio无法连接夜神模拟器的解决方案

一、AS检测不到夜神模拟器 1、问题描述 在按照教程【如何安装和使用Android夜神模拟器】进入夜神的bin目录,输入连接命令回车后,终端显示的already connected to 127.0.0.1:62001,但是AS的Running Devices并没有显示夜神模拟器。 2、解决方…

Node.js 入门:

Node.js 是一个开源、跨平台的 JavaScript 运行时环境,它允许开发者在浏览器之外编写命令行工具和服务器端脚本。以下是一些关于 Node.js 的基础教程: 1. **Node.js 入门**: - 了解 Node.js 的基本概念,包括它是一个基于 Chro…

Arm和高通的法律之争将扰乱人工智能驱动的PC浪潮

Arm和高通的法律之争将扰乱人工智能驱动的PC浪潮 科技行业高管和专家表示,两大科技巨头之间长达两年的法律大战可能会扰乱人工智能驱动的新一代个人电脑浪潮。 上周,来自微软(Microsoft)、华硕(Asus)、宏碁(Acer)、高通(Qualcomm)等公司的高管在台北举行…

IPV6 地址分类1

1、单播地址(ABC) 一对一 只有单播地址能作为源地址,也可作为目标地址 2、多播(组播)地址 (224——239) 一对多 作为目标地址 3、任意播地址-----一到最近 单播地址 1、AGUA 全球…

【HarmonyOS NEXT】如何通过h5拉起应用(在华为浏览器中拉起应用)

华为浏览器支持拉起外部应用 浏览器访问网页经常会遇到deeplink的场景。当前处理方案统一为使用AMS系统能力startAbility去隐式拉起。传递的want参数为 { "actions": "ohos.want.action.viewData", "uri": deeplink链接 } 网页需要给自己的应用拉…

计算机毕业设计Python+Vue.js知识图谱音乐推荐系统 音乐爬虫可视化 音乐数据分析 大数据毕设 大数据毕业设计 机器学习 深度学习 人工智能

开发技术 协同过滤算法、机器学习、LSTM、vue.js、echarts、django、Python、MySQL 创新点协同过滤推荐算法、爬虫、数据可视化、LSTM情感分析、短信、身份证识别 补充说明 适合大数据毕业设计、数据分析、爬虫类计算机毕业设计 介绍 音乐数据的爬取:爬取歌曲、…

clip_en的使用学习

代码分析 import torch import cn_clip.clip as clip from PIL import Image from cn_clip.clip import load_from_name, available_modelsprint("Torch version:", torch.__version__) device "cuda" if torch.cuda.is_available() else "cpu"…

第 10 章 监控系统 | 实战案例 - Nginx 监控

👉 本文目标:为 Nginx 安装 nginx-prometheus-exporter,实现对 Nginx 的监控。 👀 本文内容: 安装 Nginx Prometheus Exporter,暴露 Nginx 指标配置 Prometheus 抓取 Nginx Prometheus Exporter 暴露的指标数据【配置 Recording Rule,便于缓存/加速 Dashboard 频繁访问…

深度学习推理显卡设置

深度学习推理显卡设置 进入NVIDIA控制面板,选择 “管理3D设置”设置 "低延时模式"为 "“超高”"设置 “电源管理模式” 为 “最高性能优先” 使用锁频来获得稳定的推理 法一:命令行操作 以管理员身份打开CMD查看GPU核心可用频率&…

【如何使用python获取excel中sheet页的样式】

如何使用python获取excel中sheet页的样式 要获取Excel中sheet页的样式,特别是单元格的样式,如字体、颜色、边框等,你可以使用openpyxl库,但需要深入一些底层的操作,因为openpyxl的主要API不直接暴露这些样式信息。 以…