【Transformer从零开始代码实现 pytoch版】(五)总架构类的实现

Transformer总架构

在这里插入图片描述
在实现完输入部分、编码器、解码器和输出部分之后,就可以封装各个部件为一个完整的实体类了。

【Transformer从零开始代码实现 pytoch版】(一)输入部件:embedding+positionalEncoding

【Transformer从零开始代码实现 pytoch版】(二)Encoder编码器组件:mask + attention + feed forward + add&norm

【Transformer从零开始代码实现 pytoch版】(三)Decoder编码器组件:多头自注意力+多头注意力+全连接层+规范化层

【Transformer从零开始代码实现 pytoch版】(四)输出部件:Linear+softmax

编码器-解码器总结构代码实现

class EncoderDecoder(nn.Module):""" 编码器解码器架构实现、定义了初始化、forward、encode和decode部件"""def __init__(self, encoder, decoder, source_embed, target_embed, generator):""" 传入五大部件参数:param encoder: 编码器:param decoder: 解码器:param source_embed: 源数据embedding函数:param target_embed: 目标数据embedding函数:param generator: 输出部分类被生成器对象"""super(EncoderDecoder, self).__init__()self.encoder = encoderself.decoder = decoderself.src_embed = source_embedself.tgt_embed = target_embedself.generator = generator					# 生成器后面会专门用到def forward(self, source, target, source_mask, target_mask):""" 构建数据流入流出:param source: 源数据:param target: 目标数据:param source_mask: 源数据掩码张量:param target_mask: 目标数据掩码张量:return:"""# 注意这里先用的encode和decode函数,又才在其函数里面,再用了encoder和decoderreturn self.decode(self.encode(source, source_mask), source_mask, target, target_mask)def encode(self, source, source_mask):""" 编码函数,编码部件:param source: 源数据张量:param source_mask: 源数据的掩码张量:return: 经过解码器的输出"""return self.encoder(self.src_embed(source), source_mask)def decode(self, memory, source_mask, target, target_mask):""" 解码函数,解码部件:param memory:编码器的输出QV:param source_mask:源数据的掩码张量:param target:目标数据:param target_mask:目标数据的掩码张量:return:"""return self.decoder(self.tgt_embed(target), memory, source_mask, target_mask)

示例

# 输入参数
vocab_size = 1000
size = d_model = 512# 编码器部分
dropout = 0.2
d_ff = 64				# 隐藏层参数
head = 8				# 注意力头数
c = copy.deepcopy
attn = MultiHeadedAttention(head, d_model)
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
encoder_layer = EncoderLayer(size, c(attn), c(ff), dropout)
encoder_N = 8
encoder = Encoder(encoder_layer, encoder_N)# 解码器部分
dropout = 0.2
d_ff = 64
head = 8
c = copy.deepcopy
attn = MultiHeadedAttention(head, d_model)
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
decoder_layer = DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout)
decoder_N = 8
decoder = Decoder(decoder_layer, decoder_N)# 用了nn的embedding作为输入示意
source_embed = nn.Embedding(vocab_size, d_model)
target_embed = nn.Embedding(vocab_size, d_model)
generator = Generator(d_model, vocab_size)# 输入张量和掩码张量
source = target = torch.LongTensor([[100, 2, 421, 508], [491, 998, 1, 221]])
source_mask = target_mask = torch.zeros(2, 4, 4)# 实例化编码器-解码器,再带入参数实现
ed = EncoderDecoder(encoder, decoder, source_embed, target_embed, generator)
ed_res = ed(source, target, source_mask, target_mask)
print(f"ed_res: {ed_res}\n shape:{ed_res.shape}")ed_res: tensor([[[-0.1861,  0.0849, -0.3015,  ...,  1.1753, -1.4933,  0.2484],[-0.3626,  1.3383,  0.1739,  ...,  1.1304,  2.0266, -0.5929],[ 0.0785,  1.4932,  0.3184,  ..., -0.2021, -0.2330,  0.1539],[-0.9703,  1.1944,  0.1763,  ...,  0.1586, -0.6066, -0.6147]],[[-0.9216, -0.0309, -0.6490,  ...,  1.0177,  0.5574,  0.4873],[-1.4097,  0.6678, -0.6708,  ...,  1.1176,  0.1959, -1.2494],[-0.3204,  1.2794, -0.4022,  ...,  0.6319, -0.4709,  1.0520],[-1.3238,  1.1470, -0.9943,  ...,  0.4026,  1.0911,  0.1327]]],grad_fn=<AddBackward0>)shape:torch.Size([2, 4, 512])

编码器-解码器模型构建函数

def make_model(source_vocab, target_vocab, N=6, d_model=512, d_ff=2048, head=8, dropout=0.1):""" 用于构建模型:param source_vocab: 源数据词汇总数:param target_vocab: 目标词汇总数:param N: 解码器/解码器堆叠层数:param d_model: 词嵌入维度:param d_ff: 前馈全连接层隐藏层维度:param dropout: 置0比率:return: 返回构建编码器-解码器模型"""# 拷贝函数,来保证拷贝的函数彼此之间相互独立,不受干扰c = copy.deepcopy# 实例化多头注意力attn = MultiHeadedAttention(head, d_model)# 实例化全连接层ff = PositionwiseFeedForward(d_model, d_ff, dropout)# 实例化位置编码类,得到对象positionposition = PositionalEncoding(d_model, dropout)model = EncoderDecoder(Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),nn.Sequential(Embedding(d_model, source_vocab), c(position)),nn.Sequential(Embedding(d_model, source_vocab), c(position)),Generator(d_model, target_vocab))# 模型结构构建完成后,初始化模型中的参数for p in model.parameters():# 这里判定当参数维度大于1的时候,则会将其初始化成一个服从均匀分布的矩阵if p.dim() > 1:nn.init.xavier_normal(p)        # 生成服从正态分布的数,默认为U(-1, 1),更改第二个参数可以改值return model

示例

source_vocab = target_vocab = 11
N = 6
res = make_model(source_vocab, target_vocab, N)
print(res)EncoderDecoder((encoder): Encoder((layers): ModuleList((0-5): 6 x EncoderLayer((self_attn): MultiHeadedAttention((linears): ModuleList((0-3): 4 x Linear(in_features=512, out_features=512, bias=True))(dropout): Dropout(p=0.1, inplace=False))(feed_forward): PositionwiseFeedForward((w1): Linear(in_features=512, out_features=2048, bias=True)(w2): Linear(in_features=2048, out_features=512, bias=True)(dropout): Dropout(p=0.1, inplace=False))(sublayer): ModuleList((0-1): 2 x SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1, inplace=False)))))(norm): LayerNorm())(decoder): Decoder((layers): ModuleList((0-5): 6 x DecoderLayer((self_attn): MultiHeadedAttention((linears): ModuleList((0-3): 4 x Linear(in_features=512, out_features=512, bias=True))(dropout): Dropout(p=0.1, inplace=False))(src_attn): MultiHeadedAttention((linears): ModuleList((0-3): 4 x Linear(in_features=512, out_features=512, bias=True))(dropout): Dropout(p=0.1, inplace=False))(feed_forward): PositionwiseFeedForward((w1): Linear(in_features=512, out_features=2048, bias=True)(w2): Linear(in_features=2048, out_features=512, bias=True)(dropout): Dropout(p=0.1, inplace=False))(sublayer): ModuleList((0-2): 3 x SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1, inplace=False)))))(norm): LayerNorm())(src_embed): Sequential((0): Embedding((lut): Embedding(512, 11))(1): PositionalEncoding((dropout): Dropout(p=0.1, inplace=False)))(tgt_embed): Sequential((0): Embedding((lut): Embedding(512, 11))(1): PositionalEncoding((dropout): Dropout(p=0.1, inplace=False)))(generator): Generator((project): Linear(in_features=512, out_features=11, bias=True))
)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/140720.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CMT2300A超低功耗127-1020MHz Sub-1GHz全频段SUB-1G 射频收发芯片

CMT2300A超低功耗127-1020MHz Sub-1GHz全频段SUB-1G 射频收发芯片 Sub-1GHz&#xff0c;是指小于1GHz频率的统称。Sub-1GHz无线电频段应用的主要特点&#xff1a;&#xff08;1&#xff09;频率较低波长较长&#xff0c;传输距离远&#xff0c;穿透性强&#xff1b;&#xff0…

xinput1_3.dll丢失的详细解决步骤办法和比较,五种有效的解决办法

今天想和大家分享一个电脑中经常出现的问题——xinput1_3.dll丢失。这个文件丢失是一件常见的问题。不知道小伙伴们有没有遇到过这样的问题&#xff0c;如果你遇到这样的问题今天就教大家xinput1_3.dll丢失的详细解决步骤办法和比较&#xff0c;五种有效的解决办法。 一.xinput…

【Axure高保真原型】树切换动态面板案例

今天和大家分享树切换动态面板的原型模板&#xff0c;点击树的箭头可以打开或者收起子节点&#xff0c;点击最后一级人物节点&#xff0c;可以切换右侧面板的状态到对应的页面&#xff0c;左侧的树是通过中继器制作的&#xff0c;使用简单&#xff0c;只需要按要求填写中继器表…

浅谈jvm

前置知识补充 JDK、JRE、JVM是什么&#xff1f;区别与联系&#xff1f; 区别&#xff1a; JDK&#xff08;Java Development Kit&#xff09;&#xff1a;Java开发工具包 主要包括 Java运行环境、Java基础库及 Java工具。 JRE&#xff08;Java Runtime Environment&#xf…

使用迁移学习在线校准深度学习模型

使用迁移学习在线校准深度学习模型 本文参考的是2023年发表于Engineering Applications of Artificial Intelligence, EAAI的Deep Gaussian mixture adaptive network for robust soft sensor modeling with a closed-loop calibration mechanism 1. 动机 概念漂移导致历史训…

完全免费!超好用的IDEA插件推荐:Apipost-Helper

Idea 是一款功能强大的集成开发环境&#xff08;IDE&#xff09;&#xff0c;它可以帮助开发人员更加高效地编写、调试和部署软件应用程序,Idea 还具有许多插件和扩展&#xff0c;可以根据开发人员的需要进行定制和扩展&#xff0c;从而提高开发效率,今天我们就来介绍一款国产的…

使用Java实现一个简单的贪吃蛇小游戏

一. 准备工作 首先获取贪吃蛇小游戏所需要的头部、身体、食物以及贪吃蛇标题等图片。 然后&#xff0c;创建贪吃蛇游戏的Java项目命名为snake_game&#xff0c;并在这个项目里创建一个文件夹命名为images&#xff0c;将图片素材导入文件夹。 再在src文件下创建两个包&#xff0…

正点原子嵌入式linux驱动开发——Linux DAC驱动

上一篇笔记中学习了ADC驱动&#xff0c;STM32MP157 也有DAC外设&#xff0c;DAC也使用的IIO驱动框架。本章就来学习一下如下在Linux下使用STM32MP157上的DAC。 DAC简介 ADC是模数转换器&#xff0c;负责将外界的模拟信号转换为数字信号。DAC刚好相反&#xff0c;是数模转换器…

Mac电脑Visio文件编辑查看软件推荐Visio Viewer for Mac

mac版Visio Viewer功能特色 在Mac OS X上查看Visio绘图和图表 在Mac OS X上轻松查看MS Visio文件 在Mac上快速方便地打开并阅读Visio文件&#xff08;.vsd&#xff0c;.vsdx&#xff09;。 支持通过放大&#xff0c;缩小&#xff0c;旋转&#xff0c;文本选择和复制&#xff0…

群晖Docker(Container Manager)中安装Home Assistant Container

群晖Docker&#xff08;Container Manager&#xff09;中安装Home Assistant Container 不要使用 套件里面的 Home Assistant&#xff0c;不利于后期拓展 方式一&#xff1a; docker run -d --name"home-assistant-1" -v /volume1/docker/homeassistant/config:/c…

华东“启明”青少年音乐艺术实践中心揭幕暨中国“启明”巴洛克合奏团首演音乐会

2023年11月11日&#xff0c;华东“启明”青少年音乐艺术实践中心在上海揭幕&#xff0c;中国“启明”巴洛克合奏团开启了首场音乐会。 华东“启明”青少年音乐艺术实践中心由中共宁波市江北区委宣传部与上音管风琴艺术中心联合指导&#xff0c;宁波音乐港、宁波市江北区洛奇音乐…

Apache APISIX 的 Admin API 默认访问令牌漏洞(CVE-2020-13945)漏洞复现

漏洞描述 Apache APISIX 是一个动态、实时、高性能的 API 网关。Apache APISIX 有一个默认的内置 API 令牌&#xff0c;可用于访问所有 admin API&#xff0c;通过 2.x 版本中添加的参数导致远程执行 LUA 代码。 漏洞环境及利用 启动docker环境 访问9080端口 通过 admin api…

利用LangChain实现RAG

检索增强生成&#xff08;Retrieval-Augmented Generation, RAG&#xff09;结合了搜寻检索生成能力和自然语言处理架构&#xff0c;透过这个架构&#xff0c;模型可以从外部知识库搜寻相关信息&#xff0c;然后使用这些信息来生成response。要完成检索增强生成主要包含四个步骤…

Java进阶(垃圾回收GC)——理论篇:JVM内存模型 垃圾回收定位清除算法 JVM中的垃圾回收器

前言 JVM作为Java进阶的知识&#xff0c;是需要Java程序员不断深度和理解的。 本篇博客介绍JVM的内存模型&#xff0c;对比了1.7和1.8的内存模型的变化&#xff1b;介绍了垃圾回收的语言发展&#xff1b;阐述了定位垃圾的方法&#xff0c;引用计数法和可达性分析发以及垃圾清…

如何实现Debian工控电脑USB接口安全管控

Debian 作为工控电脑操作系统具有稳定性、安全性、自定义性和丰富的软件包等优势&#xff0c;适用于要求高度可靠性和安全性的工控应用。 Debian 作为工控电脑操作系统在工业控制领域有很大优势&#xff0c;包括&#xff1a; 稳定性&#xff1a;Debian 的发布版以其稳定性而闻…

find和grep命令的简单使用

find和grep命令的简单使用 一、find例子--不同条件查找 二、grep正则表达式的简单说明例子--简单文本查找例子--结合管道进行查找 一、find find 命令在指定的目录下查找对应的文件。 find [path] [expression]● path 是要查找的目录路径&#xff0c;可以是一个目录或文件名…

MS321V/358V/324V低压、轨到轨输入输出运放

MS321V/MS358V/MS324V 是单个、两个和四个低压轨到轨输 入输出运放&#xff0c;可工作在幅度为 2.7V 到 5V 的单电源或者双电源条件 下。在低电源、空间节省和低成本应用方面是最有效的解决方案。 这些放大器专门设计为低压工作&#xff08; 2.7V 到 5V &#xff09;…

Electron-vue出现GET http://localhost:9080/__webpack_hmr net::ERR_ABORTED解决方案

GET http://localhost:9080/__webpack_hmr net::ERR_ABORTED解决方案 使用版本解决方案解决总结 使用版本 以下是我解决此问题时使用的electron和vue等的一些版本信息 【附】经过测试 electron 的版本为 13.1.4 时也能解决 解决方案 将项目下的 .electron-vue/dev-runner.js…

为什么要用“交叉熵”做损失函数

大家好啊&#xff0c;我是董董灿。 今天看一个在深度学习中很枯燥但很重要的概念——交叉熵损失函数。 作为一种损失函数&#xff0c;它的重要作用便是可以将“预测值”和“真实值(标签)”进行对比&#xff0c;从而输出 loss 值&#xff0c;直到 loss 值收敛&#xff0c;可以…

Python数据结构:元组(Tuple)详解

1.介绍和基础操作 Python中的元组&#xff08;Tuple&#xff09;是不可变有序序列&#xff0c;可以容纳任意数据类型&#xff08;包括数字、字符串、布尔型、列表、字典等&#xff09;的元素&#xff0c;通常用圆括号() 包裹。与列表&#xff08;List&#xff09;类似&#xff…