大模型基础——从零实现一个Transformer(5)

大模型基础——从零实现一个Transformer(1)-CSDN博客

大模型基础——从零实现一个Transformer(2)-CSDN博客

大模型基础——从零实现一个Transformer(3)-CSDN博客

大模型基础——从零实现一个Transformer(4)-CSDN博客


一、前言

上一篇文章已经把Encoder模块和Decoder模块都已经实现了,

接下来来实现完整的Transformer

二、Transformer

Transformer整体架构如上,直接把我们实现的Encoder 和Decoder模块引入,开始堆叠

import torch
from torch import  nn,Tensor
from torch.nn import  Embedding#引入自己实现的模块
from llm_base.embedding.PositionalEncoding import PositionalEmbedding
from llm_base.encoder import Encoder
from llm_base.decoder import Decoder
from llm_base.mask.target_mask import make_target_maskclass Transformer(nn.Module):def __init__(self,source_vocab_size:int,target_vocab_size:int,d_model: int = 512,n_heads: int = 8,num_encoder_layers: int = 6,num_decoder_layers: int = 6,d_ff: int = 2048,dropout: float = 0.1,max_positions:int = 5000,pad_idx: int = 0,norm_first: bool=False) -> None:''':param source_vocab_size: size of the source vocabulary.:param target_vocab_size: size of the target vocabulary.:param d_model: dimension of embeddings. Defaults to 512.:param n_heads: number of heads. Defaults to 8.:param num_encoder_layers: number of encoder blocks. Defaults to 6.:param num_decoder_layers: number of decoder blocks. Defaults to 6.:param d_ff: dimension of inner feed-forward network. Defaults to 2048.:param dropout: dropout ratio. Defaults to 0.1.:param max_positions: maximum sequence length for positional encoding. Defaults to 5000.:param pad_idx: pad index. Defaults to 0.:param norm_first: if True, layer norm is done prior to attention and feedforward operations(Pre-Norm).Otherwise it's done after(Post-Norm). Default to False.'''super().__init__()# Token embeddingself.src_embeddings = Embedding(source_vocab_size,d_model)self.target_embeddings = Embedding(target_vocab_size,d_model)# Position embeddingself.encoder_pos = PositionalEmbedding(d_model,dropout,max_positions)self.decoder_pos = PositionalEmbedding(d_model,dropout,max_positions)# 编码层定义self.encoder = Encoder(d_model,num_encoder_layers,n_heads,d_ff,dropout,norm_first)# 解码层定义self.decoder = Decoder(d_model,num_decoder_layers,n_heads,d_ff,dropout,norm_first)self.pad_idx = pad_idxdef encode(self,src:Tensor,src_mask: Tensor=None,keep_attentions: bool=False) -> Tensor:'''编码过程:param src: (batch_size, src_seq_length) the sequence to the encoder:param src_mask: (batch_size, 1, src_seq_length) the mask for the sequence:param keep_attentions:  whether keep attention weigths or not. Defaults to False.:return: (batch_size, seq_length, d_model) encoder output'''src_embedding_tensor = self.src_embeddings(src)src_embedded = self.encoder_pos(src_embedding_tensor)return self.encoder(src_embedded,src_mask,keep_attentions)def decode(self,target_tensor: Tensor,memory: Tensor,target_mask: Tensor = None,memory_mask: Tensor = None,keep_attentions: bool = False) ->Tensor:''':param target_tensor: (batch_size, tgt_seq_length) the sequence to the decoder.:param memory: (batch_size, src_seq_length, d_model) the  sequence from the last layer of the encoder.:param target_mask: (batch_size, 1, 1, tgt_seq_length) the mask for the target sequence. Defaults to None.:param memory_mask: (batch_size, 1, 1, src_seq_length) the mask for the memory sequence. Defaults to None.:param keep_attentions:  whether keep attention weigths or not. Defaults to False.:return: output (batch_size, tgt_seq_length, tgt_vocab_size)'''target_embedding_tensor = self.target_embeddings(target_tensor)target_embedded = self.decoder_pos(target_embedding_tensor)# logits (batch_size, target_seq_length, d_model)logits = self.decoder(target_embedded,memory,target_mask,memory_mask,keep_attentions)return  logitsdef forward(self,src: Tensor,target: Tensor,src_mask: Tensor=None,target_mask: Tensor=None,keep_attention:bool=False)->Tensor:''':param src: (batch_size, src_seq_length) the sequence to the encoder:param target:  (batch_size, tgt_seq_length) the sequence to the decoder:param src_mask::param target_mask::param keep_attention: whether keep attention weigths or not. Defaults to False.:return: (batch_size, tgt_seq_length, tgt_vocab_size)'''memory = self.encode(src,src_mask,keep_attention)return  self.decode(target,memory,target_mask,src_mask,keep_attention)

三、测试

写个简单的main函数,测试一下整体网络是否正常

if __name__ == '__main__':source_vocab_size = 300target_vocab_size = 300# padding对应的index,一般都是0pad_idx = 0batch_size = 1max_positions = 20model = Transformer(source_vocab_size=source_vocab_size,target_vocab_size=target_vocab_size)src_tensor = torch.randint(0,source_vocab_size,(batch_size,max_positions))target_tensor = torch.randint(0,source_vocab_size,(batch_size,max_positions))## 最后5位置是paddingsrc_tensor[:,-5:] = 0## 最后10位置是paddingtarget_tensor[:, -10:] = 0src_mask = (src_tensor != pad_idx).unsqueeze(1)targe_mask =  make_target_mask(target_tensor)logits = model(src_tensor,target_tensor,src_mask,targe_mask)print(logits.shape) #torch.Size([1, 20, 512])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/29219.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深度學習筆記12-優化器對比(Tensorflow)

🍨 本文為🔗365天深度學習訓練營 中的學習紀錄博客🍖 原作者:K同学啊 | 接輔導、項目定制 一、我的環境 電腦系統:Windows 10 顯卡:NVIDIA Quadro P620 語言環境:Python 3.7.0 開發工具&…

基于GTX的64B66B编码IP生成(高速收发器二十)

点击进入高速收发器系列文章导航界面 1、配置GTX IP 相关参数 前文讲解了64B66B编码解码原理,以及GTX IP实现64B66B编解码的相关信号组成,本文生成64B66B编码的GTX IP。 首先如下图所示,需要对GTX共享逻辑进行设置,为了便于扩展&a…

【开发工具】git服务器端安装部署+客户端配置

自己安装一个轻量级的git服务端,仅仅作为代码维护,尤其适合个人代码管理。毕竟代码的版本管理是很有必要的。 这里把git服务端部署在centos系统里,部署完成后可以通过命令行推拉代码,进行版本和用户管理。 一、服务端安装配置 …

【2024最新华为OD-C/D卷试题汇总】[支持在线评测] 内存访问热度分析(100分) - 三语言AC题解(Python/Java/Cpp)

🍭 大家好这里是清隆学长 ,一枚热爱算法的程序员 ✨ 本系列打算持续跟新华为OD-C/D卷的三语言AC题解 💻 ACM银牌🥈| 多次AK大厂笔试 | 编程一对一辅导 👏 感谢大家的订阅➕ 和 喜欢💗 &#x1f…

windows环境下,怎么查看本机的IP、MAC地址和端口占用情况

1.输入ipconfig,按回车。即查看了IP地址,子码掩码,网关信息。 2.输入ipconfig/all,按回车。即查看了包含IP地址,子码掩码,网关信息以及MAC地址 3.我们有时在启动应用程序的时候提示端口被占用,如何知道谁占有了我们需要…

Vue57-组件的自定义事件_解绑

给谁绑的自定义事件,就找谁去触发;给谁绑的自定义事件,就找谁去解绑; 一、解绑自定义事件 1-1、解绑一个自定义事件 到student.vue组件中去解绑。 1-2、解绑多个自定义事件 使用数组来解绑多个。 1-3、解绑所有的自定义事件 二、…

Android Studio无法连接夜神模拟器的解决方案

一、AS检测不到夜神模拟器 1、问题描述 在按照教程【如何安装和使用Android夜神模拟器】进入夜神的bin目录,输入连接命令回车后,终端显示的already connected to 127.0.0.1:62001,但是AS的Running Devices并没有显示夜神模拟器。 2、解决方…

Arm和高通的法律之争将扰乱人工智能驱动的PC浪潮

Arm和高通的法律之争将扰乱人工智能驱动的PC浪潮 科技行业高管和专家表示,两大科技巨头之间长达两年的法律大战可能会扰乱人工智能驱动的新一代个人电脑浪潮。 上周,来自微软(Microsoft)、华硕(Asus)、宏碁(Acer)、高通(Qualcomm)等公司的高管在台北举行…

计算机毕业设计Python+Vue.js知识图谱音乐推荐系统 音乐爬虫可视化 音乐数据分析 大数据毕设 大数据毕业设计 机器学习 深度学习 人工智能

开发技术 协同过滤算法、机器学习、LSTM、vue.js、echarts、django、Python、MySQL 创新点协同过滤推荐算法、爬虫、数据可视化、LSTM情感分析、短信、身份证识别 补充说明 适合大数据毕业设计、数据分析、爬虫类计算机毕业设计 介绍 音乐数据的爬取:爬取歌曲、…

深度学习推理显卡设置

深度学习推理显卡设置 进入NVIDIA控制面板,选择 “管理3D设置”设置 "低延时模式"为 "“超高”"设置 “电源管理模式” 为 “最高性能优先” 使用锁频来获得稳定的推理 法一:命令行操作 以管理员身份打开CMD查看GPU核心可用频率&…

云计算 | (四)基本云安全

文章目录 📚基本云安全🐇云安全背景🐇基本术语和概念⭐️风险(risk)⭐️安全需求🐇威胁作用者⭐️威胁作用者(threat agent)⭐️匿名攻击者(anonymous attacker)⭐️恶意服务作用者(malicious service agent)⭐️授信的攻击者(trusted attacker)⭐️恶意的内部人员(mal…

有趣且重要的JS知识合集(22)树相关的算法

0、举例&#xff1a;树形结构原始数据 1、序列化树形结构 /*** 平铺序列化树形结构* param tree 树形结构* param result 转化后一维数组* returns Array<TreeNode>*/ export function flattenTree(tree, result []) {if (tree.length 0) {return result}for (const …

开发一个python工具,pdf转图片,并且截成单个图片,然后修整没用的白边

今天推荐一键款本人开发的pdf转单张图片并截取没有用的白边工具 一、开发背景&#xff1a; 业务需要将一个pdf文件展示在前端显示&#xff0c;但是基于各种原因&#xff0c;放弃了h5使用插件展示 原因有多个&#xff0c;文件资源太大加载太慢、pdf展示兼容性问题、pdf展示效果…

CSDN 自动上传图片并优化Markdown的图片显示

文章目录 完整代码一、上传资源二、替换 MD 中的引用文件为在线链接参考 完整代码 完整代码由两个文件组成&#xff0c;upload.py 和 main.py&#xff0c;放在同一目录下运行 main.py 就好&#xff01; # upload.py import requests class UploadPic: def __init__(self, c…

力扣每日一题 6/17 枚举+双指针

博客主页&#xff1a;誓则盟约系列专栏&#xff1a;IT竞赛 专栏关注博主&#xff0c;后期持续更新系列文章如果有错误感谢请大家批评指出&#xff0c;及时修改感谢大家点赞&#x1f44d;收藏⭐评论✍ 522.最长特殊序列II【中等】 题目&#xff1a; 给定字符串列表 strs &…

【Ubuntu通用压力测试】Ubuntu16.04 CPU压力测试

使用 stress 对CPU进行压力测试 我也是一个ubuntu初学者&#xff0c;分享是Linux的优良美德。写的不好请大佬不要喷&#xff0c;多谢支持。 sudo apt-get update 日常先更新再安装东西不容易出错 sudo apt-get upgrade -y 继续升级一波 sudo apt-get install -y linux-tools…

Stable Diffusion文生图模型训练入门实战(完整代码)

Stable Diffusion 1.5&#xff08;SD1.5&#xff09;是由Stability AI在2022年8月22日开源的文生图模型&#xff0c;是SD最经典也是社区最活跃的模型之一。 以SD1.5作为预训练模型&#xff0c;在火影忍者数据集上微调一个火影风格的文生图模型&#xff08;非Lora方式&#xff…

Python | Leetcode Python题解之第162题寻找峰值

题目&#xff1a; 题解&#xff1a; class Solution:def findPeakElement(self, nums: List[int]) -> int:n len(nums)# 辅助函数&#xff0c;输入下标 i&#xff0c;返回 nums[i] 的值# 方便处理 nums[-1] 以及 nums[n] 的边界情况def get(i: int) -> int:if i -1 or…

STM32单片机DMA存储器详解

文章目录 1. DMA概述 2. 存储器映像 3. DMA框架图 4. DMA请求 5. 数据宽度与对齐 6. DMA数据转运 7. ADC扫描模式和DMA 8. 代码示例 1. DMA概述 DMA&#xff08;Direct Memory Access&#xff09;可以直接访问STM32内部的存储器&#xff0c;DMA是一种技术&#xff0c;…

【 ARMv8/ARMv9 硬件加速系列 3.5.1 -- SVE 谓词寄存器有多少位?】

文章目录 SVE 谓词寄存器(predicate registers)简介SVE 谓词寄存器的位数SVE 谓词寄存器对向量寄存器的控制SVE 谓词寄存器位数计算SVE 谓词寄存器小结SVE 谓词寄存器(predicate registers)简介 ARMv9的Scalable Vector Extension (SVE) 引入了谓词寄存器(Predicate Register…