PyTorch从零开始实现Transformer

文章目录

    • 自注意力
    • Transformer块
    • 编码器
    • 解码器块
    • 解码器
    • 整个Transformer
    • 参考来源
    • 全部代码(可直接运行)

自注意力

计算公式

在这里插入图片描述

代码实现


class SelfAttention(nn.Module):def __init__(self, embed_size, heads):super(SelfAttention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsassert (self.head_dim * heads == embed_size),  "Embed size needs  to  be div by heads"self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)self.fc_out = nn.Linear(heads*self.head_dim, embed_size)def forward(self, values, keys, query, mask):N = query.shape[0] # the number of training examplesvalue_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]# Split embedding into self.heads piecesvalues = values.reshape(N, value_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)queries = query.reshape(N, query_len, self.heads, self.head_dim)values = self.values(values)keys = self.keys(keys)queries = self.queries(queries)energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) # 矩阵乘法,使用爱因斯坦标记法# queries shape: (N, query_len, heads, heads_dim)# keys shape: (N, key_len, heads, heads_dim)# energy shape: (N, heads, query_len, key_len)if mask is not None:energy = energy.masked_fill(mask==0, float("-1e20")) #Fills elements of self tensor with value where mask is Trueattention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim) # 矩阵乘法,使用爱因斯坦标记法einsum# attention shape: (N, heads, query_len, key_len)# values shape: (N, value_len, heads, head_dim)# after einsum (N, query_len, heads, head_dim) then flatten last two dimensionsout = self.fc_out(out)return out

Transformer块

我们把Transfomer块定义为如下图所示的结构,这个Transformer块在编码器和解码器中都有出现过。
在这里插入图片描述

代码实现

class TransformerBlock(nn.Module):def __init__(self, embed_size, heads, dropout, forward_expansion):super(TransformerBlock, self).__init__()self.attention = SelfAttention(embed_size, heads)self.norm1 = nn.LayerNorm(embed_size)self.norm2 = nn.LayerNorm(embed_size)self.feed_forward = nn.Sequential(nn.Linear(embed_size, forward_expansion*embed_size),nn.ReLU(),nn.Linear(forward_expansion*embed_size, embed_size))self.dropout = nn.Dropout(dropout)def forward(self, value, key, query, mask):attention = self.attention(value, key, query, mask)x = self.dropout(self.norm1(attention + query))forward = self.feed_forward(x)out = self.dropout(self.norm2(forward + x))return out

编码器

编码器结构如下所示,Inputs经过Input Embedding 和Positional Encoding之后,通过多个Transformer块

在这里插入图片描述

代码实现

class Encoder(nn.Module):def __init__(self, src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,dropout,max_length):super(Encoder, self).__init__()self.embed_size = embed_sizeself.device = deviceself.word_embedding = nn.Embedding(src_vocab_size, embed_size)self.position_embedding = nn.Embedding(max_length, embed_size)self.layers = nn.ModuleList([TransformerBlock(embed_size,heads,dropout=dropout,forward_expansion=forward_expansion)for _ in range(num_layers)])self.dropout = nn.Dropout(dropout)def forward(self, x, mask):N, seq_lengh = x.shapepositions = torch.arange(0, seq_lengh).expand(N, seq_lengh).to(self.device)out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))for layer in self.layers:out = layer(out, out, out, mask)return out

解码器块

解码器块结构如下图所示

在这里插入图片描述

代码实现

class DecoderBlock(nn.Module):def __init__(self, embed_size, heads, forward_expansion, dropout, device):super(DecoderBlock, self).__init__()self.attention = SelfAttention(embed_size, heads)self.norm = nn.LayerNorm(embed_size)self.transformer_block = TransformerBlock(embed_size, heads, dropout, forward_expansion)self.dropout = nn.Dropout(dropout)def forward(self, x, value, key, src_mask, trg_mask):attention = self.attention(x, x, x, trg_mask)query = self.dropout(self.norm(attention + x))out = self.transformer_block(value, key, query, src_mask)return out

解码器

解码器块加上word embedding 和 positional embedding之后构成解码器

在这里插入图片描述

代码实现

class Decoder(nn.Module):def __init__(self, trg_vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, device, max_length):super(Decoder, self).__init__()self.device = deviceself.word_embedding = nn.Embedding(trg_vocab_size, embed_size)self.position_embedding = nn.Embedding(max_length, embed_size)self.layers = nn.ModuleList([DecoderBlock(embed_size, heads, forward_expansion, dropout, device)for _ in range(num_layers)])self.fc_out = nn.Linear(embed_size, trg_vocab_size)self.dropout = nn.Dropout(dropout)def forward(self, x, enc_out, src_mask, trg_mask):N, seq_length = x.shapepositions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))for layer in self.layers:x = layer(x, enc_out, enc_out, src_mask, trg_mask)out = self.fc_out(x)return out

整个Transformer

在这里插入图片描述

代码实现

class Transformer(nn.Module):def __init__(self,src_vocab_size, trg_vocab_size,src_pad_idx,trg_pad_idx,embed_size=256,num_layers=6,forward_expansion=4,heads=8,dropout=0,device="cuda",max_length=100):super(Transformer, self).__init__()self.encoder = Encoder(src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,dropout,max_length)self.decoder = Decoder(trg_vocab_size,embed_size,num_layers,heads,forward_expansion,dropout,device,max_length)self.src_pad_idx = src_pad_idxself.trg_pad_idx = trg_pad_idxself.device = devicedef make_src_mask(self, src):src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)#(N, 1, 1, src_len)return src_mask.to(self.device)def make_trg_mask(self, trg):N, trg_len = trg.shapetrg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N, 1, trg_len, trg_len)return trg_mask.to(self.device)def forward(self, src, trg):src_mask = self.make_src_mask(src)trg_mask = self.make_trg_mask(trg)enc_src = self.encoder(src, src_mask)out = self.decoder(trg, enc_src, src_mask,  trg_mask)return out

参考来源

[1] https://www.youtube.com/watch?v=U0s0f995w14
[2] https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/more_advanced/transformer_from_scratch/transformer_from_scratch.py

[3] https://arxiv.org/abs/1706.03762
[4] https://www.youtube.com/watch?v=pkVwUVEHmfI

全部代码(可直接运行)

import torch
import torch.nn as nnclass SelfAttention(nn.Module):def __init__(self, embed_size, heads):super(SelfAttention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsassert (self.head_dim * heads == embed_size),  "Embed size needs  to  be div by heads"self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)self.fc_out = nn.Linear(heads*self.head_dim, embed_size)def forward(self, values, keys, query, mask):N = query.shape[0] # the number of training examplesvalue_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]# Split embedding into self.heads piecesvalues = values.reshape(N, value_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)queries = query.reshape(N, query_len, self.heads, self.head_dim)values = self.values(values)keys = self.keys(keys)queries = self.queries(queries)energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])# queries shape: (N, query_len, heads, heads_dim)# keys shape: (N, key_len, heads, heads_dim)# energy shape: (N, heads, query_len, key_len)if mask is not None:energy = energy.masked_fill(mask==0, float("-1e20")) #Fills elements of self tensor with value where mask is Trueattention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)# attention shape: (N, heads, query_len, key_len)# values shape: (N, value_len, heads, head_dim)# after einsum (N, query_len, heads, head_dim) then flatten last two dimensionsout = self.fc_out(out)return outclass TransformerBlock(nn.Module):def __init__(self, embed_size, heads, dropout, forward_expansion):super(TransformerBlock, self).__init__()self.attention = SelfAttention(embed_size, heads)self.norm1 = nn.LayerNorm(embed_size)self.norm2 = nn.LayerNorm(embed_size)self.feed_forward = nn.Sequential(nn.Linear(embed_size, forward_expansion*embed_size),nn.ReLU(),nn.Linear(forward_expansion*embed_size, embed_size))self.dropout = nn.Dropout(dropout)def forward(self, value, key, query, mask):attention = self.attention(value, key, query, mask)x = self.dropout(self.norm1(attention + query))forward = self.feed_forward(x)out = self.dropout(self.norm2(forward + x))return outclass Encoder(nn.Module):def __init__(self, src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,dropout,max_length):super(Encoder, self).__init__()self.embed_size = embed_sizeself.device = deviceself.word_embedding = nn.Embedding(src_vocab_size, embed_size)self.position_embedding = nn.Embedding(max_length, embed_size)self.layers = nn.ModuleList([TransformerBlock(embed_size,heads,dropout=dropout,forward_expansion=forward_expansion)for _ in range(num_layers)])self.dropout = nn.Dropout(dropout)def forward(self, x, mask):N, seq_lengh = x.shapepositions = torch.arange(0, seq_lengh).expand(N, seq_lengh).to(self.device)out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))for layer in self.layers:out = layer(out, out, out, mask)return outclass DecoderBlock(nn.Module):def __init__(self, embed_size, heads, forward_expansion, dropout, device):super(DecoderBlock, self).__init__()self.attention = SelfAttention(embed_size, heads)self.norm = nn.LayerNorm(embed_size)self.transformer_block = TransformerBlock(embed_size, heads, dropout, forward_expansion)self.dropout = nn.Dropout(dropout)def forward(self, x, value, key, src_mask, trg_mask):attention = self.attention(x, x, x, trg_mask)query = self.dropout(self.norm(attention + x))out = self.transformer_block(value, key, query, src_mask)return outclass Decoder(nn.Module):def __init__(self, trg_vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, device, max_length):super(Decoder, self).__init__()self.device = deviceself.word_embedding = nn.Embedding(trg_vocab_size, embed_size)self.position_embedding = nn.Embedding(max_length, embed_size)self.layers = nn.ModuleList([DecoderBlock(embed_size, heads, forward_expansion, dropout, device)for _ in range(num_layers)])self.fc_out = nn.Linear(embed_size, trg_vocab_size)self.dropout = nn.Dropout(dropout)def forward(self, x, enc_out, src_mask, trg_mask):N, seq_length = x.shapepositions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))for layer in self.layers:x = layer(x, enc_out, enc_out, src_mask, trg_mask)out = self.fc_out(x)return outclass Transformer(nn.Module):def __init__(self,src_vocab_size, trg_vocab_size,src_pad_idx,trg_pad_idx,embed_size=256,num_layers=6,forward_expansion=4,heads=8,dropout=0,device="cuda",max_length=100):super(Transformer, self).__init__()self.encoder = Encoder(src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,dropout,max_length)self.decoder = Decoder(trg_vocab_size,embed_size,num_layers,heads,forward_expansion,dropout,device,max_length)self.src_pad_idx = src_pad_idxself.trg_pad_idx = trg_pad_idxself.device = devicedef make_src_mask(self, src):src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)#(N, 1, 1, src_len)return src_mask.to(self.device)def make_trg_mask(self, trg):N, trg_len = trg.shapetrg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N, 1, trg_len, trg_len)return trg_mask.to(self.device)def forward(self, src, trg):src_mask = self.make_src_mask(src)trg_mask = self.make_trg_mask(trg)enc_src = self.encoder(src, src_mask)out = self.decoder(trg, enc_src, src_mask,  trg_mask)return outif __name__ == "__main__":device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(device)x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(device)trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(device)src_pad_idx = 0trg_pad_idx = 0src_vocab_size = 10trg_vocab_size = 10model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to(device)out = model(x, trg[:, :-1])print(out.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/5726.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ElasticSearch学习--RestClient及案例

目录 RestClient查询文档 快速入门 总结 全文检索(match)查询 精确查询 复合查询 查询总结 排序,分页 高亮 RestClient查询文档 快速入门 总结 全文检索(match)查询 多种查询的差异都在做类型和条件上&#x…

数据可视化——如何绘制地图

文章目录 前言如何绘制地图添加配置项 根据已有数据绘制地图整体代码展示 前言 前面我们学习了如何利用提供的数据来对数据进行处理,然后以折线图的形式展现出来,那么今天我将为大家分享如何将提数据以地图的形式展现。 如何绘制地图 前面我们绘制折线…

如何从gitee上下载项目并把它在本地运行起来

有时候我们会想到在gitee上下载下来项目,那么怎么把项目下载到本地并跑起来呢? 第一步:在git上找到你想要克隆下来的项目,按照如下操作复制项目地址连接,如下图: 以上可以选择HTTPS和SSH两种形式。 第二步…

REST和RPC的区别

1 REST REST 不是一种协议,它是一种架构。大部分REST的实现中使用了RPC的机制,大致由三部分组成: method:动词(GET、POST、PUT、DELETE之类的)Host:URI(统一资源标识)&…

jmeter压测过程中,ServerAgent响应异常:Cannot send data to network connection

ServerAgent异常信息: Cannot send data to network connection(无法将数据发送到网络连接) 原因: linux 防火墙 拦截了当前端口 解决方案: Linux 执行以下命令 /sbin/iptables -I INPUT -p tcp --dport 4445 -j ACC…

数学建模入门-如何从0开始,掌握数学建模的基本技能

一、前言 本文主要面向没有了解过数学建模的同学,帮助同学们如何快速地进行数学建模的入门并且尽快地在各类赛事中获奖,或者写出优秀的数学建模论文。 在本文中,我将从什么是数学建模、数学建模的应用领域、数学建模的基本步骤、数学建模的技…

【动手学深度学习】--12.深度卷积神经网络AlexNet

文章目录 深度卷积神经网络AlexNet1.AlexNet2.模型设计3.激活函数4.模型实现5.读取数据集6.训练AlexNet 深度卷积神经网络AlexNet 学习视频:深度卷积神经网络 AlexNet【动手学深度学习v2】 官方笔记:深度卷积神经网络(AlexNet) …

Android 中 app freezer 原理详解(一):R 版本

基于版本:Android R 0. 前言 在之前的两篇博文《Android 中app内存回收优化(一)》和 《Android 中app内存回收优化(二)》中详细剖析了 Android 中 app 内存优化的流程。这个机制的管理通过 CachedAppOptimizer 类管理,为什么叫这个名字,而不…

【Linux | Shell】结构化命令2 - test命令、方括号测试条件、case命令

目录 一、概述二、test 命令2.1 test 命令2.2 方括号测试条件2.3 test 命令和测试条件可以判断的 3 类条件2.3.1 数值比较2.3.2 字符串比较 三、复合条件测试四、if-then 的高级特性五、case 命令 一、概述 上篇文章介绍了 if 语句相关知识。但 if 语句只能执行命令&#xff0c…

Docker 的数据管理、容器互联、镜像创建

目录 一、数据管理 1.数据卷 2. 数据卷容器 二、容器互联(使用centos镜像) 三、Docker 镜像的创建 1.基于现有镜像创建 1.1首先启动一个镜像,在容器里修改 1.2将修改后的容器提交为新的镜像,需使用该容器的id号创建新镜像 …

JAVA SE -- 第十天

(全部来自“韩顺平教育”) 一、枚举(enumeration,简写enum) 枚举是一组常量的集合 1、实现方式 a.自定义类实现枚举 b.使用enum关键字实现枚举 二、自定义类实现枚举 1、注意事项 ①不需要提供setXxx方法&#xff…

HTTP、HTTPS协议详解

文章目录 HTTP是什么报文结构请求头部响应头部 工作原理用户点击一个URL链接后,浏览器和web服务器会执行什么http的版本持久连接和非持久连接无状态与有状态Cookie和Sessionhttp方法:get和post的区别 状态码 HTTPS是什么ssl如何搞到证书nginx中的部署 加…

【从删库到跑路】MySQL数据库的索引(一)——索引的结构(BTree B+Tree Hash),语法等

🎊专栏【MySQL】 🍔喜欢的诗句:更喜岷山千里雪 三军过后尽开颜。 🎆音乐分享【如愿】 🥰欢迎并且感谢大家指出小吉的问题 文章目录 🍔概述🍔索引结构⭐B-Tree多路平衡查找树🏳️‍&a…

【iOS】weak关键字的实现原理

前言 关于什么是weak关键字可以去看看我以前的一篇博客:【OC】 属性关键字 weak原理 1. SideTable SideTable 这个结构体,前辈给它总结了一个很形象的名字叫引用计数和弱引用依赖表,因为它主要用于管理对象的引用计数和 weak 表。在 NSOb…

Vite + Vue3 + Ts 【免key、免账号实战本地运行GPT】

🐔 前期回顾 Vue3 Ts Vite —— 封装庆祝彩屑纷飞 示例_彩色之外的博客-CSDN博客封装 彩屑纷飞 示例https://blog.csdn.net/m0_57904695/article/details/131718019?spm1001.2014.3001.5501 目录 🌍 公网 🛹 本地 🪂 源码 &…

【前端|CSS系列第4篇】CSS布局之网格布局

前言 最近在做的一个项目前台首页有一个展示词条的功能,每一个词条都以一个固定大小的词条卡片进行展示,要将所有的词条卡片展示出来,大概是下面这种布局 每一行的卡片数目会随着屏幕大小自动变化,并且希望整个卡片区域周围不要…

20230721 Essex UK, Dongbing Gu 公开讲座--机器人前沿

个人主页: https://www.essex.ac.uk/people/GUDON81301/dongbing-gu 机器人领域任务的特点:dull, dirty, dangerous tasks in remote spaces 机器鱼: 实时港口环境监测 机器鱼群探索算法 化学传感器 水面声呐定位系统/SLAM/通信问题 Robotic …

C—数据的储存(下)

文章目录 前言🌟一、练习一下🌏1.例一🌏2.例二🌏3.例三🌏4.例四 🌟二、浮点型在内存中的储存🌏1.浮点数🌏2.浮点数存储💫(1).二进制浮点数&#x…

QDialog的两种显示方式

QDialog的两种显示方式 模态显示非模态显示 QDialog不能嵌入到其他窗口中显示(无论继承与否) 模态显示 d->exec(); 阻塞程序的执行 非模态显示 d->show(); 不阻塞程序

OpenCV4图像处理-图像交互式分割-GrabCut

本文将实现一个与人(鼠标)交互从而分割背景的程序。 GrabCut 1.理论介绍2. 鼠标交互3. GrabCut 1.理论介绍 用户指定前景的大体区域,剩下为背景区域,还可以明确指出某些地方为前景或者背景,GrabCut算法采用分段迭代的…