【变形金刚03】使用 Pytorch 开始构建transformer

 

一、说明

        在本教程中,我们将使用 PyTorch 从头开始构建一个基本的转换器模型。Vaswani等人在论文“注意力是你所需要的一切”中引入的Transformer模型是一种深度学习架构,专为序列到序列任务而设计,例如机器翻译和文本摘要。它基于自我注意机制,已成为许多最先进的自然语言处理模型的基础,如GPT和BERT。

二、准备活动

        若要生成转换器模型,我们将按照以下步骤操作:

  1. 导入必要的库和模块
  2. 定义基本构建块:多头注意力、位置前馈网络、位置编码
  3. 构建编码器和解码器层
  4. 组合编码器和解码器层以创建完整的转换器模型
  5. 准备示例数据
  6. 训练模型

        让我们从导入必要的库和模块开始。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

现在,我们将定义转换器模型的基本构建基块。

三、多头注意力

图2.多头注意力(来源:作者创建的图像)

        多头注意力机制计算序列中每对位置之间的注意力。它由多个“注意头”组成,用于捕获输入序列的不同方面。

class MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()assert d_model % num_heads == 0, "d_model must be divisible by num_heads"self.d_model = d_modelself.num_heads = num_headsself.d_k = d_model // num_headsself.W_q = nn.Linear(d_model, d_model)self.W_k = nn.Linear(d_model, d_model)self.W_v = nn.Linear(d_model, d_model)self.W_o = nn.Linear(d_model, d_model)def scaled_dot_product_attention(self, Q, K, V, mask=None):attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)if mask is not None:attn_scores = attn_scores.masked_fill(mask == 0, -1e9)attn_probs = torch.softmax(attn_scores, dim=-1)output = torch.matmul(attn_probs, V)return outputdef split_heads(self, x):batch_size, seq_length, d_model = x.size()return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)def combine_heads(self, x):batch_size, _, seq_length, d_k = x.size()return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)def forward(self, Q, K, V, mask=None):Q = self.split_heads(self.W_q(Q))K = self.split_heads(self.W_k(K))V = self.split_heads(self.W_v(V))attn_output = self.scaled_dot_product_attention(Q, K, V, mask)output = self.W_o(self.combine_heads(attn_output))return output

        MultiHeadAttention 代码使用输入参数和线性变换层初始化模块。它计算注意力分数,将输入张量重塑为多个头部,并将所有头部的注意力输出组合在一起。前向方法计算多头自我注意,允许模型专注于输入序列的某些不同方面。

四、位置前馈网络

class PositionWiseFeedForward(nn.Module):def __init__(self, d_model, d_ff):super(PositionWiseFeedForward, self).__init__()self.fc1 = nn.Linear(d_model, d_ff)self.fc2 = nn.Linear(d_ff, d_model)self.relu = nn.ReLU()def forward(self, x):return self.fc2(self.relu(self.fc1(x)))

PositionWiseFeedForward 类扩展了 PyTorch 的 nn。模块并实现按位置的前馈网络。该类使用两个线性转换层和一个 ReLU 激活函数进行初始化。forward 方法按顺序应用这些转换和激活函数来计算输出。此过程使模型能够在进行预测时考虑输入元素的位置。

五、位置编码

        位置编码用于注入输入序列中每个令牌的位置信息。它使用不同频率的正弦和余弦函数来生成位置编码。

class PositionalEncoding(nn.Module):def __init__(self, d_model, max_seq_length):super(PositionalEncoding, self).__init__()pe = torch.zeros(max_seq_length, d_model)position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)self.register_buffer('pe', pe.unsqueeze(0))def forward(self, x):return x + self.pe[:, :x.size(1)]

PositionalEncoding 类使用 d_model 和 max_seq_length 输入参数进行初始化,从而创建一个张量来存储位置编码值。该类根据比例因子div_term分别计算偶数和奇数指数的正弦和余弦值。前向方法通过将存储的位置编码值添加到输入张量中来计算位置编码,从而使模型能够捕获输入序列的位置信息。

现在,我们将构建编码器层和解码器层。

六、编码器层

图3.变压器网络的编码器部分(来源:图片来自原文)

编码器层由多头注意力层、位置前馈层和两个层归一化层组成。

class EncoderLayer(nn.Module):def __init__(self, d_model, num_heads, d_ff, dropout):super(EncoderLayer, self).__init__()self.self_attn = MultiHeadAttention(d_model, num_heads)self.feed_forward = PositionWiseFeedForward(d_model, d_ff)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, mask):attn_output = self.self_attn(x, x, x, mask)x = self.norm1(x + self.dropout(attn_output))ff_output = self.feed_forward(x)x = self.norm2(x + self.dropout(ff_output))return x

类使用输入参数和组件进行初始化,包括一个多头注意模块、一个 PositionWiseFeedForward 模块、两个层规范化模块和一个 dropout 层。前向方法通过应用自注意、将注意力输出添加到输入张量并规范化结果来计算编码器层输出。然后,它计算按位置的前馈输出,将其与归一化的自我注意输出相结合,并在返回处理后的张量之前对最终结果进行归一化。

七、解码器层

图4.变压器网络的解码器部分(Souce:图片来自原始论文)

解码器层由两个多头注意力层、一个位置前馈层和三个层归一化层组成。

class DecoderLayer(nn.Module):def __init__(self, d_model, num_heads, d_ff, dropout):super(DecoderLayer, self).__init__()self.self_attn = MultiHeadAttention(d_model, num_heads)self.cross_attn = MultiHeadAttention(d_model, num_heads)self.feed_forward = PositionWiseFeedForward(d_model, d_ff)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.norm3 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, enc_output, src_mask, tgt_mask):attn_output = self.self_attn(x, x, x, tgt_mask)x = self.norm1(x + self.dropout(attn_output))attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)x = self.norm2(x + self.dropout(attn_output))ff_output = self.feed_forward(x)x = self.norm3(x + self.dropout(ff_output))return x

解码器层使用输入参数和组件进行初始化,例如用于屏蔽自我注意和交叉注意力的多头注意模块、PositionWiseFeedForward 模块、三层归一化模块和辍学层。

转发方法通过执行以下步骤来计算解码器层输出:

  1. 计算掩蔽的自我注意输出并将其添加到输入张量中,然后进行 dropout 和层归一化。
  2. 计算解码器和编码器输出之间的交叉注意力输出,并将其添加到规范化的掩码自注意力输出中,然后进行 dropout 和层规范化。
  3. 计算按位置的前馈输出,并将其与归一化交叉注意力输出相结合,然后是压差和层归一化。
  4. 返回已处理的张量。

这些操作使解码器能够根据输入和编码器输出生成目标序列。

现在,让我们组合编码器和解码器层来创建完整的转换器模型。

八、变压器型号

图5.The Transformer Network(来源:图片来源于原文)

将它们全部合并在一起:

class Transformer(nn.Module):def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):super(Transformer, self).__init__()self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)self.positional_encoding = PositionalEncoding(d_model, max_seq_length)self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])self.fc = nn.Linear(d_model, tgt_vocab_size)self.dropout = nn.Dropout(dropout)def generate_mask(self, src, tgt):src_mask = (src != 0).unsqueeze(1).unsqueeze(2)tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)seq_length = tgt.size(1)nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()tgt_mask = tgt_mask & nopeak_maskreturn src_mask, tgt_maskdef forward(self, src, tgt):src_mask, tgt_mask = self.generate_mask(src, tgt)src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))enc_output = src_embeddedfor enc_layer in self.encoder_layers:enc_output = enc_layer(enc_output, src_mask)dec_output = tgt_embeddedfor dec_layer in self.decoder_layers:dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)output = self.fc(dec_output)return output

类组合了以前定义的模块以创建完整的转换器模型。在初始化期间,Transformer 模块设置输入参数并初始化各种组件,包括源序列和目标序列的嵌入层、位置编码模块、用于创建堆叠层的编码器层和解码器层模块、用于投影解码器输出的线性层以及 dropout 层。

generate_mask 方法为源序列和目标序列创建二进制掩码,以忽略填充标记并防止解码器处理将来的令牌。前向方法通过以下步骤计算转换器模型的输出:

  1. 使用 generate_mask 方法生成源掩码和目标掩码。
  2. 计算源和目标嵌入,并应用位置编码和丢弃。
  3. 通过编码器层处理源序列,更新enc_output张量。
  4. 通过解码器层处理目标序列,使用enc_output和掩码,并更新dec_output张量。
  5. 将线性投影层应用于解码器输出,获取输出对数。

这些步骤使转换器模型能够处理输入序列,并根据其组件的组合功能生成输出序列。

九、准备样本数据

        在此示例中,我们将创建一个用于演示目的的玩具数据集。实际上,您将使用更大的数据集,预处理文本,并为源语言和目标语言创建词汇映射。

src_vocab_size = 5000
tgt_vocab_size = 5000
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length = 100
dropout = 0.1transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)# Generate random sample data
src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)
tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)

十、训练模型

        现在,我们将使用示例数据训练模型。在实践中,您将使用更大的数据集并将其拆分为训练集和验证集。

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)transformer.train()for epoch in range(100):optimizer.zero_grad()output = transformer(src_data, tgt_data[:, :-1])loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[:, 1:].contiguous().view(-1))loss.backward()optimizer.step()print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

        我们可以使用这种方式在 Pytorch 中从头开始构建一个简单的转换器。所有大型语言模型都使用这些转换器编码器或解码器块进行训练。因此,了解启动这一切的网络非常重要。希望本文能帮助所有希望深入了解LLM的人。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/39350.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

iOS Epub阅读器改造记录

六个月前在这个YHEpubDemo阅读器的基础上做了一些优化,这里做一下记录。 1.首行缩进修复 由于分页的存在,新的一页的首行可能是新的一行,则应该缩进;也可能是前面一页段落的延续,这时候不应该缩进。YHEpubDemo基于XDS…

pycharm,VSCode 几个好用的插件

pycharm Tabnine AI Code 可以在编写程序的时候为你提供一些快捷方式,增加编程速度 Chinese 对英文不好的程序员来说是个不错的选择,可以将英文状态下的pycharm变为中文版的 ChatGPT 可以跟ai聊天,ai可以解决你80%的问题 ,也可以帮…

变形金刚:从零开始【01/2】

一、说明 在我们的日常生活中,无论你是否是数据科学家,你都在单向地使用变压器模型。例如。如果您使用的是 ChatGPT 或 GPT-4 或任何 GPT,那么在为您回答问题的框中是变压器的一部分。如果您是数据科学家或数据分析师,则可能正在使…

【BASH】回顾与知识点梳理(二十九)

【BASH】回顾与知识点梳理 二十九 二十九. 进程和工作管理29.1 什么是进程 (process)进程与程序 (process & program)子进程与父进程:fork and exec:进程呼叫的流程系统或网络服务:常驻在内存的进程 29.2 Linux 的多人多任务环境多人环境…

SAP MM学习笔记23-购买发注的账户分配类型(勘定Category)

SAP中控制财务凭证过账科目的是 账号分配类型(勘定Category)栏目。 ・账号分配类型(勘定Category)有: 1,K 原价Center(成本中心。用于消耗物料采购 的过账) 2,E 得意先…

LabVIEW对并行机器人结构进行建模仿真

LabVIEW对并行机器人结构进行建模仿真 为了对复杂机器人结构的数学模型进行建模、搜索、动画和验证,在工业机器人动态行为实验室中,设计并实现了具有五个自由度的单臂型机器人。在研究台上可以区分以下元素:带有直流电机和编码器的机器人;稳…

nvm管理node版本

nvm是什么? NVM全名叫做 nodejs version manage,即Node的版本管理工具。 使用NVM,可以通过命令很方便地在多个NodeJS版本之间进行切换。 nvm的下载与安装 下载地址:Releases coreybutler/nvm-windows (github.com) windows系统下载nvm-setup…

Arcgis中直接通过sde更新sqlserver空间数据库失败

问题 背景 不知道有没有人经历过这样一个情况,我们直接在Arcgis中通过sde更新serserver数据库会失败,就是虽然在sde更新sqlserver数据库,但是在Navicat中通过sql语句来查询,发现数据并没有更新,如:上图中,更新数据库后,第一张图是sde打开的sqlserver数据库,它的数据库…

项目管理工具和方法有哪些:了解项目管理的必备工具和有效方法

先谈谈什么是项目管理,简单直白,就是对项目进行管理。项目管理涉及有效的计划和对工作的系统管理,但很多工具可以使项目管理更有效、更高效。比如,Zoho Projects项目管理工具。 1.项目合理拆解 当确定了项目目标后,无疑…

我国农机自动驾驶系统需求日益增长,北斗系统赋能精准农业

中国现代农业的发展,离不开智能化、自动化设备,迫切需要自动驾驶系统与农用机械的密切结合。自动驾驶农机不仅能够缓解劳动力短缺问题,提升劳作生产效率,同时还能对农业进行智慧化升级,成为解决当下农业痛点的有效手段…

Pycharm社区版连接WSL2中的Mysql8.*

当前时间2023.08.13,Windows11中默认的WSL版本已经是2了,在WSL2中默认的Ubuntu版本已经是22.04,而Ubuntu22.04中默认的Mysql版本已经是8.*。 Wsl 2 中安装mysql WSL2中安装Mysql的方法参考自微软官方文档【开始使用适用于 Linux 的 Windows …

vector【2】模拟实现(超详解哦)

vector 引言(实现概述)接口实现详解默认成员函数构造函数析构函数赋值重载 迭代器容量size与capacityreserveresizeempty 元素访问数据修改inserterasepush_back与pop_backswap 模拟实现源码概览总结 引言(实现概述) 在前面&…

分布式定时任务系列5:XXL-job中blockingQueue的应用

传送门 分布式定时任务系列1:XXL-job安装 分布式定时任务系列2:XXL-job使用 分布式定时任务系列3:任务执行引擎设计 分布式定时任务系列4:任务执行引擎设计续 Java并发编程实战1:java中的阻塞队列 引子 这篇文章的…

MATLAB计算一组坐标点的相互距离(pdist、squareform、pdist2函数)

如果有一组坐标P(X,Y),包含多个点的X和Y坐标,计算其坐标点之间的相互距离 一、坐标点 P[1 1;5 2;3 6;8 8;4 5;5 1; 6 9];二、pdist函数 输出的结果是一维数组,获得任意两个坐标之间的距离,但没有对应关系 Dpdist(P)三、square…

JavaWeb-Servlet服务连接器(二)

目录 Request(获取请求信息) 1.获取请求行内容 2.解决乱码问题 3.获取请求头部分 4.获取请求体 5.其他功能 Request(获取请求信息) 工作流程: 1.通过请求的url的资源路径,tomcat会生成相应的Servlet实…

【单片机】DS2431,STM32,EEPROM读取与写入

芯片介绍: https://qq742971636.blog.csdn.net/article/details/132164189 接线 串口结果: 部分代码: #include "sys.h" #include "DS2431.h"unsigned char serialNb[8]; unsigned char write_data[128]; unsigned cha…

STM32入门学习之定时器输入捕获

1.定时器的输入捕获可以用来测量脉冲宽度或者测量频率。输入捕获的原理图如下: 假设定时器是向上计数。在图中,t1~t2之间的便是我们要测量的高电平的时间(脉冲宽度)。首先,设置定时器为上升沿捕获,如此一来,在t1时刻可…

AI 绘画Stable Diffusion 研究(九)sd图生图功能详解-老照片高清修复放大

大家好,我是风雨无阻。 通过前面几篇文章的介绍,相信各位小伙伴,对 Stable Diffusion 这款强大的AI 绘图系统有了全新的认知。我们见识到了借助 Stable Diffusion的文生图功能,利用简单的几个单词,就可以生成完美的图片…

阿里云OSS对象存储的核心概念与购买应用

文章目录 1.OSS对象存储基本介绍1.1.OSS对象存储概念1.2.NAS与OSS存储的不同1.3.OSS的应用场景1.4.OSS术语对应表 2.购买OSS存储资源包3.KodCloud云盘接入OSS对象存储3.1.创建Bucket存储空间3.2.创建子用户用于管理Bucket3.3.获取用户的AccessKey3.3.为用户设置权限3.4.将Bucke…

MySQL和Redis如何保证数据一致性

MySQL与Redis都是常用的数据存储和缓存系统。为了提高应用程序的性能和可伸缩性,很多应用程序将MySQL和Redis一起使用,其中MySQL作为主要的持久存储,而Redis作为主要的缓存。在这种情况下,应用程序需要确保MySQL和Redis中的数据是…