用 Pytorch 训练一个 Transformer模型

昨天说了一下Transformer架构,今天我们来看看怎么 Pytorch 训练一个Transormer模型,真实训练一个模型是个庞大工程,准备数据、准备硬件等等,我只是做一个简单的实现。因为只是做实验,本地用 CPU 也可以运行。
本文包含以下几部分:

  1. 准备环境。
  2. 然后就是跟据架构来定义每一层,包括Embedding、Position Encoding、多头注意力、 网络层。
  3. 准备Encoder。
  4. 准备Decoder。
  5. 运行 Transformer,包括训练和评估。

安装Pytorch 环境

!pip3 install torch torchvision torchaudio

引入所需工具类库

引入需要的类库,pytorch 是强大的训练框架,深度学习中需要的一些函数和基本功能都已经实现。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

Position Embedding

生成位置信息。

class PositionalEncoding(nn.Module):def __init__(self, d_model, max_seq_length):super(PositionalEncoding, self).__init__()pe = torch.zeros(max_seq_length, d_model)position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)self.register_buffer('pe', pe.unsqueeze(0))def forward(self, x):return x + self.pe[:, :x.size(1)]

多头注意力

  1. 初始化model 维度,头数,每个头的维度。
  2. 计算是在 forward 这个方法,主要看这个方法。
class MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()# Ensure that the model dimension (d_model) is divisible by the number of headsassert d_model % num_heads == 0, "d_model must be divisible by num_heads"# Initialize dimensionsself.d_model = d_model # Model's dimensionself.num_heads = num_heads # Number of attention headsself.d_k = d_model // num_heads # Dimension of each head's key, query, and value# Linear layers for transforming inputsself.W_q = nn.Linear(d_model, d_model) # Query transformationself.W_k = nn.Linear(d_model, d_model) # Key transformationself.W_v = nn.Linear(d_model, d_model) # Value transformationself.W_o = nn.Linear(d_model, d_model) # Output transformationdef scaled_dot_product_attention(self, Q, K, V, mask=None):# Calculate attention scoresattn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)# Apply mask if provided (useful for preventing attention to certain parts like padding)if mask is not None:attn_scores = attn_scores.masked_fill(mask == 0, -1e9)# Softmax is applied to obtain attention probabilitiesattn_probs = torch.softmax(attn_scores, dim=-1)# Multiply by values to obtain the final outputoutput = torch.matmul(attn_probs, V)return outputdef split_heads(self, x):# 转换,每一个 head 独立处理batch_size, seq_length, d_model = x.size()return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)def combine_heads(self, x):# Combine the multiple heads back to original shapebatch_size, _, seq_length, d_k = x.size()return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)def forward(self, Q, K, V, mask=None):# 线性转换并切分Q = self.split_heads(self.W_q(Q))K = self.split_heads(self.W_k(K))V = self.split_heads(self.W_v(V))# 运行计算公式attn_output = self.scaled_dot_product_attention(Q, K, V, mask)# 合并并返回output = self.W_o(self.combine_heads(attn_output))return output

网络定义

class PositionWiseFeedForward(nn.Module):def __init__(self, d_model, d_ff):super(PositionWiseFeedForward, self).__init__()self.fc1 = nn.Linear(d_model, d_ff)self.fc2 = nn.Linear(d_ff, d_model)self.relu = nn.ReLU()def forward(self, x):return self.fc2(self.relu(self.fc1(x)))

Encoder

在这里插入图片描述
跟据这张图看下面的实现比较直观,初始化了MultiHeadAttention、PositionWiseFeedForward、两个LayerNorm。 forward 方法中 x 是 Encoder 的输入。

class EncoderLayer(nn.Module):def __init__(self, d_model, num_heads, d_ff, dropout):super(EncoderLayer, self).__init__()self.self_attn = MultiHeadAttention(d_model, num_heads)self.feed_forward = PositionWiseFeedForward(d_model, d_ff)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, mask):attn_output = self.self_attn(x, x, x, mask)x = self.norm1(x + self.dropout(attn_output))ff_output = self.feed_forward(x)x = self.norm2(x + self.dropout(ff_output))return x

Decoder

在这里插入图片描述
看代码的方式和 Encoder 类似,比较好理解,2 个MultiHeadAttention、3个 Norm,forward 中cross_attn 把 enc_output作为传入的参数。

class DecoderLayer(nn.Module):def __init__(self, d_model, num_heads, d_ff, dropout):super(DecoderLayer, self).__init__()self.self_attn = MultiHeadAttention(d_model, num_heads)self.cross_attn = MultiHeadAttention(d_model, num_heads)self.feed_forward = PositionWiseFeedForward(d_model, d_ff)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.norm3 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, enc_output, src_mask, tgt_mask):attn_output = self.self_attn(x, x, x, tgt_mask)x = self.norm1(x + self.dropout(attn_output))attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)x = self.norm2(x + self.dropout(attn_output))ff_output = self.feed_forward(x)x = self.norm3(x + self.dropout(ff_output))return x

Transformer

Transformer主类,包括初始化 embedding、position embedding、encoder 和 decoder。forward 方法进行计算。

class Transformer(nn.Module):def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):super(Transformer, self).__init__()self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)self.positional_encoding = PositionalEncoding(d_model, max_seq_length)self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])self.fc = nn.Linear(d_model, tgt_vocab_size)self.dropout = nn.Dropout(dropout)def generate_mask(self, src, tgt):src_mask = (src != 0).unsqueeze(1).unsqueeze(2)tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)seq_length = tgt.size(1)nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()tgt_mask = tgt_mask & nopeak_maskreturn src_mask, tgt_maskdef forward(self, src, tgt):src_mask, tgt_mask = self.generate_mask(src, tgt)src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))enc_output = src_embeddedfor enc_layer in self.encoder_layers:enc_output = enc_layer(enc_output, src_mask)dec_output = tgt_embeddedfor dec_layer in self.decoder_layers:dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)output = self.fc(dec_output)return output

训练

首先准备数据,这里的数据是随机生成,只是做演示。

src_vocab_size = 5000
tgt_vocab_size = 5000
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length = 100
dropout = 0.1transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)# Generate random sample data
src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)
tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)

开始训练

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)transformer.train()for epoch in range(100):optimizer.zero_grad()output = transformer(src_data, tgt_data[:, :-1])loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[:, 1:].contiguous().view(-1))loss.backward()optimizer.step()print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

评估

将模型运行在验证集或者测试集上,这里数据也是随机生成的,只为体验一下完整流程。

transformer.eval()# Generate random sample validation data
val_src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)
val_tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)with torch.no_grad():val_output = transformer(val_src_data, val_tgt_data[:, :-1])val_loss = criterion(val_output.contiguous().view(-1, tgt_vocab_size), val_tgt_data[:, 1:].contiguous().view(-1))print(f"Validation Loss: {val_loss.item()}")

如果你对 Pytorch 和神经网络比较熟悉,Transformer整体实现起来并不复杂,如果想我一样对深度学习不太熟悉,理解起来还是有些困难,这里只是大概跑了一下流程,对Transformer训练有一个概念。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/827347.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue阶段练习:tab栏、进度条、

阶段练习旨在学习完Vue 指令、计算属性、侦听器-CSDN博客后,进行自我检测,每个练习分为效果显示、需求分析、静态代码、完整代码、总结 四个部分,效果显示和准备代码已给出,我们需要完成“完整代码”部分。 练习1:tab栏…

LSB隐写是什么?

LSB隐写是什么? 所需知识二进制位LSB的概念LSB在数值中的作用LSB在量化中的应用小结 LSB隐写原理应用威胁与挑战改进补充资料 所需知识 二进制数 位(bit) LSB概念 二进制 在计算机科学中,二进制数是一种数制,使用两…

开源大数据集群部署(二十一)Spark on yarn 部署

作者:櫰木 1 spark on yarn安装(每个节点) cd /root/bigdata/ tar -xzvf spark-3.3.1-bin-hadoop3.tgz -C /opt/ ln -s /opt/spark-3.3.1-bin-hadoop3 /opt/spark chown -R spark:spark /opt/spark-3.3.1-bin-hadoop32 配置环境变量及修改配…

攻防世界---misc---再见李华

1.下载附件是解压之后得到一张图片 2.使用常规方法后没有得到什么信息,接着用winhex分析,发现有压缩包 ,里面还有个key.txt 3.接着用kali使用命名foremost进行分离,得到压缩包,里面的key.txt需要密码 4.接着给压缩包暴…

IDEA代码重构

重构 重构的目的: 提高代码的可读性、可维护性、可扩展性和性能。 重命名元素 重命名类 当我们进行重命名操作的时候可以看到第六行存在一个R(rename),点击后就会弹出所偶有引用,这样可以避免我们在修改后存在遗漏引用处未修改。 我们可以通过…

生成计算机注册信息:硬盘,主板和CPU组合 根据计算机硬盘,主板,CPU生成注册信息

目录 一.总体说明 二.完整代码 三.逐行分析 一.总体说明 注册信息是用于识别和验证计算机的唯一标识符。在生成注册信息时,通常会包含计算机的硬盘、主板和CPU的相关信息。这些信息可以用于授权软件、管理许可证、防止盗版以及进行系统配置和维护等方面。 具体而言,注册…

管理集群工具之LVS

管理集群工具之LVS 集群概念 将很多机器组织在一起,作为一个整体对外提供服务集群在扩展性、性能方面都可以做到很灵活集群分类 负载均衡集群:Load Balance高可用集群:High Availability高性能计算:High Performance Computing …

模拟网关是什么?

模拟网关是一种网络设备,用于在模拟电话系统和数字网络之间进行信号转换。它的主要作用是将模拟语音信号转换为数字格式,使得这些信号能够通过基于IP(互联网协议)的网络进行传输,从而实现语音通信。这种设备是将传统的…

Python环境找不到解决方法

Python环境找不到 打开设置:Ctrl Alt S 添加Local Interpreter... 打开System Interpreter,找到本地安装的Python.exe路径,然后一路点OK Trust Project 如果打开工程时,出现如下对话框,请勾选 Trust projects in ...&…

项目管理中,项目团队如何高效的协作与沟通?

目 录 一、项目团队高效的协作与沟通,可以通过以下几个方面来实现: 二、如何在项目团队中明确和共享愿景以提高协作效率? 三、有效的沟通策略在项目管理中的应用案例有哪些? 四、建立哪些具体的沟通机制可以提升团队协作效率…

matlab学习003-绘制由差分方程表示的离散系统图像

目录 1,题目 2,使用函数求解差分方程 1)基础知识 ①filter函数和impz函数 ②zeros函数 ☀ 2)绘制图像 ​☀ 3)对应代码 如果连简单的信号都不会的,建议先看如下文章👇,之…

互联网大厂ssp面经,数据结构part2

1. 什么是堆和优先队列?它们的特点和应用场景是什么? a. 堆是一种特殊的树形数据结构,具有以下特点:i. 堆是一个完全二叉树,即除了最后一层外,其他层都是满的,并且最后一层的节点都靠左对齐。i…

【设计模式】11、flyweight 享元模式

文章目录 十一、flyweight11.1 pool 连接池11.1.1 pool_test.go11.1.2 pool.go11.1.3 conn.go 11.2 chess_board11.2.1 chess_test.go11.2.2 chess.go 十一、flyweight https://refactoringguru.cn/design-patterns/flyweight 大量重复的对象, 如果很消耗资源, 没必要每次都初…

IP/网关流量控制:Linux上设置多个路由表 (多号段IP配置)

文章目录 引言I 使用特定的路由策略和规则1.1 实现思路1.2 配置路由策略II Mac/window的IP多号段配置2.1 Mac电脑2.2 windows系统2.3 linux系统III 预备知识3.1 配置接口 IP 地址和网关3.2 启用IP转发功能3.3 修改IP地址3.4 添加一个新的路由策略3.5 路由表配置引言

常见的经典目标检测

常见的经典目标检测算法主要包括: R-CNN 系列57:R-CNN(Region-based Convolutional Neural Networks)是一种经典的目标检测算法,它通过卷积神经网络(CNN)提取候选区域特征,然后使用支…

SEGGER Embedded Studio IDE移植FreeRTOS

SEGGER Embedded Studio IDE移植FreeRTOS 一、简介二、技术路线2.1 获取FreeRTOS源码2.2 将必要的文件复制到工程中2.2.1 移植C文件2.2.2 移植portable文件2.2.3 移植头文件 2.3 创建FreeRTOSConfig.h并进行配置2.3.1 处理中断优先级2.3.2 configASSERT( x )的处理2.3.3 关于系…

【npm】常用的NPM命令及在开发过程中的应用

常用的NPM命令及在开发过程中的应用 NPM(Node Package Manager)是JavaScript的包管理工具,也是世界上最大的软件注册表。它允许开发者共享和重用代码,并便于管理各种Node.js的包依赖。本文将介绍一些常用的NPM命令,并…

linq select 和selectMany的区别

Select 和 SelectMany 都是 LINQ 查询方法&#xff0c;但它们之间有一些区别。 Select 方法用于从集合中选择特定的属性或对集合中的元素进行转换&#xff0c;并返回一个新的集合。例如&#xff1a; var numbers new List<int> { 1, 2, 3, 4, 5 }; var squaredNumbers…

SRS WebRTC Whip 和 Whep 部署体验问题

whip 報錯 404 webrtc推流 小窗口一闪而过&#xff0c;然后查看f12回复404的报错信息 chrome版本&#xff1a; 正在检查更新 版本 123.0.6312.123&#xff08;正式版本&#xff09; &#xff08;64 位&#xff09; centos 7.9 源码安装部署&#xff0c; 代码分支5.0 完全按…

【设计模式】中介模式

目录 什么是中介模式 中介模式的组成 使用场景&#xff1a; 优点&#xff1a; 缺点&#xff1a; Java 示例代码&#xff1a; 什么是中介模式 Java 中的中介模式&#xff08;Mediator Pattern&#xff09;是一种行为型设计模式&#xff0c;旨在降低多个对象和类之间的通信…