Transformer的PyTorch实现之若干问题探讨(一)

《Transformer的PyTorch实现》这篇博文以一个机器翻译任务非常优雅简介的阐述了Transformer结构。在阅读时存在一些小困惑,此处权当一个记录。

1.自定义数据中enc_input、dec_input及dec_output的区别

博文中给出了两对德语翻译成英语的例子:

# S: decoding input 的起始符
# E: decoding output 的结束符
# P:意为padding,如果当前句子短于本batch的最长句子,那么用这个符号填补缺失的单词
sentence = [# enc_input   dec_input    dec_output['ich mochte ein bier P','S i want a beer .', 'i want a beer . E'],['ich mochte ein cola P','S i want a coke .', 'i want a coke . E'],
]

初看会对这其中的enc_input、dec_input及dec_output三个句子的作用不太理解,此处作详细解释:
-enc_input是模型需要翻译的输入句子,
-dec_input是用于指导模型开始翻译过程的信号
-dec_output是模型训练时的目标输出,模型的目标是使其产生的输出尽可能接近dec_output,即为翻译真实标签。

在使用Transformer进行翻译的时候,需要在Encoder端输入enc_input编码的向量,在decoder端最初只输入起始符S,然后让Transformer网络预测下一个token。

我们知道Transformer架构在进行预测时,每次推理时会获得下一个token,因此推理不是并行的,需要输出多少个token,理论上就要推理多少次。那么,在训练阶段,也需要像预测那样根据之前的输出预测下一个token,然而再所引出dec_output中对应的token做损失吗?实际并不是这样,如果真是这样做,就没有办法并行训练了。

实际我认为Transformer的并行应该是有两个层次:
(1)不同batch在训练和推理时是否可以实现并行?
(2)一个batch是否能并行得把所有的token推理出来?
Tranformer在训练时实现了上述的(1)(2),而推理时(1)(2)都没有实现。Transformer的推理似乎很难实现并行,原因是如果一次性推理两句话,那么如何保证这两句话一样长?难道有一句已经结束了,另一句没有结束,需要不断的把结束符E送入继续预测下一个结束符吗?此外,Transformer在预测下一个token时必须前面的token已经预测出来了,如果第i-1个token都没有,是无法得到第i个token。因此推理的时候都是逐句话预测,逐token预测。这儿实际也是我认为是transformer结构需要改进的地方。这样才可以提高transformer的推理效率。

2.Transformer的训练流程

此处给出博文中附带的非常简洁的Transformer训练代码:

from torch import optim
from model import *model = Transformer().cuda()
model.train()
# 损失函数,忽略为0的类别不对其计算loss(因为是padding无意义)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)# 训练开始
for epoch in range(1000):for enc_inputs, dec_inputs, dec_outputs in loader:'''enc_inputs: [batch_size, src_len] [2,5]dec_inputs: [batch_size, tgt_len] [2,6]dec_outputs: [batch_size, tgt_len] [2,6]'''enc_inputs, dec_inputs, dec_outputs = enc_inputs.cuda(), dec_inputs.cuda(), dec_outputs.cuda() # [2, 6], [2, 6], [2, 6]outputs = model(enc_inputs, dec_inputs) # outputs: [batch_size * tgt_len, tgt_vocab_size]loss = criterion(outputs, dec_outputs.view(-1))  # 将dec_outputs展平成一维张量# 更新权重optimizer.zero_grad()loss.backward()optimizer.step()print(f'Epoch [{epoch + 1}/1000], Loss: {loss.item()}')
torch.save(model, f'MyTransformer_temp.pth')

这段代码非常简洁,可以看到输入的是batch为2的样本,送入Transformer网络中直接logits算损失。Transformer在训练时实际上使用了一个策略叫teacher forcing。要解释这个策略的意义,以本博文给出的样本为例,对于输入的样本:

ich mochte ein bier

在进行训练时,当我们给出起始符S,接下来应该预测出:

I

那训练时,有了SI后,则应该预测出

want

那么问题来了,如I就预测错了,假如预测成了a,那么在预测want时,还应该使用Sa来预测吗?当然不是,即使预测错了,也应该用对应位置正确的tokenSI去预测下一个token,这就是teacher forcing。

那么transformer是如何实现这样一个teacher forcing的机制的呢?且听下回分解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/667859.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

编译原理本科课程 专题5 基于 SLR(1)分析的语义分析及中间代码生成程序设计

一、程序功能描述 本程序由C/C编写,实现了赋值语句语法制导生成四元式,并完成了语法分析和语义分析过程。 以专题 1 词法分析程序的输出为语法分析的输入,完成以下描述赋值语句 SLR(1)文法的语义分析及中间代码四元式的过程,实现…

开源节点框架STNodeEditor使用

节点,一般都为树形Tree结构,如TreeNode,XmlNode。 树形结构有其关键属性Parent【父节点】,Children【子节点】 LinkedListNode为链表线性结构,有其关键属性Next【下一个】,Previous【上一个】&#xff0c…

1978-2022年人民币汇率(年平均价)数据

1978-2022年人民币汇率(年平均价)数据 1、时间:1978-2022年,其中人民币对欧元汇率时间为2002-2022年 2、指标:人民币对美元汇率(美元100)(元)、人民币对日元汇率(日元100)(元)、人民币对港元汇率(港元100)(元)、人民…

华为突然官宣:新版鸿蒙系统,正式发布

华为,一家始终引领科技创新潮流的全球性企业,近日再次引发行业震动——全新HarmonyOS NEXT,被誉为“纯血版鸿蒙”的操作系统正式官宣。这是华为在操作系统领域迈出的坚实且具有突破性的一步,标志着华为正逐步摆脱对安卓生态系统的…

3D力导向树插件-3d-force-graph学习002

一、实现效果:节点文字同时展示 节点显示不同颜色节点盒label文字并存节点上添加点击事件 二、利用插件:CSS2DRenderer 提示:以下引入文件均可在安装完3d-force-graph的安装包里找到 三、关键代码 提示:模拟数据可按如下格式填…

Node.js 包管理工具

一、概念介绍 1.1 包是什么 『包』英文单词是 package ,代表了一组特定功能的源码集合 1.2 包管理工具 管理『包』的应用软件,可以对「包」进行 下载安装 , 更新 , 删除 , 上传 等操作。 借助包管理工具&#xff0…

深度学习本科课程 实验5 循环神经网络

循环神经网络实验 任务内容 理解序列数据处理方法,补全面向对象编程中的缺失代码,并使用torch自带数据工具将数据封装为dataloader分别采用手动方式以及调用接口方式实现RNN、LSTM和GRU,并在至少一种数据集上进行实验从训练时间、预测精度、…

android tv开发-1,leanback

目录 1.leanback库的一些事 2.leanback在使用时遇到的一些麻烦 视频卡片 页面空白 关于左侧菜单的一些设置 数据加载异常与加载中的一些操作 如果页面无数据,如何显示错误的页面.

视频美颜SDK开发指南:从入门到精通的技术实践

美颜SDK是一种强大的工具,它不仅仅可以让用户在实时视频中获得光滑的肌肤和自然的妆容,从简单的滤镜到复杂的人脸识别,美颜SDK涵盖了广泛的技术领域。 一、美颜SDK的基本原理 美颜SDK包括图像处理、人脸检测和识别、滤镜应用等方面。掌握这些…

2024日常训练

#一些简单的训练 第一题:带余除法 来源于dotcpp、 习题一 (题目可以自己去看一看,这里就直接开始写题解了) 题解: 1.需要用到map()函数来接受一行整数输入 2.需要用到---end 这里就可以实现输出结果在同一行 …

mysql 使用join进行多表关联查询

join类型 在一些报表统计或数据展示时候需要提取的数据分布在多个表中,这个时候需要进行join连表操作。join将两个或多个表当成不同的数据集合,然后进行集合取交集运算。比如有订单Order表记录用户id,如果像查询订单对应的用户信息&#xff…

uniapp中使用EelementPlus

uniapp的强大是非常震撼的,一套代码可以编写到十几个平台。这个可以在官网上进行查询uni-app官网。主要还是开发小型的软件系统,使用起来非常的方便、快捷、高效。 uniapp中有很多自带的UI,在创建项目的时候,就可以自由选择。而E…

网络编程面试系列-01

1. 应用层中常见的协议都有哪些? 应用层协议(application layer protocol)定义了运行在不同端系统上的应用程序进程如何相互传递报文。 应用层协议 1)DNS:一种用以将域名转换为IP地址的Internet服务,域名系统DNS是因特网使用的命名系统,用来把便于人们使用的机器名字…

Unity类银河恶魔城学习记录1-11 PlayerPrimaryAttack P38

Alex教程每一P的教程原代码加上我自己的理解初步理解写的注释,可供学习Alex教程的人参考 此代码仅为较上一P有所改变的代码 【Unity教程】从0编程制作类银河恶魔城游戏_哔哩哔哩_bilibili Player.cs using System.Collections; using System.Collections.Generic…

苹果公司宣布,为Apple Vision Pro打造了超过600款新应用

深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领域的领跑者。点击订阅,与未来同行! 订阅:https://rengongzhineng.io/ 。 2月…

JavaWeb之HTML-CSS --黑马笔记

什么是HTML ? 标记语言:由标签构成的语言。 注意:HTML标签都是预定义好的,HTML代码直接在浏览器中运行,HTML标签由浏览器解析。 什么是CSS ? 开发工具 VS Code --安装文档和安装包都在网盘中 链接:https://p…

FreeRtos的静态方法创建任务和删除示例

需要使用静态方法需要将宏configSUPPORT_STATIC_ALLOCATION1 步骤 1.修改宏configSUPPORT_STATIC_ALLOCATION1运行时候会显示两个函数未定义 vApplicationGetIdleTaskMemory()vApplicationGetTimerTaskMemory() #include &quo…

git整合分支的两种方法——合并(Merge)、变基(Rebase)

问题描述: 初次向git上传本地代码或者更新代码时,总是会遇到以下两个选项。有时候,只是想更新一下代码,没想到,直接更新了最新的代码,但是自己本地的代码并没有和git上的代码融合,反而被覆盖了…

机器学习系列——(六)数据降维

引言 在机器学习领域,数据降维是一种常用的技术,旨在减少数据集的维度,同时保留尽可能多的有用信息。数据降维可以帮助我们解决高维数据带来的问题,提高模型的效率和准确性。本文将详细介绍机器学习中的数据降维方法和技术&#…

浅谈——开源软件的影响力

✅作者简介:2022年博客新星 第八。热爱国学的Java后端开发者,修心和技术同步精进。 🍎个人主页:Java Fans的博客 🍊个人信条:不迁怒,不贰过。小知识,大智慧。 ✨特色专栏&#xff1a…