用 Pytorch 训练一个 Transformer模型

昨天说了一下Transformer架构,今天我们来看看怎么 Pytorch 训练一个Transormer模型,真实训练一个模型是个庞大工程,准备数据、准备硬件等等,我只是做一个简单的实现。因为只是做实验,本地用 CPU 也可以运行。
本文包含以下几部分:

  1. 准备环境。
  2. 然后就是跟据架构来定义每一层,包括Embedding、Position Encoding、多头注意力、 网络层。
  3. 准备Encoder。
  4. 准备Decoder。
  5. 运行 Transformer,包括训练和评估。

安装Pytorch 环境

!pip3 install torch torchvision torchaudio

引入所需工具类库

引入需要的类库,pytorch 是强大的训练框架,深度学习中需要的一些函数和基本功能都已经实现。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

Position Embedding

生成位置信息。

class PositionalEncoding(nn.Module):def __init__(self, d_model, max_seq_length):super(PositionalEncoding, self).__init__()pe = torch.zeros(max_seq_length, d_model)position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)self.register_buffer('pe', pe.unsqueeze(0))def forward(self, x):return x + self.pe[:, :x.size(1)]

多头注意力

  1. 初始化model 维度,头数,每个头的维度。
  2. 计算是在 forward 这个方法,主要看这个方法。
class MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()# Ensure that the model dimension (d_model) is divisible by the number of headsassert d_model % num_heads == 0, "d_model must be divisible by num_heads"# Initialize dimensionsself.d_model = d_model # Model's dimensionself.num_heads = num_heads # Number of attention headsself.d_k = d_model // num_heads # Dimension of each head's key, query, and value# Linear layers for transforming inputsself.W_q = nn.Linear(d_model, d_model) # Query transformationself.W_k = nn.Linear(d_model, d_model) # Key transformationself.W_v = nn.Linear(d_model, d_model) # Value transformationself.W_o = nn.Linear(d_model, d_model) # Output transformationdef scaled_dot_product_attention(self, Q, K, V, mask=None):# Calculate attention scoresattn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)# Apply mask if provided (useful for preventing attention to certain parts like padding)if mask is not None:attn_scores = attn_scores.masked_fill(mask == 0, -1e9)# Softmax is applied to obtain attention probabilitiesattn_probs = torch.softmax(attn_scores, dim=-1)# Multiply by values to obtain the final outputoutput = torch.matmul(attn_probs, V)return outputdef split_heads(self, x):# 转换,每一个 head 独立处理batch_size, seq_length, d_model = x.size()return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)def combine_heads(self, x):# Combine the multiple heads back to original shapebatch_size, _, seq_length, d_k = x.size()return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)def forward(self, Q, K, V, mask=None):# 线性转换并切分Q = self.split_heads(self.W_q(Q))K = self.split_heads(self.W_k(K))V = self.split_heads(self.W_v(V))# 运行计算公式attn_output = self.scaled_dot_product_attention(Q, K, V, mask)# 合并并返回output = self.W_o(self.combine_heads(attn_output))return output

网络定义

class PositionWiseFeedForward(nn.Module):def __init__(self, d_model, d_ff):super(PositionWiseFeedForward, self).__init__()self.fc1 = nn.Linear(d_model, d_ff)self.fc2 = nn.Linear(d_ff, d_model)self.relu = nn.ReLU()def forward(self, x):return self.fc2(self.relu(self.fc1(x)))

Encoder

在这里插入图片描述
跟据这张图看下面的实现比较直观,初始化了MultiHeadAttention、PositionWiseFeedForward、两个LayerNorm。 forward 方法中 x 是 Encoder 的输入。

class EncoderLayer(nn.Module):def __init__(self, d_model, num_heads, d_ff, dropout):super(EncoderLayer, self).__init__()self.self_attn = MultiHeadAttention(d_model, num_heads)self.feed_forward = PositionWiseFeedForward(d_model, d_ff)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, mask):attn_output = self.self_attn(x, x, x, mask)x = self.norm1(x + self.dropout(attn_output))ff_output = self.feed_forward(x)x = self.norm2(x + self.dropout(ff_output))return x

Decoder

在这里插入图片描述
看代码的方式和 Encoder 类似,比较好理解,2 个MultiHeadAttention、3个 Norm,forward 中cross_attn 把 enc_output作为传入的参数。

class DecoderLayer(nn.Module):def __init__(self, d_model, num_heads, d_ff, dropout):super(DecoderLayer, self).__init__()self.self_attn = MultiHeadAttention(d_model, num_heads)self.cross_attn = MultiHeadAttention(d_model, num_heads)self.feed_forward = PositionWiseFeedForward(d_model, d_ff)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.norm3 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, enc_output, src_mask, tgt_mask):attn_output = self.self_attn(x, x, x, tgt_mask)x = self.norm1(x + self.dropout(attn_output))attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)x = self.norm2(x + self.dropout(attn_output))ff_output = self.feed_forward(x)x = self.norm3(x + self.dropout(ff_output))return x

Transformer

Transformer主类,包括初始化 embedding、position embedding、encoder 和 decoder。forward 方法进行计算。

class Transformer(nn.Module):def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):super(Transformer, self).__init__()self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)self.positional_encoding = PositionalEncoding(d_model, max_seq_length)self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])self.fc = nn.Linear(d_model, tgt_vocab_size)self.dropout = nn.Dropout(dropout)def generate_mask(self, src, tgt):src_mask = (src != 0).unsqueeze(1).unsqueeze(2)tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)seq_length = tgt.size(1)nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()tgt_mask = tgt_mask & nopeak_maskreturn src_mask, tgt_maskdef forward(self, src, tgt):src_mask, tgt_mask = self.generate_mask(src, tgt)src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))enc_output = src_embeddedfor enc_layer in self.encoder_layers:enc_output = enc_layer(enc_output, src_mask)dec_output = tgt_embeddedfor dec_layer in self.decoder_layers:dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)output = self.fc(dec_output)return output

训练

首先准备数据,这里的数据是随机生成,只是做演示。

src_vocab_size = 5000
tgt_vocab_size = 5000
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length = 100
dropout = 0.1transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)# Generate random sample data
src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)
tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)

开始训练

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)transformer.train()for epoch in range(100):optimizer.zero_grad()output = transformer(src_data, tgt_data[:, :-1])loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[:, 1:].contiguous().view(-1))loss.backward()optimizer.step()print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

评估

将模型运行在验证集或者测试集上,这里数据也是随机生成的,只为体验一下完整流程。

transformer.eval()# Generate random sample validation data
val_src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)
val_tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)with torch.no_grad():val_output = transformer(val_src_data, val_tgt_data[:, :-1])val_loss = criterion(val_output.contiguous().view(-1, tgt_vocab_size), val_tgt_data[:, 1:].contiguous().view(-1))print(f"Validation Loss: {val_loss.item()}")

如果你对 Pytorch 和神经网络比较熟悉,Transformer整体实现起来并不复杂,如果想我一样对深度学习不太熟悉,理解起来还是有些困难,这里只是大概跑了一下流程,对Transformer训练有一个概念。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/827347.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue阶段练习:tab栏、进度条、

阶段练习旨在学习完Vue 指令、计算属性、侦听器-CSDN博客后,进行自我检测,每个练习分为效果显示、需求分析、静态代码、完整代码、总结 四个部分,效果显示和准备代码已给出,我们需要完成“完整代码”部分。 练习1:tab栏…

开源大数据集群部署(二十一)Spark on yarn 部署

作者:櫰木 1 spark on yarn安装(每个节点) cd /root/bigdata/ tar -xzvf spark-3.3.1-bin-hadoop3.tgz -C /opt/ ln -s /opt/spark-3.3.1-bin-hadoop3 /opt/spark chown -R spark:spark /opt/spark-3.3.1-bin-hadoop32 配置环境变量及修改配…

攻防世界---misc---再见李华

1.下载附件是解压之后得到一张图片 2.使用常规方法后没有得到什么信息,接着用winhex分析,发现有压缩包 ,里面还有个key.txt 3.接着用kali使用命名foremost进行分离,得到压缩包,里面的key.txt需要密码 4.接着给压缩包暴…

IDEA代码重构

重构 重构的目的: 提高代码的可读性、可维护性、可扩展性和性能。 重命名元素 重命名类 当我们进行重命名操作的时候可以看到第六行存在一个R(rename),点击后就会弹出所偶有引用,这样可以避免我们在修改后存在遗漏引用处未修改。 我们可以通过…

管理集群工具之LVS

管理集群工具之LVS 集群概念 将很多机器组织在一起,作为一个整体对外提供服务集群在扩展性、性能方面都可以做到很灵活集群分类 负载均衡集群:Load Balance高可用集群:High Availability高性能计算:High Performance Computing …

模拟网关是什么?

模拟网关是一种网络设备,用于在模拟电话系统和数字网络之间进行信号转换。它的主要作用是将模拟语音信号转换为数字格式,使得这些信号能够通过基于IP(互联网协议)的网络进行传输,从而实现语音通信。这种设备是将传统的…

Python环境找不到解决方法

Python环境找不到 打开设置:Ctrl Alt S 添加Local Interpreter... 打开System Interpreter,找到本地安装的Python.exe路径,然后一路点OK Trust Project 如果打开工程时,出现如下对话框,请勾选 Trust projects in ...&…

项目管理中,项目团队如何高效的协作与沟通?

目 录 一、项目团队高效的协作与沟通,可以通过以下几个方面来实现: 二、如何在项目团队中明确和共享愿景以提高协作效率? 三、有效的沟通策略在项目管理中的应用案例有哪些? 四、建立哪些具体的沟通机制可以提升团队协作效率…

matlab学习003-绘制由差分方程表示的离散系统图像

目录 1,题目 2,使用函数求解差分方程 1)基础知识 ①filter函数和impz函数 ②zeros函数 ☀ 2)绘制图像 ​☀ 3)对应代码 如果连简单的信号都不会的,建议先看如下文章👇,之…

互联网大厂ssp面经,数据结构part2

1. 什么是堆和优先队列?它们的特点和应用场景是什么? a. 堆是一种特殊的树形数据结构,具有以下特点:i. 堆是一个完全二叉树,即除了最后一层外,其他层都是满的,并且最后一层的节点都靠左对齐。i…

SEGGER Embedded Studio IDE移植FreeRTOS

SEGGER Embedded Studio IDE移植FreeRTOS 一、简介二、技术路线2.1 获取FreeRTOS源码2.2 将必要的文件复制到工程中2.2.1 移植C文件2.2.2 移植portable文件2.2.3 移植头文件 2.3 创建FreeRTOSConfig.h并进行配置2.3.1 处理中断优先级2.3.2 configASSERT( x )的处理2.3.3 关于系…

linq select 和selectMany的区别

Select 和 SelectMany 都是 LINQ 查询方法&#xff0c;但它们之间有一些区别。 Select 方法用于从集合中选择特定的属性或对集合中的元素进行转换&#xff0c;并返回一个新的集合。例如&#xff1a; var numbers new List<int> { 1, 2, 3, 4, 5 }; var squaredNumbers…

SRS WebRTC Whip 和 Whep 部署体验问题

whip 報錯 404 webrtc推流 小窗口一闪而过&#xff0c;然后查看f12回复404的报错信息 chrome版本&#xff1a; 正在检查更新 版本 123.0.6312.123&#xff08;正式版本&#xff09; &#xff08;64 位&#xff09; centos 7.9 源码安装部署&#xff0c; 代码分支5.0 完全按…

socket通信基础讲解及示例-C

socket通信之C篇 服务端与客户端简介 socket通信服务端与客户端通信模型通信实战server&#xff08;服务端&#xff09;创建client&#xff08;客户端&#xff09;创建 函数详解创建套接字 socket绑定端口bind进入监听状态listen获取客户端连接请求accept接收网络数据read发送数…

每日一题---移除链表元素

文章目录 前言1.题目2.分析思路3.参考代码 前言 Leetcode–-移除链表元素 1.题目 2.分析思路 首先要创建一个新的链表&#xff0c;在定义三个指针&#xff0c;newHead&#xff0c;newTail和pcur&#xff0c;分别代表新链表头&#xff0c;新链表尾以及用于遍历原链表。 其次是…

Rust入门-所有权

一、为什么、是什么、怎么用 1、为什么Rust要提出一个所有权和借用的概念 所有的程序都必须和计算机内存打交道&#xff0c;如何从内存中申请空间来存放程序的运行内容&#xff0c;如何在不需要的时候释放这些空间&#xff0c;成为所有编程语言设计的难点之一。 主要分为三种…

git merge 和 git rebese的区别

git merge 和 git rebese的区别 拉取分支和合并代码会涉及两种选择&#xff0c;git merge 和 git rebase&#xff1a; rebase&#xff1a;变基&#xff0c;会有一个干净的分支&#xff0c;但是对于记录来源不够清楚merge&#xff1a;合并&#xff0c;git 分支看起来比较混乱&…

CentOS 7虚拟机配置静态IP地址(一)

IP地址的配置 以下几个地址需要记住&#xff0c;在配置中使用 &#xff08;1&#xff09;查看MAC地址&#xff08;点击菜单虚拟机-设置-网络适配器-高级-记住MAC地址&#xff09; &#xff08;2&#xff09;查看子网掩码和网关IP&#xff08;点击菜单编辑-虚拟网络编辑器-选择…

机器学习-10-神经网络python实现-从零开始

文章目录 总结参考本门课程的目标机器学习定义从零构建神经网络手写数据集MNIST介绍代码读取数据集MNIST神经网络实现测试手写的图片 带有反向查询的神经网络实现 总结 本系列是机器学习课程的系列课程&#xff0c;主要介绍基于python实现神经网络。 参考 BP神经网络及pytho…

Reactor 模式

目录 1. 实现代码 2. Reactor 模式 3. 分析服务器的实现具体细节 3.1. Connection 结构 3.2. 服务器的成员属性 3.2. 服务器的构造 3.3. 事件轮询 3.4. 事件派发 3.5. 连接事件 3.6. 读事件 3.7. 写事件 3.8. 异常事件 4. 服务器上层的处理 5. Reactor 总结 1…