PyTorch从零开始实现Transformer

文章目录

    • 自注意力
    • Transformer块
    • 编码器
    • 解码器块
    • 解码器
    • 整个Transformer
    • 参考来源
    • 全部代码(可直接运行)

自注意力

计算公式

在这里插入图片描述

代码实现


class SelfAttention(nn.Module):def __init__(self, embed_size, heads):super(SelfAttention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsassert (self.head_dim * heads == embed_size),  "Embed size needs  to  be div by heads"self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)self.fc_out = nn.Linear(heads*self.head_dim, embed_size)def forward(self, values, keys, query, mask):N = query.shape[0] # the number of training examplesvalue_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]# Split embedding into self.heads piecesvalues = values.reshape(N, value_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)queries = query.reshape(N, query_len, self.heads, self.head_dim)values = self.values(values)keys = self.keys(keys)queries = self.queries(queries)energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) # 矩阵乘法,使用爱因斯坦标记法# queries shape: (N, query_len, heads, heads_dim)# keys shape: (N, key_len, heads, heads_dim)# energy shape: (N, heads, query_len, key_len)if mask is not None:energy = energy.masked_fill(mask==0, float("-1e20")) #Fills elements of self tensor with value where mask is Trueattention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim) # 矩阵乘法,使用爱因斯坦标记法einsum# attention shape: (N, heads, query_len, key_len)# values shape: (N, value_len, heads, head_dim)# after einsum (N, query_len, heads, head_dim) then flatten last two dimensionsout = self.fc_out(out)return out

Transformer块

我们把Transfomer块定义为如下图所示的结构,这个Transformer块在编码器和解码器中都有出现过。
在这里插入图片描述

代码实现

class TransformerBlock(nn.Module):def __init__(self, embed_size, heads, dropout, forward_expansion):super(TransformerBlock, self).__init__()self.attention = SelfAttention(embed_size, heads)self.norm1 = nn.LayerNorm(embed_size)self.norm2 = nn.LayerNorm(embed_size)self.feed_forward = nn.Sequential(nn.Linear(embed_size, forward_expansion*embed_size),nn.ReLU(),nn.Linear(forward_expansion*embed_size, embed_size))self.dropout = nn.Dropout(dropout)def forward(self, value, key, query, mask):attention = self.attention(value, key, query, mask)x = self.dropout(self.norm1(attention + query))forward = self.feed_forward(x)out = self.dropout(self.norm2(forward + x))return out

编码器

编码器结构如下所示,Inputs经过Input Embedding 和Positional Encoding之后,通过多个Transformer块

在这里插入图片描述

代码实现

class Encoder(nn.Module):def __init__(self, src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,dropout,max_length):super(Encoder, self).__init__()self.embed_size = embed_sizeself.device = deviceself.word_embedding = nn.Embedding(src_vocab_size, embed_size)self.position_embedding = nn.Embedding(max_length, embed_size)self.layers = nn.ModuleList([TransformerBlock(embed_size,heads,dropout=dropout,forward_expansion=forward_expansion)for _ in range(num_layers)])self.dropout = nn.Dropout(dropout)def forward(self, x, mask):N, seq_lengh = x.shapepositions = torch.arange(0, seq_lengh).expand(N, seq_lengh).to(self.device)out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))for layer in self.layers:out = layer(out, out, out, mask)return out

解码器块

解码器块结构如下图所示

在这里插入图片描述

代码实现

class DecoderBlock(nn.Module):def __init__(self, embed_size, heads, forward_expansion, dropout, device):super(DecoderBlock, self).__init__()self.attention = SelfAttention(embed_size, heads)self.norm = nn.LayerNorm(embed_size)self.transformer_block = TransformerBlock(embed_size, heads, dropout, forward_expansion)self.dropout = nn.Dropout(dropout)def forward(self, x, value, key, src_mask, trg_mask):attention = self.attention(x, x, x, trg_mask)query = self.dropout(self.norm(attention + x))out = self.transformer_block(value, key, query, src_mask)return out

解码器

解码器块加上word embedding 和 positional embedding之后构成解码器

在这里插入图片描述

代码实现

class Decoder(nn.Module):def __init__(self, trg_vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, device, max_length):super(Decoder, self).__init__()self.device = deviceself.word_embedding = nn.Embedding(trg_vocab_size, embed_size)self.position_embedding = nn.Embedding(max_length, embed_size)self.layers = nn.ModuleList([DecoderBlock(embed_size, heads, forward_expansion, dropout, device)for _ in range(num_layers)])self.fc_out = nn.Linear(embed_size, trg_vocab_size)self.dropout = nn.Dropout(dropout)def forward(self, x, enc_out, src_mask, trg_mask):N, seq_length = x.shapepositions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))for layer in self.layers:x = layer(x, enc_out, enc_out, src_mask, trg_mask)out = self.fc_out(x)return out

整个Transformer

在这里插入图片描述

代码实现

class Transformer(nn.Module):def __init__(self,src_vocab_size, trg_vocab_size,src_pad_idx,trg_pad_idx,embed_size=256,num_layers=6,forward_expansion=4,heads=8,dropout=0,device="cuda",max_length=100):super(Transformer, self).__init__()self.encoder = Encoder(src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,dropout,max_length)self.decoder = Decoder(trg_vocab_size,embed_size,num_layers,heads,forward_expansion,dropout,device,max_length)self.src_pad_idx = src_pad_idxself.trg_pad_idx = trg_pad_idxself.device = devicedef make_src_mask(self, src):src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)#(N, 1, 1, src_len)return src_mask.to(self.device)def make_trg_mask(self, trg):N, trg_len = trg.shapetrg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N, 1, trg_len, trg_len)return trg_mask.to(self.device)def forward(self, src, trg):src_mask = self.make_src_mask(src)trg_mask = self.make_trg_mask(trg)enc_src = self.encoder(src, src_mask)out = self.decoder(trg, enc_src, src_mask,  trg_mask)return out

参考来源

[1] https://www.youtube.com/watch?v=U0s0f995w14
[2] https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/more_advanced/transformer_from_scratch/transformer_from_scratch.py

[3] https://arxiv.org/abs/1706.03762
[4] https://www.youtube.com/watch?v=pkVwUVEHmfI

全部代码(可直接运行)

import torch
import torch.nn as nnclass SelfAttention(nn.Module):def __init__(self, embed_size, heads):super(SelfAttention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsassert (self.head_dim * heads == embed_size),  "Embed size needs  to  be div by heads"self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)self.fc_out = nn.Linear(heads*self.head_dim, embed_size)def forward(self, values, keys, query, mask):N = query.shape[0] # the number of training examplesvalue_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]# Split embedding into self.heads piecesvalues = values.reshape(N, value_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)queries = query.reshape(N, query_len, self.heads, self.head_dim)values = self.values(values)keys = self.keys(keys)queries = self.queries(queries)energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])# queries shape: (N, query_len, heads, heads_dim)# keys shape: (N, key_len, heads, heads_dim)# energy shape: (N, heads, query_len, key_len)if mask is not None:energy = energy.masked_fill(mask==0, float("-1e20")) #Fills elements of self tensor with value where mask is Trueattention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)# attention shape: (N, heads, query_len, key_len)# values shape: (N, value_len, heads, head_dim)# after einsum (N, query_len, heads, head_dim) then flatten last two dimensionsout = self.fc_out(out)return outclass TransformerBlock(nn.Module):def __init__(self, embed_size, heads, dropout, forward_expansion):super(TransformerBlock, self).__init__()self.attention = SelfAttention(embed_size, heads)self.norm1 = nn.LayerNorm(embed_size)self.norm2 = nn.LayerNorm(embed_size)self.feed_forward = nn.Sequential(nn.Linear(embed_size, forward_expansion*embed_size),nn.ReLU(),nn.Linear(forward_expansion*embed_size, embed_size))self.dropout = nn.Dropout(dropout)def forward(self, value, key, query, mask):attention = self.attention(value, key, query, mask)x = self.dropout(self.norm1(attention + query))forward = self.feed_forward(x)out = self.dropout(self.norm2(forward + x))return outclass Encoder(nn.Module):def __init__(self, src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,dropout,max_length):super(Encoder, self).__init__()self.embed_size = embed_sizeself.device = deviceself.word_embedding = nn.Embedding(src_vocab_size, embed_size)self.position_embedding = nn.Embedding(max_length, embed_size)self.layers = nn.ModuleList([TransformerBlock(embed_size,heads,dropout=dropout,forward_expansion=forward_expansion)for _ in range(num_layers)])self.dropout = nn.Dropout(dropout)def forward(self, x, mask):N, seq_lengh = x.shapepositions = torch.arange(0, seq_lengh).expand(N, seq_lengh).to(self.device)out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))for layer in self.layers:out = layer(out, out, out, mask)return outclass DecoderBlock(nn.Module):def __init__(self, embed_size, heads, forward_expansion, dropout, device):super(DecoderBlock, self).__init__()self.attention = SelfAttention(embed_size, heads)self.norm = nn.LayerNorm(embed_size)self.transformer_block = TransformerBlock(embed_size, heads, dropout, forward_expansion)self.dropout = nn.Dropout(dropout)def forward(self, x, value, key, src_mask, trg_mask):attention = self.attention(x, x, x, trg_mask)query = self.dropout(self.norm(attention + x))out = self.transformer_block(value, key, query, src_mask)return outclass Decoder(nn.Module):def __init__(self, trg_vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, device, max_length):super(Decoder, self).__init__()self.device = deviceself.word_embedding = nn.Embedding(trg_vocab_size, embed_size)self.position_embedding = nn.Embedding(max_length, embed_size)self.layers = nn.ModuleList([DecoderBlock(embed_size, heads, forward_expansion, dropout, device)for _ in range(num_layers)])self.fc_out = nn.Linear(embed_size, trg_vocab_size)self.dropout = nn.Dropout(dropout)def forward(self, x, enc_out, src_mask, trg_mask):N, seq_length = x.shapepositions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))for layer in self.layers:x = layer(x, enc_out, enc_out, src_mask, trg_mask)out = self.fc_out(x)return outclass Transformer(nn.Module):def __init__(self,src_vocab_size, trg_vocab_size,src_pad_idx,trg_pad_idx,embed_size=256,num_layers=6,forward_expansion=4,heads=8,dropout=0,device="cuda",max_length=100):super(Transformer, self).__init__()self.encoder = Encoder(src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,dropout,max_length)self.decoder = Decoder(trg_vocab_size,embed_size,num_layers,heads,forward_expansion,dropout,device,max_length)self.src_pad_idx = src_pad_idxself.trg_pad_idx = trg_pad_idxself.device = devicedef make_src_mask(self, src):src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)#(N, 1, 1, src_len)return src_mask.to(self.device)def make_trg_mask(self, trg):N, trg_len = trg.shapetrg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N, 1, trg_len, trg_len)return trg_mask.to(self.device)def forward(self, src, trg):src_mask = self.make_src_mask(src)trg_mask = self.make_trg_mask(trg)enc_src = self.encoder(src, src_mask)out = self.decoder(trg, enc_src, src_mask,  trg_mask)return outif __name__ == "__main__":device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(device)x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(device)trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(device)src_pad_idx = 0trg_pad_idx = 0src_vocab_size = 10trg_vocab_size = 10model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to(device)out = model(x, trg[:, :-1])print(out.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/5726.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Windows Spark 开发测试版本快速搭建

1、Spark 包下载 清华大学开源软件镜像站下载(速度较快,但版本不全)官方各个版本 下载后解压即可。 (可选)添加环境变量 SPARK_HOME。并将 %SPARK_HOME%/bin、%SPARK_HOME%/sbin 添加到 path 中。 ps:本文使用的是 spark-3.3.0…

【unity】RectTransform与Transform组件有什么区别

RectTransform组件是Unity中用于控制UI元素位置、大小和旋转的组件。它是UI系统的一部分,用于在屏幕空间中布局和定位UI元素。 与transform组件相比,RectTransform组件具有以下区别: 1. 坐标系统:RectTransform组件使用屏幕空间坐…

线上问题排查-dubbo-Dubbo client can not supported string message

1.问题描述 线上值班时,收到告警通知dubbo 调用异常。 主要报错包括下面两条记录: [DUBBO] Dubbo client can not supported string message: [ERROR] [New I/O client worker #1-3] com.alibaba.dubbo.remoting.transport.AbstractCodec - Data len…

一起学SF框架系列5.8-spring-Beans-Bean注解解析3-解析配置component-scan

本文主要讲述Spring是如何解析“context:component-scan”元素&#xff0c;扫描加载目录下的BeanDefinition。 解析内容 1、解析的元素如下&#xff1a; <!-- 注解模式&#xff1a;配置bean扫描路径&#xff08;注&#xff1a;自动包含子路径&#xff09; --><conte…

SpringAOP面向切面编程 通知类型

一、通知类型 Around&#xff1a;环绕通知&#xff0c;此注解标注的通知方法在目标方法前、后都被执行 Before&#xff1a;前置通知&#xff0c;此注解标注的通知方法在目标方法前被执行 After&#xff1a;后置通知&#xff0c;此注解标注的通知方法在目标方法后被执行&…

Linux-tomcat环境搭建、jpress部署实践、nginx反向代理

♥️作者&#xff1a;小刘在C站 ♥️个人主页&#xff1a; 小刘主页 ♥️努力不一定有回报&#xff0c;但一定会有收获加油&#xff01;一起努力&#xff0c;共赴美好人生&#xff01; ♥️学习两年总结出的运维经验&#xff0c;以及思科模拟器全套网络实验教程。专栏&#xf…

分类评估指标

文章目录 1. 混淆矩阵2. Precision(精准率)3. Recall(召回率)4. F1-score5. ROC曲线和AUC指标5.1 ROC 曲线5.2 绘制 ROC 曲线5.3 AUC 值6. API介绍6.1 **分类评估报告api**6.2 **AUC计算API**练习-电信客户流失预测1. 数据集介绍2. 处理流程3. 案例实现4. 小结1. 混淆矩阵

ElasticSearch学习--RestClient及案例

目录 RestClient查询文档 快速入门 总结 全文检索&#xff08;match&#xff09;查询 精确查询 复合查询 查询总结 排序&#xff0c;分页 高亮 RestClient查询文档 快速入门 总结 全文检索&#xff08;match&#xff09;查询 多种查询的差异都在做类型和条件上&#x…

[SQL挖掘机] - HAVING语句

经过对分组函数 group by的介绍, 往往少不了having, 所以这篇文章主要介绍having语句. group by分组函数可以查看如下链接; https://tongjier.blog.csdn.net/article/details/131885281 介绍: having 是用于在 group by 查询中对分组结果进行过滤的子句。它允许我们在分组之后…

数据可视化——如何绘制地图

文章目录 前言如何绘制地图添加配置项 根据已有数据绘制地图整体代码展示 前言 前面我们学习了如何利用提供的数据来对数据进行处理&#xff0c;然后以折线图的形式展现出来&#xff0c;那么今天我将为大家分享如何将提数据以地图的形式展现。 如何绘制地图 前面我们绘制折线…

如何从gitee上下载项目并把它在本地运行起来

有时候我们会想到在gitee上下载下来项目&#xff0c;那么怎么把项目下载到本地并跑起来呢&#xff1f; 第一步&#xff1a;在git上找到你想要克隆下来的项目&#xff0c;按照如下操作复制项目地址连接&#xff0c;如下图&#xff1a; 以上可以选择HTTPS和SSH两种形式。 第二步…

REST和RPC的区别

1 REST REST 不是一种协议&#xff0c;它是一种架构。大部分REST的实现中使用了RPC的机制&#xff0c;大致由三部分组成&#xff1a; method&#xff1a;动词&#xff08;GET、POST、PUT、DELETE之类的&#xff09;Host&#xff1a;URI&#xff08;统一资源标识&#xff09;&…

Hologres SQL

1.SQL基础 1.1 DDL 创建数据库 CREATE DATABASE db_name [[WITH] OWNER [] user_name];创建者自动成为新DB的owner用户需要有CREATEDB权限&#xff08;或者superuser&#xff09; 删除数据库 DROP DATABASE [IF EXISTS] db_name;只有该数据库的superuser或者该db的owner才…

jmeter压测过程中,ServerAgent响应异常:Cannot send data to network connection

ServerAgent异常信息&#xff1a; Cannot send data to network connection&#xff08;无法将数据发送到网络连接&#xff09; 原因&#xff1a; linux 防火墙 拦截了当前端口 解决方案&#xff1a; Linux 执行以下命令 /sbin/iptables -I INPUT -p tcp --dport 4445 -j ACC…

数学建模入门-如何从0开始,掌握数学建模的基本技能

一、前言 本文主要面向没有了解过数学建模的同学&#xff0c;帮助同学们如何快速地进行数学建模的入门并且尽快地在各类赛事中获奖&#xff0c;或者写出优秀的数学建模论文。 在本文中&#xff0c;我将从什么是数学建模、数学建模的应用领域、数学建模的基本步骤、数学建模的技…

【python】如何包装 numpy 的数组

一、说明 Numpy的数组是强大的对象&#xff0c;通常用作更复杂的对象&#xff08;如pandas或xarray&#xff09;的基本数据结构。话虽如此&#xff0c;您当然也可以在自己的类中使用numpy的强大数组 - 为此&#xff0c;您基本上有2种方法&#xff1a; 子类方法&#xff1a;创建…

【动手学深度学习】--12.深度卷积神经网络AlexNet

文章目录 深度卷积神经网络AlexNet1.AlexNet2.模型设计3.激活函数4.模型实现5.读取数据集6.训练AlexNet 深度卷积神经网络AlexNet 学习视频&#xff1a;深度卷积神经网络 AlexNet【动手学深度学习v2】 官方笔记&#xff1a;深度卷积神经网络&#xff08;AlexNet&#xff09; …

Android 中 app freezer 原理详解(一):R 版本

基于版本&#xff1a;Android R 0. 前言 在之前的两篇博文《Android 中app内存回收优化(一)》和 《Android 中app内存回收优化(二)》中详细剖析了 Android 中 app 内存优化的流程。这个机制的管理通过 CachedAppOptimizer 类管理&#xff0c;为什么叫这个名字&#xff0c;而不…

【Linux | Shell】结构化命令2 - test命令、方括号测试条件、case命令

目录 一、概述二、test 命令2.1 test 命令2.2 方括号测试条件2.3 test 命令和测试条件可以判断的 3 类条件2.3.1 数值比较2.3.2 字符串比较 三、复合条件测试四、if-then 的高级特性五、case 命令 一、概述 上篇文章介绍了 if 语句相关知识。但 if 语句只能执行命令&#xff0c…

Docker 的数据管理、容器互联、镜像创建

目录 一、数据管理 1.数据卷 2. 数据卷容器 二、容器互联&#xff08;使用centos镜像&#xff09; 三、Docker 镜像的创建 1.基于现有镜像创建 1.1首先启动一个镜像&#xff0c;在容器里修改 1.2将修改后的容器提交为新的镜像&#xff0c;需使用该容器的id号创建新镜像 …