从代码学习深度学习 - Transformer PyTorch 版

文章目录

  • 前言
  • 1. 位置编码(Positional Encoding)
  • 2. 多头注意力机制(Multi-Head Attention)
  • 3. 前馈网络与残差连接(Position-Wise FFN & AddNorm)
    • 3.1 基于位置的前馈网络(PositionWiseFFN)
    • 3.2 残差连接和层规范化(AddNorm)
  • 4. 编码器(Encoder)
    • 4.1 编码器块(EncoderBlock)
    • 4.2 Transformer 编码器(TransformerEncoder)
  • 5. 解码器(Decoder)
    • 5.1 解码器块(DecoderBlock)
    • 5.2 Transformer 解码器(TransformerDecoder)
  • 6. 完整 Transformer 模型
    • 使用示例
  • 总结


前言

Transformer 模型自 2017 年在论文《Attention is All You Need》中提出以来,彻底改变了自然语言处理(NLP)领域,并在计算机视觉等其他领域展现了强大的潜力。与传统的 RNN 和 LSTM 相比,Transformer 通过自注意力机制(Self-Attention)实现了并行计算,极大地提高了训练效率和模型性能。本博客将通过 PyTorch 实现的 Transformer 模型代码,深入剖析其核心组件,包括多头注意力机制、位置编码、编码器和解码器等。我们将结合代码和文字说明,逐步拆解 Transformer 的实现逻辑,帮助读者从代码层面理解这一经典模型的精髓。
在这里插入图片描述

本文基于提供的代码文件(PE.pyEnDecoder.pyMHA.pyTransformer.ipynb),完整呈现 Transformer 的 PyTorch 实现,并通过清晰的目录结构和代码注释,带领大家从零开始学习 Transformer 的构建过程。关于训练和可视化部分,这里忽略掉,但是你仍然可以在下面的链接里找到所有的源代码,其中提供了丰富的注释。无论你是深度学习初学者还是希望深入理解 Transformer 的开发者,这篇博客都将为你提供一个清晰的学习路径。

完整代码:下载链接


1. 位置编码(Positional Encoding)

Transformer 的自注意力机制不包含序列的位置信息,因此需要通过位置编码(Positional Encoding)为每个词元添加位置信息。以下是 PE.py 中实现的位置编码类,它通过正弦和余弦函数生成固定位置编码。

import torch
import torch.nn as nnclass PositionalEncoding(nn.Module):"""位置编码在Transformer中,由于自注意力机制不含位置信息,需要额外添加位置编码在位置嵌入矩阵P中,行代表词元在序列中的位置,列代表位置编码的不同维度"""def __init__(self, num_hiddens, dropout, max_len=1000):"""初始化位置编码参数:num_hiddens (int): 隐藏层维度,即位置编码的维度dropout (float): dropout概率max_len (int, 可选): 最大序列长度,默认为1000"""super(PositionalEncoding, self).__init__()# 初始化丢弃层self.dropout = nn.Dropout(dropout)# 创建位置编码矩阵P,形状为(1, max_len, num_hiddens)self.P = torch.zeros((1, max_len, num_hiddens))# 计算位置编码的正弦和余弦函数输入# X形状: (max_len, num_hiddens/2)X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)# 偶数维度赋值正弦,奇数维度赋值余弦self.P[:, :, 0::2] = torch.sin(X)self.P[:, :, 1::2] = torch.cos(X)def forward(self, X):"""前向传播参数:X (torch.Tensor): 输入张量,形状为(batch_size, seq_len, embed_dim)返回:torch.Tensor: 添加位置编码后的张量,形状为(batch_size, seq_len, embed_dim)"""# 将位置编码加到输入X上,截取与X长度匹配的部分X = X + self.P[:, :X.shape[1], :].to(X.device)# 应用丢弃并返回结果return self.dropout(X)

代码解析

  • 初始化PositionalEncoding 类根据隐藏层维度(num_hiddens)和最大序列长度(max_len)生成一个位置编码矩阵 P。该矩阵的每一行表示一个位置,每一列对应一个编码维度。
  • 正弦和余弦编码:通过正弦(sin)和余弦(cos)函数为不同位置和维度生成编码值,公式为:
    P E ( p o s , 2 i ) = sin ⁡ ( p o s 1000 0 2 i / d ) , P E ( p o s , 2 i + 1 ) = cos ⁡ ( p o s 1000 0 2 i / d ) PE(pos, 2i) = \sin\left(\frac{pos}{10000^{2i/d}}\right), \quad PE(pos, 2i+1) = \cos\left(\frac{pos}{10000^{2i/d}}\right) PE(pos,2i)=sin(100002i/dpos),PE(pos,2i+1)=cos(100002i/dpos)
    其中 pos 是位置索引,i 是维度索引,d 是隐藏层维度。
  • 前向传播:将输入张量 X 与位置编码矩阵 P 相加,并应用 dropout 以增强模型的鲁棒性。

位置编码的作用是将序列的位置信息嵌入到词嵌入中,使得 Transformer 能够区分相同词元在不同位置的语义。

2. 多头注意力机制(Multi-Head Attention)

多头注意力机制是 Transformer 的核心组件,允许模型并行计算多个注意力头,从而捕获序列中不同方面的依赖关系。以下是 MHA.py 中实现的多头注意力机制。

import math
import torch
from torch import nn
import torch.nn.functional as Fdef sequence_mask(X, valid_len, value=0):"""在序列中屏蔽不相关的项,使超出有效长度的位置被设置为指定值"""maxlen = X.size(1)mask = torch.arange(maxlen, dtype=torch

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/76112.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

阅读分析Linux0.11 /boot/head.s

目录 初始化IDT、IDTR和GDT、GDTR检查协处理器并设置CR0寄存器初始化页表和CR3寄存器,开启分页 初始化IDT、IDTR和GDT、GDTR startup_32:movl $0x10,%eaxmov %ax,%dsmov %ax,%esmov %ax,%fsmov %ax,%gslss _stack_start,%espcall setup_idtcall setup_gdtmovl $0x1…

33、单元测试实战练习题

以下是三个练习题的具体实现方案,包含完整代码示例和详细说明: 练习题1:TDD实现博客评论功能 步骤1:编写失败测试 # tests/test_blog.py import unittest from blog import BlogPost, Comment, InvalidCommentErrorclass TestBl…

16-算法打卡-哈希表-两个数组的交集-leetcode(349)-第十六天

1 题目地址 349. 两个数组的交集 - 力扣(LeetCode)349. 两个数组的交集 - 给定两个数组 nums1 和 nums2 ,返回 它们的 交集 。输出结果中的每个元素一定是 唯一 的。我们可以 不考虑输出结果的顺序 。 示例 1:输入:nu…

SciPy库详解

SciPy 是一个用于数学、科学和工程计算的 Python 库,它建立在 NumPy 之上,提供了许多高效的算法和工具,用于解决各种科学计算问题。 CONTENT 1. 数值积分功能代码 2. 优化问题求解功能代码3. 线性代数运算功能代码 4. 信号处理功能代码 5. 插…

杰弗里·辛顿:深度学习教父

名人说:路漫漫其修远兮,吾将上下而求索。—— 屈原《离骚》 创作者:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 杰弗里辛顿:当坚持遇见突破,AI迎来新纪元 一、人物简介 杰弗…

BladeX单点登录与若依框架集成实现

1. 概述 本文档详细介绍了将BladeX认证系统与若依(RuoYi)框架集成的完整实现过程。集成采用OAuth2.0授权码流程,使用户能够通过BladeX账号直接登录若依系统,实现无缝单点登录体验。 2. 系统架构 2.1 总体架构 #mermaid-svg-YxdmBwBtzGqZHMme {font-fa…

初识Redis · set和zset

目录 前言: set 基本命令 交集并集差集 内部编码和应用场景 zset 基本命令 交集并集差集 内部编码和应用场景 应用场景(AI生成) 排行榜系统 应用背景 设计思路 热榜系统 应用背景 设计思路 热度计算方式 总结对比表 前言&a…

playwright 教程高级篇:掌握网页自动化与验证码处理等关键技术详解

Playwright 教程高级篇:掌握网页自动化与验证码处理等关键技术详解 本教程将带您一步步学习如何使用 Playwright——一个强大的浏览器自动化工具,来完成网页任务,例如提交链接并处理旋转验证码。我们将按照典型的自动化流程顺序,从启动浏览器到关闭浏览器,详细讲解每个步骤…

数据结构(完)

树 二叉树 构建二叉树 int value;Node left;Node right;public Node(int val) {valueval;} 节点的添加 Node rootnull;public void insert(int num) {Node nodenew Node(num);if(rootnull) {rootnode;return;}Node index root;while(true) {//插入的节点值小if(index.value&g…

FastAPI与SQLAlchemy数据库集成与CRUD操作

title: FastAPI与SQLAlchemy数据库集成与CRUD操作 date: 2025/04/16 09:50:57 updated: 2025/04/16 09:50:57 author: cmdragon excerpt: FastAPI与SQLAlchemy集成基础包括环境准备、数据库连接配置和模型定义。CRUD操作通过数据访问层封装和路由层实现,确保线程安全和事务…

一个基于Django的写字楼管理系统实现方案

一个基于Django的写字楼管理系统实现方案 用户现在需要我用Django来编写一个写字楼管理系统的Web版本,要求包括增删改查写字楼的HTML页面,视频管理功能,本地化部署,以及人员权限管理,包含完整的代码结构和功能实现&am…

mongodb在window10中创建副本集的方法,以及node.js连接副本集的方法

创建Mongodb的副本集最好是新建一个文件夹,如D:/data,不要在mongodb安装文件夹里面创建副本集,虽然这样也可以,但是容易造成误操作或路径混乱;在新建文件夹里与现有 MongoDB 数据隔离,避免误操作影响原有数…

Maven 多仓库与镜像配置全攻略:从原理到企业级实践

Maven 多仓库与镜像配置全攻略:从原理到企业级实践 一、核心概念:Repository 与 Mirror 的本质差异 在 Maven 依赖管理体系中,repository与mirror是构建可靠依赖解析链的两大核心组件,其核心区别如下: 1. Repositor…

STM32 四足机器人常见问题汇总

文章不介绍具体参数,有需求可去网上搜索。 特别声明:不论年龄,不看学历。既然你对这个领域的东西感兴趣,就应该不断培养自己提出问题、思考问题、探索答案的能力。 提出问题:提出问题时,应说明是哪款产品&a…

MySQL 中 `${}` 和 `#{}` 占位符详解及面试高频考点

文章目录 一、概述二、#{} 和 ${} 的核心区别1. 底层机制代码示例 2. 核心区别总结 三、为什么表名只能用 ${}?1. 预编译机制的限制2. 动态表名的实现 四、安全性注意事项1. ${} 的风险场景2. 安全实践 五、面试高频考点1. 基础原理类问题**问题 1**:**问…

C语言编译预处理2

#include <XXXX.h>和#include <XXXX.c> #include "XXXX.h" 是 C 语言中一条预处理指令 #include <XXXX.h>&#xff1a;这种形式用于包含系统标准库的头文件。预处理器会在系统默认的头文件搜索路径中查找XXXX.h 文件。例如在 Linux 系统中&#…

Elasticvue-轻量级Elasticsearch可视化管理工具

Elasticvue一个免费且开源的 Elasticsearch 在线可视化客户端&#xff0c;用于管理 Elasticsearch 集群中的数据&#xff0c;完全支持 Elasticsearch 版本 8.x 和 7.x. 功能特色&#xff1a; 集群概览索引和别名管理分片管理搜索和编辑文档REST 查询快照和存储库管理支持国际…

Git提交规范及最佳实践

Git 提交规范通常是为了提高代码提交的可读性、可维护性和自动化效率&#xff08;如生成 ChangeLog&#xff09;。以下是常见的 Conventional Commits 规范&#xff0c;结合社区最佳实践总结而成&#xff1a; 1. 提交格式 每次提交的 commit message 应包含三部分&#xff1a;…

Ubuntu中snap

通过Snap可以安装众多的软件包。需要注意的是&#xff0c;snap是一种全新的软件包管理方式&#xff0c;它类似一个容器拥有一个应用程序所有的文件和库&#xff0c;各个应用程序之间完全独立。所以使用snap包的好处就是它解决了应用程序之间的依赖问题&#xff0c;使应用程序之…

android studio 运行java main报错

运行某个带main函数的java文件报错 Could not create task :app:Test.main(). > SourceSet with name main not found. 解决办法&#xff1a;在工程的.idea/gradle.xml 文件下添加&#xff1a; <option name"delegatedBuild" value"false" /&g…