PyTorch 实战:Transformer 模型搭建全解析

Transformer 作为一种强大的序列到序列模型,凭借自注意力机制在诸多领域大放异彩。它能并行处理序列,有效捕捉上下文关系,其架构包含编码器与解码器,各由多层组件构成,涉及自注意力、前馈神经网络、归一化和 Dropout 等关键环节 。下面我们深入探讨其核心要点,并结合代码实现进行详细解读。

一、Transformer 核心公式与机制

(一)自注意力计算

自注意力机制是 Transformer 的核心,其计算基于公式\(Attention(Q, K, V)=softmax\left(\frac{Q K^{T}}{\sqrt{d_{k}}}\right) V\) 。其中,Q、K、V分别是查询、键和值矩阵,由输入X分别乘以对应的权重矩阵\(W_Q\)、\(W_K\)、\(W_V\)得到 。\(d_{k}\)表示键的维度,除以\(\sqrt{d_{k}}\) ,一方面可防止\(QK^{T}\)过大导致 softmax 计算溢出,另一方面能让\(QK^{T}\)结果满足均值为 0、方差为 1 的分布 。\(QK^{T}\)本质上是计算向量间的余弦相似度,反映向量方向上的相似程度。

(二)多头注意力机制

多头注意力机制将输入x拆分为h份,独立计算h组不同的线性投影得到各自的Q、K、V ,然后并行计算注意力,最后拼接h个注意力池化结果,并通过可学习的线性投影产生最终输出。这种设计使每个头能关注输入的不同部分,增强了模型对复杂函数的表示能力。

(三)位置编码

由于 Transformer 没有循环结构,位置编码用于保留序列中的位置信息,确保模型在处理序列时能感知元素的位置。

二、自注意力与多头注意力的实现

(一)自注意力实现

在 PyTorch 中,自注意力模块Self_Attention的实现如下:

python

import numpy as np
import torch
from torch import nnclass Self_Attention(nn.Module):def __init__(self, input_dim, dim_k, dim_v):super(Self_Attention, self).__init__()self.q = nn.Linear(input_dim, dim_k)self.k = nn.Linear(input_dim, dim_k)self.v = nn.Linear(input_dim, dim_v)self._norm_fact = 1 / np.sqrt(dim_k)def forward(self, x):Q = self.q(x) K = self.k(x) V = self.v(x) atten = nn.Softmax(dim=-1)(torch.bmm(Q, K.permute(0, 2, 1))) * self._norm_factoutput = torch.bmm(atten, V)return outputX = torch.randn(4, 3, 2)
self_atten = Self_Attention(2, 4, 5) 
res = self_atten(X)
print(res.shape) 

在这段代码中,Self_Attention类继承自nn.Module 。__init__方法初始化了线性层qkv ,并计算了归一化因子_norm_fact 。forward方法实现了自注意力的计算过程,先通过线性层得到Q、K、V ,然后计算注意力权重atten ,最后得到输出output 。

(二)多头注意力实现

多头注意力模块Self_Attention_Muti_Head的实现如下:

python

import torch
import torch.nn as nnclass Self_Attention_Muti_Head(nn.Module):def __init__(self,input_dim,dim_k,dim_v,nums_head):super(Self_Attention_Muti_Head,self).__init__()assert dim_k % nums_head == 0assert dim_v % nums_head == 0self.q = nn.Linear(input_dim,dim_k)self.k = nn.Linear(input_dim,dim_k)self.v = nn.Linear(input_dim,dim_v)self.nums_head = nums_headself.dim_k = dim_kself.dim_v = dim_vself._norm_fact = 1 / np.sqrt(dim_k)def forward(self,x):Q = self.q(x).reshape(-1,x.shape[0],x.shape[1],self.dim_k // self.nums_head) K = self.k(x).reshape(-1,x.shape[0],x.shape[1],self.dim_k // self.nums_head) V = self.v(x).reshape(-1,x.shape[0],x.shape[1],self.dim_v // self.nums_head)atten = nn.Softmax(dim=-1)(torch.matmul(Q,K.permute(0,1,3,2))) output = torch.matmul(atten,V).reshape(x.shape[0],x.shape[1],-1) return outputx=torch.rand(1,3,4)
atten=Self_Attention_Muti_Head(4,4,4,2)
y=atten(x)
print(y.shape) 

在这个类中,__init__方法进行了参数校验和模块初始化 。forward方法将输入x经过线性层变换后,重塑形状以适应多头计算,接着计算注意力权重并得到输出,最后将多头的结果拼接起来。

三、注意力机制的拓展

(一)视觉注意力机制

视觉注意力机制主要包括空间域、通道域和混合域三种。空间域注意力通过对图片空间域信息进行变换,生成掩码并打分来提取关键信息;通道域注意力为每个通道分配权重,代表模块有 SENet,通过全局池化提取通道权重,进而调整特征图;混合域注意力则结合了空间域和通道域的信息 。

(二)通道域注意力(SENet)实现

SENet 的实现代码如下:

python

class SELayer(nn.Module):def __init__(self, channel, reduction=16):super(SELayer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1, 1) return x * y.expand_as(x)

SELayer类中,__init__方法初始化了平均池化层avg_pool和全连接层序列fc 。forward方法实现了 SENet 的核心操作,先通过平均池化进行 Squeeze 操作,再经过全连接层进行 Excitation 操作,最后将生成的权重与原特征图相乘,实现对特征图的增强 。

(三)门控注意力机制(GCT)

GCT 是一种能提升卷积网络泛化能力的通道间建模结构。它包含全局上下文嵌入、通道规范化和门控适应三个部分。全局上下文嵌入模块汇聚每个通道的全局上下文信息;通道规范化构建神经元竞争关系;门控适应加入门限机制,促进神经元的协同或竞争关系 。

其实现代码如下:

python

class GCT(nn.Module):def __init__(self, num_channels, epsilon=1e-5, mode='l2', after_relu=False):super(GCT, self).__init__()self.alpha = nn.Parameter(torch.ones(1, num_channels, 1, 1))self.gamma = nn.Parameter(torch.zeros(1, num_channels, 1, 1))self.beta = nn.Parameter(torch.zeros(1, num_channels, 1, 1))self.epsilon = epsilonself.mode = modeself.after_relu = after_reludef forward(self, x):if self.mode == 'l2':embedding = (x.pow(2).sum((2, 3), keepdim=True) + self.epsilon).pow(0.5) * self.alphanorm = self.gamma / ((embedding.pow(2).mean(dim=1, keepdim=True) + self.epsilon).pow(0.5))elif self.mode == 'l1':if not self.after_relu:_x = torch.abs(x)else:_x = xembedding = _x.sum((2, 3), keepdim=True) * self.alphanorm = self.gamma / (torch.abs(embedding).mean(dim=1, keepdim=True) + self.epsilon)gate = 1. + torch.tanh(embedding * norm + self.beta)return x * gate

GCT类中,__init__方法初始化了可训练参数alphagammabeta以及其他超参数 。forward方法根据不同的模式(l2l1)计算嵌入和归一化结果,最后通过门控机制得到输出 。GCT 通常添加在 Conv 层前,训练时可先冻结原模型训练 GCT,再解冻微调 。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/76986.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

网页不同渲染方式的应对与反爬机制的处理——python爬虫

文章目录 写在前面爬虫习惯web 网页渲染方式服务器渲染客户端渲染 反爬机制使用session对象使用cookie让请求头信息更丰富使用代理和随机延迟 写在前面 本文是对前两篇文章所介绍的内容的补充,在了解前两篇文章——《爬虫入门与requests库的使用》和《BeautifulSou…

RK3588平台用v4l工具调试USB摄像头实践(亮度,饱和度,对比度,色相等)

目录 前言:v4l-utils简介 一:查找当前的摄像头设备 二:查看当前摄像头支持的v4l2-ctl调试参数 三根据提示设置对应参数,在提示范围内设置 四:常用调试命令 五:应用内执行命令方法 前言:v4l-utils简介 v4l-utils工具是由Linu…

Spring Security基础入门

本入门案例主要演示Spring Security在Spring Boot中的安全管理效果。为了更好地使用Spring Boot整合实现Spring Security安全管理功能,体现案例中Authentication(认证)和Authorization(授权)功能的实现,本案…

Trae+DeepSeek学习Python开发MVC框架程序笔记(二):使用4个文件实现MVC框架

修改上节文件,将test2.py拆分为4个文件,目录结构如下: mvctest/ │── model.py # 数据模型 │── view.py # 视图界面 │── controller.py # 控制器 │── main.py # 程序入口其中model.py代码如下&#xff…

从认证到透传:用 Nginx 为 EasySearch 构建一体化认证网关

在构建本地或云端搜索引擎系统时,EasySearch 凭借其轻量、高性能、易部署等优势,逐渐成为众多开发者和技术爱好者的首选。但在实际部署过程中,如何借助 Nginx 为 EasySearch 提供高效、稳定且安全的访问入口,尤其是在身份认证方面…

CPU 虚拟化机制——受限直接执行 (LDE)

1. 引言:CPU虚拟化的核心问题 让多个进程看似同时运行在一个物理CPU上。核心思想是时分共享 (time sharing) CPU。为了实现高效且可控的时分共享,本章介绍了一种关键机制,称为受限直接执行 (Limited Direct Execution, LDE)。 1.1 LDE的基本…

linux 中断子系统链式中断编程

直接贴代码了&#xff1a; 虚拟中断控制器代码&#xff0c;chained_virt.c #include<linux/kernel.h> #include<linux/module.h> #include<linux/clk.h> #include<linux/err.h> #include<linux/init.h> #include<linux/interrupt.h> #inc…

容器修仙传 我的灵根是Pod 第10章 心魔大劫(RBAC与SecurityContext)

第四卷&#xff1a;飞升之劫化神篇 第10章 心魔大劫&#xff08;RBAC与SecurityContext&#xff09; 血月当空&#xff0c;林衍的混沌灵根正在异变。 每道经脉都爬满黑色纹路&#xff0c;神识海中回荡着蛊惑之音&#xff1a;"破开藏经阁第九层禁制…夺取《太古弑仙诀》……

基于c#,wpf,ef框架,sql server数据库,音乐播放器

详细视频: 【基于c#,wpf,ef框架,sql server数据库&#xff0c;音乐播放器。-哔哩哔哩】 https://b23.tv/ZqmOKJ5

精益数据分析(21/126):剖析创业增长引擎与精益画布指标

精益数据分析&#xff08;21/126&#xff09;&#xff1a;剖析创业增长引擎与精益画布指标 大家好&#xff01;在创业和数据分析的探索道路上&#xff0c;我一直希望能和大家携手共进&#xff0c;共同学习。今天&#xff0c;我们继续深入研读《精益数据分析》&#xff0c;剖析…

Spark-streaming核心编程

1.导入依赖‌&#xff1a; <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-streaming-kafka-0-10_2.12</artifactId> <version>3.0.0</version> </dependency> 2.编写代码‌&#xff1a; 创建Sp…

Kafka的ISR机制是什么?如何保证数据一致性?

一、Kafka ISR机制深度解析 1. ISR机制定义 ISR&#xff08;In-Sync Replicas&#xff09;是Kafka保证数据一致性的核心机制&#xff0c;由Leader副本&#xff08;复杂读写&#xff09;和Follower副本(负责备份)组成。当Follower副本的延迟超过replica.lag.time.max.ms&#…

Docker 基本概念与安装指南

Docker 基本概念与安装指南 一、Docker 核心概念 1. 容器&#xff08;Container&#xff09; 容器是 Docker 的核心运行单元&#xff0c;本质是一个轻量级的沙盒环境。它基于镜像创建&#xff0c;包含应用程序及其运行所需的依赖&#xff08;如代码、库、环境变量等&#xf…

数据库监控 | MongoDB监控全解析

PART 01 MongoDB&#xff1a;灵活、可扩展的文档数据库 MongoDB作为一款开源的NoSQL数据库&#xff0c;凭借其灵活的数据模型&#xff08;基于BSON的文档存储&#xff09;、水平扩展能力&#xff08;分片集群&#xff09;和高可用性&#xff08;副本集架构&#xff09;&#x…

OpenFeign和Gateway

OpenFeign和Gateway 一.OpenFeign介绍二.快速上手1.引入依赖2.开启openfeign的功能3.编写客户端4.修改远程调用代码5.测试 三.OpenFeign参数传递1.传递单个参数2.多个参数、传递对象和传递JSON字符串3.最佳方式写代码继承的方式抽取的方式 四.部署OpenFeign五.统一服务入口-Gat…

spark-streaming(二)

DStream创建&#xff08;kafka数据源&#xff09; 1.在idea中的 pom.xml 中添加依赖 <dependency><groupId>org.apache.spark</groupId><artifactId>spark-streaming-kafka-0-10_2.12</artifactId><version>3.0.0</version> </…

JAVA聚焦OutOfMemoryError 异常

个人主页 文章专栏 在正文开始前&#xff0c;我想多说几句&#xff0c;也就是吐苦水吧…最近这段时间一直想写点东西&#xff0c;停下来反思思考一下。 心中万言&#xff0c;真正执笔时又不知先写些什么。通常这个时候&#xff0c;我都会随便写写&#xff0c;文风极像散文&…

如何在Spring Boot中配置自定义端口运行应用程序

Spring Boot 应用程序默认在端口 8080 上运行嵌入式 Web 服务器&#xff08;如 Tomcat、Jetty 或 Undertow&#xff09;。然而&#xff0c;在开发、测试或生产环境中&#xff0c;开发者可能需要将应用程序配置为在自定义端口上运行&#xff0c;例如避免端口冲突、适配微服务架构…

linux嵌入式(进程与线程1)

Linux进程 进程介绍 1. 进程的基本概念 定义&#xff1a;进程是程序的一次执行过程&#xff0c;拥有独立的地址空间、资源&#xff08;如内存、文件描述符&#xff09;和唯一的进程 ID&#xff08;PID&#xff09;。 组成&#xff1a; 代码段&#xff1a;程序的指令。 数据…

智驭未来:NVIDIA自动驾驶安全白皮书与实验室创新实践深度解析

一、引言&#xff1a;自动驾驶安全的范式革新 在当今数字化浪潮的推动下&#xff0c;全球自动驾驶技术正大步迈入商业化的深水区。随着越来越多的自动驾驶车辆走上道路&#xff0c;其安全性已成为整个行业乃至社会关注的核心命题。在这个关键的转折点上&#xff0c;NVIDIA 凭借…