简单易懂的理解 PyTorch 中 Transformer 组件

目录

torch.nn子模块transformer详解

nn.Transformer

Transformer 类描述

Transformer 类的功能和作用

Transformer 类的参数

forward 方法

参数

输出

示例代码

注意事项

nn.TransformerEncoder

TransformerEncoder 类描述

TransformerEncoder 类的功能和作用

TransformerEncoder 类的参数

forward 方法

参数

返回类型

形状

示例代码

nn.TransformerDecoder

TransformerDecoder 类描述

TransformerDecoder 类的功能和作用

TransformerDecoder 类的参数

forward 方法

参数

返回类型

形状

示例代码

nn.TransformerEncoderLayer

TransformerEncoderLayer 类描述

TransformerEncoderLayer 类的功能和作用

TransformerEncoderLayer 类的参数

forward 方法

参数

返回类型

形状

示例代码

nn.TransformerDecoderLayer

TransformerDecoderLayer 类描述

TransformerDecoderLayer 类的功能和作用

TransformerDecoderLayer 类的参数

forward 方法

参数

返回类型

形状

示例代码

总结


torch.nn子模块transformer详解

nn.Transformer

Transformer 类描述

torch.nn.Transformer 类是 PyTorch 中实现 Transformer 模型的核心类。基于 2017 年的论文 “Attention Is All You Need”,该类提供了构建 Transformer 模型的完整功能,包括编码器(Encoder)和解码器(Decoder)部分。用户可以根据需要调整各种属性。

Transformer 类的功能和作用
  • 多头注意力: Transformer 使用多头自注意力机制,允许模型同时关注输入序列的不同位置。
  • 编码器和解码器: 包含多个编码器和解码器层,每层都有自注意力和前馈神经网络。
  • 适用范围广泛: 被广泛用于各种 NLP 任务,如语言翻译、文本生成等。
Transformer 类的参数
  1. d_model (int): 编码器/解码器输入的特征数(默认值为512)。
  2. nhead (int): 多头注意力模型中的头数(默认值为8)。
  3. num_encoder_layers (int): 编码器中子层的数量(默认值为6)。
  4. num_decoder_layers (int): 解码器中子层的数量(默认值为6)。
  5. dim_feedforward (int): 前馈网络模型的维度(默认值为2048)。
  6. dropout (float): Dropout 值(默认值为0.1)。
  7. activation (str 或 Callable): 编码器/解码器中间层的激活函数,默认为 ReLU。
  8. custom_encoder/decoder (可选): 自定义的编码器或解码器(默认值为None)。
  9. layer_norm_eps (float): 层归一化组件中的 eps 值(默认值为1e-5)。
  10. batch_first (bool): 如果为 True,则输入和输出张量的格式为 (batch, seq, feature)(默认值为False)。
  11. norm_first (bool): 如果为 True,则在其他注意力和前馈操作之前进行层归一化(默认值为False)。
  12. bias (bool): 如果设置为 False,则线性和层归一化层将不学习附加偏置(默认值为True)。
forward 方法

forward 方法用于处理带掩码的源/目标序列。

参数
  • src (Tensor): 编码器的输入序列。
  • tgt (Tensor): 解码器的输入序列。
  • src/tgt/memory_mask (可选): 序列掩码。
  • src/tgt/memory_key_padding_mask (可选): 键填充掩码。
  • src/tgt/memory_is_causal (可选): 指定是否应用因果掩码。
输出
  • 输出 Tensor 的形状为 (T, N, E)(N, T, E)(如果 batch_first=True),其中 T 是目标序列长度,N 是批次大小,E 是特征数。
示例代码
import torch
import torch.nn as nn# 创建 Transformer 实例
transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)# 输入数据
src = torch.rand((10, 32, 512))
tgt = torch.rand((20, 32, 512))# 前向传播
out = transformer_model(src, tgt)

这段代码展示了如何创建并使用 Transformer 模型。在这个例子中,srctgt 分别是随机生成的编码器和解码器的输入张量。输出 out 是模型的最终输出。

注意事项
  • 掩码生成: 可以使用 generate_square_subsequent_mask 方法来生成序列的因果掩码。
  • 配置灵活性: 由于 Transformer 类的可配置性,用户可以轻松调整模型结构以适应不同的任务需求。

nn.TransformerEncoder

TransformerEncoder 类描述

torch.nn.TransformerEncoder 类在 PyTorch 中实现了 Transformer 模型的编码器部分。它是一系列编码器层的堆叠,用户可以通过这个类构建类似于 BERT 的模型。

TransformerEncoder 类的功能和作用
  • 多层编码器结构: TransformerEncoder 由多个 Transformer 编码器层组成,每一层都包括自注意力机制和前馈网络。
  • 适用于各种 NLP 任务: 可用于语言模型、文本分类等多种自然语言处理任务。
  • 灵活性和可定制性: 用户可以自定义编码器层的数量和层参数,以适应不同的应用需求。
TransformerEncoder 类的参数
  1. encoder_layer: TransformerEncoderLayer 实例,表示单个编码器层(必需)。
  2. num_layers: 编码器中子层的数量(必需)。
  3. norm: 层归一化组件(可选)。
  4. enable_nested_tensor: 如果为 True,则输入会自动转换为嵌套张量(在输出时转换回来),当填充率较高时,这可以提高 TransformerEncoder 的整体性能。默认为 True(启用)。
  5. mask_check: 是否检查掩码。默认为 True。
forward 方法

forward 方法用于顺序通过编码器层处理输入。

参数
  • src (Tensor): 编码器的输入序列(必需)。
  • mask (可选 Tensor): 源序列的掩码(可选)。
  • src_key_padding_mask (可选 Tensor): 批次中源键的掩码(可选)。
  • is_causal (可选 bool): 如指定,应用因果掩码。默认为 None;尝试检测因果掩码。
返回类型
  • Tensor
形状
  • 请参阅 Transformer 类中的文档。
示例代码
import torch
import torch.nn as nn# 创建 TransformerEncoderLayer 实例
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)# 创建 TransformerEncoder 实例
transformer_encoder = nn.TransformeEncoder(encoder_layer, num_layers=6)# 输入数据
src = torch.rand(10, 32, 512)  # 随机输入# 前向传播
out = transformer_encoder(src)

这段代码展示了如何创建并使用 TransformerEncoder。在这个例子中,src 是随机生成的输入张量,transformer_encoder 是由 6 层编码器层组成的编码器。输出 out 是编码器的最终输出。

nn.TransformerDecoder

TransformerDecoder 类描述

torch.nn.TransformerDecoder 类实现了 Transformer 模型的解码器部分。它是由多个解码器层堆叠而成,用于处理编码器的输出并生成最终的输出序列。

TransformerDecoder 类的功能和作用
  • 多层解码器结构: TransformerDecoder 由多个 Transformer 解码器层组成,每层包括自注意力机制、交叉注意力机制和前馈网络。
  • 处理编码器输出: 解码器用于处理编码器的输出,并根据此输出和之前生成的输出序列生成新的输出。
  • 应用场景广泛: 适用于各种基于 Transformer 的生成任务,如机器翻译、文本摘要等。
TransformerDecoder 类的参数
  1. decoder_layer: TransformerDecoderLayer 实例,表示单个解码器层(必需)。
  2. num_layers: 解码器中子层的数量(必需)。
  3. norm: 层归一化组件(可选)。
forward 方法

forward 方法用于将输入(及掩码)依次通过解码器层进行处理。

参数
  • tgt (Tensor): 解码器的输入序列(必需)。
  • memory (Tensor): 编码器的最后一层输出序列(必需)。
  • tgt/memory_mask (可选 Tensor): 目标/内存序列的掩码(可选)。
  • tgt/memory_key_padding_mask (可选 Tensor): 批次中目标/内存键的掩码(可选)。
  • tgt_is_causal/memory_is_causal (可选 bool): 指定是否应用因果掩码。
返回类型
  • Tensor
形状
  • 请参阅 Transformer 类中的文档。
示例代码
import torch
import torch.nn as nn# 创建 TransformerDecoderLayer 实例
decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)# 创建 TransformerDecoder 实例
transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)# 输入数据
memory = torch.rand(10, 32, 512)  # 编码器的输出
tgt = torch.rand(20, 32, 512)     # 解码器的输入# 前向传播
out = transformer_decoder(tgt, memory)

这段代码展示了如何创建并使用 TransformerDecoder。在这个例子中,memory 是编码器的输出,tgt 是解码器的输入。输出 out 是解码器的最终输出。

nn.TransformerEncoderLayer

TransformerEncoderLayer 类描述

torch.nn.TransformerEncoderLayer 类构成了 Transformer 编码器的基础单元,每个编码器层包含一个自注意力机制和一个前馈网络。这种标准的编码器层基于论文 "Attention Is All You Need"。

TransformerEncoderLayer 类的功能和作用
  • 自注意力机制: 通过自注意力机制,每个编码器层能够捕获输入序列中不同位置间的关系。
  • 前馈网络: 为序列中的每个位置提供额外的转换。
  • 灵活性和可定制性: 用户可以根据应用需求修改或实现不同的编码器层。
TransformerEncoderLayer 类的参数
  1. d_model (int): 输入中预期的特征数量(必需)。
  2. nhead (int): 多头注意力模型中的头数(必需)。
  3. dim_feedforward (int): 前馈网络模型的维度(默认值=2048)。
  4. dropout (float): Dropout 值(默认值=0.1)。
  5. activation (str 或 Callable): 中间层的激活函数,可以是字符串("relu" 或 "gelu")或一元可调用对象。默认值:relu。
  6. layer_norm_eps (float): 层归一化组件中的 eps 值(默认值=1e-5)。
  7. batch_first (bool): 如果为 True,则输入和输出张量以 (batch, seq, feature) 的格式提供。默认值:False(seq, batch, feature)。
  8. norm_first (bool): 如果为 True,则在注意力和前馈操作之前进行层归一化。否则之后进行。默认值:False(之后)。
  9. bias (bool): 如果设置为 False,则线性和层归一化层将不会学习附加偏置。默认值:True。
forward 方法

forward 方法用于将输入通过编码器层进行处理。

参数
  • src (Tensor): 传递给编码器层的序列(必需)。
  • src_mask (可选 Tensor): 源序列的掩码(可选)。
  • src_key_padding_mask (可选 Tensor): 批次中源键的掩码(可选)。
  • is_causal (bool): 如果指定,则应用因果掩码作为源掩码。默认值:False。
返回类型
  • Tensor
形状
  • 请参阅 Transformer 类中的文档。
示例代码
import torch
import torch.nn as nn# 创建 TransformerEncoderLayer 实例
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)# 输入数据
src = torch.rand(10, 32, 512)  # 随机输入# 前向传播
out = encoder_layer(src)

或者在 batch_first=True 的情况下:

encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
src = torch.rand(32, 10, 512)
out = encoder_layer(src)

这段代码展示了如何创建并使用 TransformerEncoderLayer。在这个例子中,src 是随机生成的输入张量。输出 out 是编码器层的输出。

nn.TransformerDecoderLayer

TransformerDecoderLayer 类描述

torch.nn.TransformerDecoderLayer 类是构成 Transformer 模型解码器的基本单元。这个标准的解码器层基于论文 "Attention Is All You Need"。它由自注意力机制、多头注意力机制和前馈网络组成。

TransformerDecoderLayer 类的功能和作用
  • 自注意力和多头注意力机制: 使解码器能够同时关注输入序列的不同部分。
  • 前馈网络: 为序列中的每个位置提供额外的转换。
  • 灵活性和可定制性: 用户可以根据应用需求修改或实现不同的解码器层。
TransformerDecoderLayer 类的参数
  1. d_model (int): 输入中预期的特征数量(必需)。
  2. nhead (int): 多头注意力模型中的头数(必需)。
  3. dim_feedforward (int): 前馈网络模型的维度(默认值=2048)。
  4. dropout (float): Dropout 值(默认值=0.1)。
  5. activation (str 或 Callable): 中间层的激活函数,可以是字符串("relu" 或 "gelu")或一元可调用对象。默认值:relu。
  6. layer_norm_eps (float): 层归一化组件中的 eps 值(默认值=1e-5)。
  7. batch_first (bool): 如果为 True,则输入和输出张量以 (batch, seq, feature) 的格式提供。默认值:False(seq, batch, feature)。
  8. norm_first (bool): 如果为 True,则在自注意力、多头注意力和前馈操作之前进行层归一化。否则之后进行。默认值:False(之后)。
  9. bias (bool): 如果设置为 False,则线性和层归一化层将不会学习附加偏置。默认值:True。
forward 方法

forward 方法用于将输入(及掩码)通过解码器层进行处理。

参数
  • tgt (Tensor): 解码器层的输入序列(必需)。
  • memory (Tensor): 编码器的最后一层输出序列(必需)。
  • tgt/memory_mask (可选 Tensor): 目标/内存序列的掩码(可选)。
  • tgt/memory_key_padding_mask (可选 Tensor): 批次中目标/内存键的掩码(可选)。
  • tgt_is_causal/memory_is_causal (bool): 指定是否应用因果掩码。
返回类型
  • Tensor
形状
  • 请参阅 Transformer 类中的文档。
示例代码
import torch
import torch.nn as nn# 创建 TransformerDecoderLayer 实例
decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)# 输入数据
memory = torch.rand(10, 32, 512)  # 编码器的输出
tgt = torch.rand(20, 32, 512)     # 解码器的输入# 前向传播
out = decoder_layer(tgt, memory)

 或者在 batch_first=True 的情况下:

decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
memory = torch.rand(32, 10, 512)
tgt = torch.rand(32, 20, 512)
out = decoder_layer(tgt, memory)

 这段代码展示了如何创建并使用 TransformerDecoderLayer。在这个例子中,memory 是编码器的输出,tgt 是解码器的输入。输出 out 是解码器层的输出。

总结

        本篇博客深入探讨了 PyTorch 的 torch.nn 子模块中与 Transformer 相关的核心组件。我们详细介绍了 nn.Transformer 及其构成部分 —— 编码器 (nn.TransformerEncoder) 和解码器 (nn.TransformerDecoder),以及它们的基础层 —— nn.TransformerEncoderLayernn.TransformerDecoderLayer。每个部分的功能、作用、参数配置和实际应用示例都被全面解析。这些组件不仅提供了构建高效、灵活的 NLP 模型的基础,还展示了如何通过自注意力和多头注意力机制来捕捉语言数据中的复杂模式和长期依赖关系。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/607473.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

普中STM32-PZ6806L开发板(有点悲伤的故事)

简介 关于我使用 普中STM32-PZ6806L做了做了一些实验, 不小心输入12V,导致核心板等被烧坏, 为了利用电路和资源, 搭建了STM32F103CBT6并使用普中STM32-PZ6806L上面没有烧坏的模块的故事。 普中STM32-PZ6806L开发板 这块的STM32F103ZET6部分算是Closed了, 不准备换核…

docker compose 搭建ES集群的详细步骤,并去掉验证,使用http访问

要使用 Docker Compose 搭建一个 Elasticsearch 集群,并配置为不使用验证,同时使用 HTTP 访问,你可以遵循以下步骤: 步骤 1: 安装 Docker 和 Docker Compose 确保你的系统上已安装 Docker 和 Docker Compose。如果尚未安装&…

freertos——任务通知知识总结与任务通知模拟及信号量实验、消息邮箱实验、事件标志组实验

1.任务通知概念 任务通知:用来通知任务的,任务控制块中的结构体成员变量 ulNotifiedValue就是这个通知值,不需要另外创建一个结构体可以直接接受别人发过来的通知 2.任务通知的优势及劣势 任务通知的优势: 效率更高 &#xff…

Hive基础知识(四):Hive 元数据配置到 MySQL

1. 拷贝驱动 将 MySQL 的 JDBC 驱动拷贝到 Hive 的 lib 目录下 [zzdqhadoop100 software]$ cp /home/atguigu/mysql-connector-java-5.1.37.jar $HIVE_HOME/lib 2. 配置 Metastore 到 MySQL 1)在$HIVE_HOME/conf 目录下新建 hive-site.xml 文件 [zzdqhadoop100 s…

DevExpress历史安装文件包集合

Components - DevExpress.NET组件安装包此安装程序包括所有 .NET Framework、.NET Core 3 和 .NET 5、ASP.NET Core 和 HTML/JavaScript 组件和库(Web和桌面应用程序开发只需要安装此文件即可)。 注意:自DevExpress21.1版本之后,该…

Python 动态变量名称的使用

Python 动态变量名称的使用 引言正文 引言 今天遇到了一个问题,那就是如何在 Python 中使用动态变量名称。那么为什么会想要使用动态变量名称呢?比如,我们有五个对象,分别为: rectangle1 rectangle2 rectangle3 rectangle4 rect…

Altium Designer实用系列(六)----AD问题整理,持续更新......

本篇博客会对粉丝提出的问题进行整理汇总,博客会持续更新… 一、 问题描述 问题1: 为什么我的ad点击设计之后没有update pcb document 原因:由于在工程中只有schDoc文件,没有PcbDoc,而导致没有“设计”–>update …

TypeScript 和 jsdom 库创建爬虫程序示例

TypeScript 简介 TypeScript 是一种由微软开发的自由和开源的编程语言。它是 JavaScript 的一个超集,可以编译生成纯 JavaScript 代码。TypeScript 增加了可选的静态类型和针对对象的编程功能,使得开发更加大规模的应用容易。 jsdom 简介 jsdom 是一个…

oracle 19c容器数据库数据加载和传输-----SQL*Loader(一)

目录 数据加载 (一)控制文件加载 1.创建用户执行sqlldr 2.创建文本文件和控制文件 3.查看表数据 4.查看log文件 (二)快捷方式加载 1.system用户执行 2.查看表数据 3.查看log文件 外部表 数据加载和传输的工具&#xff1…

【docker】centos7安装harbor

目录 零、前提一、下载离线包二、安装三、访问四、开机自启 零、前提 1.前提是已经安装了docker和docker-compose 一、下载离线包 1. csdn资源:harbor-offline-installer-v2.10.0.tgz 2. 百度云盘(提取码:ap3t):harbo…

彻底卸载Microsoft Edge的几种方法

在Windows 10或更高版本的操作系统中,由于Microsoft Edge浏览器与系统深度集成,常规的卸载方法有时可能无法完全移除。然而,你可以尝试以下步骤来尽可能彻底地卸载Microsoft Edge: 方法1:通过设置应用卸载Microsoft Ed…

2023三星齐发,博客之星、华为OD、Java学习星球

大家好,我是哪吒。 一、回顾2023 2023年,华为OD成了我的主旋律,一共发布了561篇文章,其中包含 368篇华为OD机试的文章;100篇Java基础的文章40多篇MongoDB、Redis的文章;30多篇数据库的文章;2…

Java多线程并发篇----第二篇

系列文章目录 文章目录 系列文章目录前言一、ExecutorService、 Callable、 Future 有返回值线程二、基于线程池的方式三、4 种线程池前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站,这篇文章男女通用,看懂了就去…

VS2022 | 调整适配虚幻5的设置

VS2022 | 调整适配虚幻5的设置

系统存储架构升级分享 | 京东云技术团队

一、业务背景 系统业务功能:系统内部进行数据处理及整合, 对外部系统提供结果数据的初始化(写)及查询数据结果服务。 系统网络架构: 部署架构对切量上线的影响 - 内部管理系统上线对其他系统的读业务无影响分布式缓存可进行单独扩容, 与存储及查询功能升级无关通过…

1.2 Hadoop概述

小肥柴的Hadoop之旅 1.2 Hadoop概述 目录1.2 Hadoop概述1.2.1 回归问题1.2.2 Google的三篇论文1.2.3 Hadoop的诞生过程1.2.4 Hadoop特点简介 参考文献和资料 ) 目录 1.2 Hadoop概述 1.2.1 回归问题 通过前一篇帖子的介绍,特别是问题思考部分的说明,我…

GCN的使用和包的安装(超详细)

文章目录 工具包安装方法首先进入官网,找到安装包的地址进入后,找到自己的torch版本进入后,将每种对应的包都下载到本地,用本地命令安装然后就是本地安装了最后就是pip install pytorch_geometric 工具包安装方法 一定参考其GITH…

【ASP.NET Core 基础知识】--项目结构

一、ASP.NET Core项目的基本结构 ASP.NET Core项目的基本结构通常遵循一种标准的组织方式,这有助于提高项目的可维护性和可扩展性。以下是一个典型的ASP.NET Core项目的基本结构: 项目文件 (.csproj): 项目的主要配置文件,定义了项目的依…

idea创建javaweb项目步骤超详细(2022最新版本)

目录 前言必读 一、新建文件 1.在idea里面点击文件-新建-项目 2.新建项目-更改名称为自己想要的项目名称-创建 3.右键自己建立的项目-添加框架支持(英文版是Add Framework Support...) 4.勾选Web应用程序-确定 5.建立成功界面 二、配置tomcat 6.…

Java游戏开发 —— 坦克大战

引言: 坦克大战也是小时一个比较经典的游戏了,我在网上也是参考了韩顺平老师写的坦克大战,并做了一下完善,编写出来作为儿时的回忆吧! 思路: 创建主窗口,加载菜单及游戏面板。 在游戏面板中初始…