【Transformer从零开始代码实现 pytoch版】(五)总架构类的实现

Transformer总架构

在这里插入图片描述
在实现完输入部分、编码器、解码器和输出部分之后,就可以封装各个部件为一个完整的实体类了。

【Transformer从零开始代码实现 pytoch版】(一)输入部件:embedding+positionalEncoding

【Transformer从零开始代码实现 pytoch版】(二)Encoder编码器组件:mask + attention + feed forward + add&norm

【Transformer从零开始代码实现 pytoch版】(三)Decoder编码器组件:多头自注意力+多头注意力+全连接层+规范化层

【Transformer从零开始代码实现 pytoch版】(四)输出部件:Linear+softmax

编码器-解码器总结构代码实现

class EncoderDecoder(nn.Module):""" 编码器解码器架构实现、定义了初始化、forward、encode和decode部件"""def __init__(self, encoder, decoder, source_embed, target_embed, generator):""" 传入五大部件参数:param encoder: 编码器:param decoder: 解码器:param source_embed: 源数据embedding函数:param target_embed: 目标数据embedding函数:param generator: 输出部分类被生成器对象"""super(EncoderDecoder, self).__init__()self.encoder = encoderself.decoder = decoderself.src_embed = source_embedself.tgt_embed = target_embedself.generator = generator					# 生成器后面会专门用到def forward(self, source, target, source_mask, target_mask):""" 构建数据流入流出:param source: 源数据:param target: 目标数据:param source_mask: 源数据掩码张量:param target_mask: 目标数据掩码张量:return:"""# 注意这里先用的encode和decode函数,又才在其函数里面,再用了encoder和decoderreturn self.decode(self.encode(source, source_mask), source_mask, target, target_mask)def encode(self, source, source_mask):""" 编码函数,编码部件:param source: 源数据张量:param source_mask: 源数据的掩码张量:return: 经过解码器的输出"""return self.encoder(self.src_embed(source), source_mask)def decode(self, memory, source_mask, target, target_mask):""" 解码函数,解码部件:param memory:编码器的输出QV:param source_mask:源数据的掩码张量:param target:目标数据:param target_mask:目标数据的掩码张量:return:"""return self.decoder(self.tgt_embed(target), memory, source_mask, target_mask)

示例

# 输入参数
vocab_size = 1000
size = d_model = 512# 编码器部分
dropout = 0.2
d_ff = 64				# 隐藏层参数
head = 8				# 注意力头数
c = copy.deepcopy
attn = MultiHeadedAttention(head, d_model)
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
encoder_layer = EncoderLayer(size, c(attn), c(ff), dropout)
encoder_N = 8
encoder = Encoder(encoder_layer, encoder_N)# 解码器部分
dropout = 0.2
d_ff = 64
head = 8
c = copy.deepcopy
attn = MultiHeadedAttention(head, d_model)
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
decoder_layer = DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout)
decoder_N = 8
decoder = Decoder(decoder_layer, decoder_N)# 用了nn的embedding作为输入示意
source_embed = nn.Embedding(vocab_size, d_model)
target_embed = nn.Embedding(vocab_size, d_model)
generator = Generator(d_model, vocab_size)# 输入张量和掩码张量
source = target = torch.LongTensor([[100, 2, 421, 508], [491, 998, 1, 221]])
source_mask = target_mask = torch.zeros(2, 4, 4)# 实例化编码器-解码器,再带入参数实现
ed = EncoderDecoder(encoder, decoder, source_embed, target_embed, generator)
ed_res = ed(source, target, source_mask, target_mask)
print(f"ed_res: {ed_res}\n shape:{ed_res.shape}")ed_res: tensor([[[-0.1861,  0.0849, -0.3015,  ...,  1.1753, -1.4933,  0.2484],[-0.3626,  1.3383,  0.1739,  ...,  1.1304,  2.0266, -0.5929],[ 0.0785,  1.4932,  0.3184,  ..., -0.2021, -0.2330,  0.1539],[-0.9703,  1.1944,  0.1763,  ...,  0.1586, -0.6066, -0.6147]],[[-0.9216, -0.0309, -0.6490,  ...,  1.0177,  0.5574,  0.4873],[-1.4097,  0.6678, -0.6708,  ...,  1.1176,  0.1959, -1.2494],[-0.3204,  1.2794, -0.4022,  ...,  0.6319, -0.4709,  1.0520],[-1.3238,  1.1470, -0.9943,  ...,  0.4026,  1.0911,  0.1327]]],grad_fn=<AddBackward0>)shape:torch.Size([2, 4, 512])

编码器-解码器模型构建函数

def make_model(source_vocab, target_vocab, N=6, d_model=512, d_ff=2048, head=8, dropout=0.1):""" 用于构建模型:param source_vocab: 源数据词汇总数:param target_vocab: 目标词汇总数:param N: 解码器/解码器堆叠层数:param d_model: 词嵌入维度:param d_ff: 前馈全连接层隐藏层维度:param dropout: 置0比率:return: 返回构建编码器-解码器模型"""# 拷贝函数,来保证拷贝的函数彼此之间相互独立,不受干扰c = copy.deepcopy# 实例化多头注意力attn = MultiHeadedAttention(head, d_model)# 实例化全连接层ff = PositionwiseFeedForward(d_model, d_ff, dropout)# 实例化位置编码类,得到对象positionposition = PositionalEncoding(d_model, dropout)model = EncoderDecoder(Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),nn.Sequential(Embedding(d_model, source_vocab), c(position)),nn.Sequential(Embedding(d_model, source_vocab), c(position)),Generator(d_model, target_vocab))# 模型结构构建完成后,初始化模型中的参数for p in model.parameters():# 这里判定当参数维度大于1的时候,则会将其初始化成一个服从均匀分布的矩阵if p.dim() > 1:nn.init.xavier_normal(p)        # 生成服从正态分布的数,默认为U(-1, 1),更改第二个参数可以改值return model

示例

source_vocab = target_vocab = 11
N = 6
res = make_model(source_vocab, target_vocab, N)
print(res)EncoderDecoder((encoder): Encoder((layers): ModuleList((0-5): 6 x EncoderLayer((self_attn): MultiHeadedAttention((linears): ModuleList((0-3): 4 x Linear(in_features=512, out_features=512, bias=True))(dropout): Dropout(p=0.1, inplace=False))(feed_forward): PositionwiseFeedForward((w1): Linear(in_features=512, out_features=2048, bias=True)(w2): Linear(in_features=2048, out_features=512, bias=True)(dropout): Dropout(p=0.1, inplace=False))(sublayer): ModuleList((0-1): 2 x SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1, inplace=False)))))(norm): LayerNorm())(decoder): Decoder((layers): ModuleList((0-5): 6 x DecoderLayer((self_attn): MultiHeadedAttention((linears): ModuleList((0-3): 4 x Linear(in_features=512, out_features=512, bias=True))(dropout): Dropout(p=0.1, inplace=False))(src_attn): MultiHeadedAttention((linears): ModuleList((0-3): 4 x Linear(in_features=512, out_features=512, bias=True))(dropout): Dropout(p=0.1, inplace=False))(feed_forward): PositionwiseFeedForward((w1): Linear(in_features=512, out_features=2048, bias=True)(w2): Linear(in_features=2048, out_features=512, bias=True)(dropout): Dropout(p=0.1, inplace=False))(sublayer): ModuleList((0-2): 3 x SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1, inplace=False)))))(norm): LayerNorm())(src_embed): Sequential((0): Embedding((lut): Embedding(512, 11))(1): PositionalEncoding((dropout): Dropout(p=0.1, inplace=False)))(tgt_embed): Sequential((0): Embedding((lut): Embedding(512, 11))(1): PositionalEncoding((dropout): Dropout(p=0.1, inplace=False)))(generator): Generator((project): Linear(in_features=512, out_features=11, bias=True))
)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/140720.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring boot 整合elasticsearch

文章目录 初始化RestClient 初始化RestClient 在elasticsearch提供的API中&#xff0c;与elasticsearch一切交互都封装在一个名为RestHighLevelClient的类中&#xff0c;必须先完成这个对象的初始化&#xff0c;建立与elasticsearch的连接。 分为三步&#xff1a; 1&#xf…

CMT2300A超低功耗127-1020MHz Sub-1GHz全频段SUB-1G 射频收发芯片

CMT2300A超低功耗127-1020MHz Sub-1GHz全频段SUB-1G 射频收发芯片 Sub-1GHz&#xff0c;是指小于1GHz频率的统称。Sub-1GHz无线电频段应用的主要特点&#xff1a;&#xff08;1&#xff09;频率较低波长较长&#xff0c;传输距离远&#xff0c;穿透性强&#xff1b;&#xff0…

xinput1_3.dll丢失的详细解决步骤办法和比较,五种有效的解决办法

今天想和大家分享一个电脑中经常出现的问题——xinput1_3.dll丢失。这个文件丢失是一件常见的问题。不知道小伙伴们有没有遇到过这样的问题&#xff0c;如果你遇到这样的问题今天就教大家xinput1_3.dll丢失的详细解决步骤办法和比较&#xff0c;五种有效的解决办法。 一.xinput…

YOLOv5 分类模型的后处理

YOLOv5 分类模型的后处理 flyfish 简化源码测试 import torch import numpy as np from torchvision import transforms import torch.nn.functional as Fdata0 np.random.random((1, 7)) data0 np.round(data0,7) print(data0.shape) print(data0) data1 torch.from_n…

力扣labuladong一刷day7共3题

力扣labuladong一刷day7共3题 文章目录 力扣labuladong一刷day7共3题一、216. 组合总和 III二、111. 二叉树的最小深度三、752. 打开转盘锁 一、216. 组合总和 III 题目链接&#xff1a;https://leetcode.cn/problems/combination-sum-iii/ 思路&#xff1a;还是组合只是既有n…

【Axure高保真原型】树切换动态面板案例

今天和大家分享树切换动态面板的原型模板&#xff0c;点击树的箭头可以打开或者收起子节点&#xff0c;点击最后一级人物节点&#xff0c;可以切换右侧面板的状态到对应的页面&#xff0c;左侧的树是通过中继器制作的&#xff0c;使用简单&#xff0c;只需要按要求填写中继器表…

各种ui框架的 form校验 validator获取不到value

// form-item 配置prop prop"user.name" // rules rules: {user.name: [message: "xxxxx",validator(rule, val, callback) {// val 就是user.name的值},] }如: 对象的sysUser.userName <n-form ref"formRefuser" :model"modelUser&qu…

浅谈jvm

前置知识补充 JDK、JRE、JVM是什么&#xff1f;区别与联系&#xff1f; 区别&#xff1a; JDK&#xff08;Java Development Kit&#xff09;&#xff1a;Java开发工具包 主要包括 Java运行环境、Java基础库及 Java工具。 JRE&#xff08;Java Runtime Environment&#xf…

selenium基本使用、无头浏览器(chrome、FireFox)、搜索标签

selenium基本使用 这个模块&#xff1a;既能发请求&#xff0c;又能解析&#xff0c;还能执行js selenium最初是一个自动化测试工具,而爬虫中使用它主要是为了解决requests无法直接执行 JavaScript代码的问题 selenium 会做web方向的自动化测试appnium 会做 app方向的自动化…

使用迁移学习在线校准深度学习模型

使用迁移学习在线校准深度学习模型 本文参考的是2023年发表于Engineering Applications of Artificial Intelligence, EAAI的Deep Gaussian mixture adaptive network for robust soft sensor modeling with a closed-loop calibration mechanism 1. 动机 概念漂移导致历史训…

SpringBoot--中间件技术-1:任务管理,异步任务,任务调度,发邮件Mail的实现,含代码

SpringBoot中的事务管理 关键注解&#xff1a; 设置事务&#xff08;声明事务管理&#xff09;&#xff0c;写在业务层的方法上&#xff1a; Transactional(isolation Isolation.DEFAULT) Transactional(propagation Propagation.REQUIRED) 开启事务&#xff0c;设置在启动…

# Oracle 库常见问题排查

Oracle 库常见问题排查 文章目录 Oracle 库常见问题排查查询数据库的相关信息查看正在执行的语句杀掉正在执行的sql查看未提交的事务查看锁表 查询数据库的相关信息 查看正在执行的语句 SELECT s.sid, s.serial#, s.username, s.status, s.sql_id, s.sql_child_number, sq.sq…

React 共享组件状态及其实践

React 是一个强大的JavaScript库&#xff0c;它提供了一种简单的方式来构建用户界面。然而&#xff0c;随着应用规模的增长&#xff0c;状态管理成为一个复杂的问题。本篇文章将深入探讨如何在React组件之间共享状态。 状态提升 首先&#xff0c;我们来谈谈"状态提升&qu…

完全免费!超好用的IDEA插件推荐:Apipost-Helper

Idea 是一款功能强大的集成开发环境&#xff08;IDE&#xff09;&#xff0c;它可以帮助开发人员更加高效地编写、调试和部署软件应用程序,Idea 还具有许多插件和扩展&#xff0c;可以根据开发人员的需要进行定制和扩展&#xff0c;从而提高开发效率,今天我们就来介绍一款国产的…

使用Java实现一个简单的贪吃蛇小游戏

一. 准备工作 首先获取贪吃蛇小游戏所需要的头部、身体、食物以及贪吃蛇标题等图片。 然后&#xff0c;创建贪吃蛇游戏的Java项目命名为snake_game&#xff0c;并在这个项目里创建一个文件夹命名为images&#xff0c;将图片素材导入文件夹。 再在src文件下创建两个包&#xff0…

正点原子嵌入式linux驱动开发——Linux DAC驱动

上一篇笔记中学习了ADC驱动&#xff0c;STM32MP157 也有DAC外设&#xff0c;DAC也使用的IIO驱动框架。本章就来学习一下如下在Linux下使用STM32MP157上的DAC。 DAC简介 ADC是模数转换器&#xff0c;负责将外界的模拟信号转换为数字信号。DAC刚好相反&#xff0c;是数模转换器…

在node中实现高效率、低内存的excel/JSON转换

在node中实现高效率、低内存的excel/JSON转换 在nodejs中不使用过多内存的情况下&#xff0c;将大型excel文件转换为json格式是一个非常常见的需求&#xff0c;因为它可以更容易地处理和共享数据。在这篇文章中&#xff0c;我们将探讨如何完成这个需求&#xff0c;提供一个逐步…

Mac电脑Visio文件编辑查看软件推荐Visio Viewer for Mac

mac版Visio Viewer功能特色 在Mac OS X上查看Visio绘图和图表 在Mac OS X上轻松查看MS Visio文件 在Mac上快速方便地打开并阅读Visio文件&#xff08;.vsd&#xff0c;.vsdx&#xff09;。 支持通过放大&#xff0c;缩小&#xff0c;旋转&#xff0c;文本选择和复制&#xff0…

群晖Docker(Container Manager)中安装Home Assistant Container

群晖Docker&#xff08;Container Manager&#xff09;中安装Home Assistant Container 不要使用 套件里面的 Home Assistant&#xff0c;不利于后期拓展 方式一&#xff1a; docker run -d --name"home-assistant-1" -v /volume1/docker/homeassistant/config:/c…

秋招求职经验分享

0.个人简介 2023年10月底&#xff0c;最终拿到了海康威视、汇川技术等十余家公司的Offer&#xff0c;最终签了自己心仪的Offer&#xff0c;秋招对我来说算是正式结束了&#xff0c;写个博客纪念一下&#xff0c;顺便分享以下秋招的经验&#xff0c;为后来人求职提供一些参考。…