Transformer架构实现一

从0-1搭建Transformer架构
架构图
在这里插入图片描述

本文主要讲解
1)输入层的词嵌入
2)输入层的位置编码
3)编码层的多头注意力机制
4)编码层的前馈全连接

1)输入层的词嵌入

class Embeddings(nn.Module):"""构建embedding类实现文本嵌入"""def __init__(self, d_model, vocab):# d_model: 词嵌入维度# vocab: 词表的大小super(Embeddings, self).__init__()self.lut = nn.Embedding(vocab, d_model)self.d_model = d_modeldef forward(self, x):return self.lut(x) * math.sqrt(self.d_model)

2)输入层的位置编码

class PositionalEncoding(nn.Module):"""位置编码"""def __init__(self, d_model, pad_size=5000):# d_model 词嵌入维度# pad_size 默认词汇大小super(PositionalEncoding, self).__init__()self.d_model = d_modelself.pad_size = pad_sizepe = torch.zeros(pad_size, d_model)for t in range(pad_size):for i in range(d_model // 2):angle_rate = 1 / (10000 ** (2 * i / d_model))pe[t, 2 * i] = np.sin(t * angle_rate)pe[t, 2 * i + 1] = np.cos(t * angle_rate)# # 双层循环等价写法# pe = torch.tensor(#     [[pad / (10000.0 ** (i // 2 * 2.0 / d_model)) for i in range(d_model)] for pad in range(pad_size)])## pe[:, 0::2] = np.sin(pe[:, 0::2])# pe[:, 1::2] = np.cos(pe[:, 1::2])# 将位置编码扩展到三维pe = pe.unsqueeze(0)# 将位置编码矩阵注册成模型的buffer,buffer不是模型的参数,不跟随优化器更新# 注册成buffer后,在模型保存后重新加载模型的时候,将这个位置编码将和参数一起加载进来self.register_buffer('pe', pe)def forward(self, x):# 位置编码不需要反向更新x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)return x

3)编码层的多头注意力机制
三个辅助函数注意力机制、module拷贝函数、

def attention(q, k, v, dropout=None, mask=None):# 计算公式 AT(Q,K,V) = softmax(\frac{QK^{T}}{\sqrt{d_k}})V# 词嵌入维度d_k = q.shape[-1]score = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)if mask is not None:score = score.masked_fill(mask == 0, -1e6)score = F.softmax(score, dim=-1)if dropout is not None:score = dropout(score)return torch.matmul(score, v), scoredef clones(module, N):""":param module: 需要复制的网络模块:param N: copy数量"""return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])class SublayerConnection(nn.Module):""" 子层连接结构,根据传入的sublayer(实例对象)处理在编码层sublayer可以是多头注意机制或者前馈全连接在解码层sublayer也可以是带有掩码的多头注意力机制SublayerConnection处理流程:规范化 -> 掩码多头/多头/前馈 -> 残差连接"""def __init__(self, d_k, dropout=0.1):super(SublayerConnection, self).__init__()self.norm = nn.LayerNorm(d_k)self.dropout = nn.Dropout(p=dropout)def forward(self, x, sublayer):# 先规范化处理在由具体子层函数处理out = sublayer(self.norm(x))out = self.dropout(out)# 残差连接return x + out

多头注意力机制

class MultiHeadAttention(nn.Module):"""多头注意力机制"""def __init__(self, d_k, head_num, dropout=0.0):super(MultiHeadAttention, self).__init__()self.d_k = d_kself.head_num = head_numassert d_k % head_num == 0self.head_dim = d_k // head_numself.dropout = nn.Dropout(p=dropout)# 深度copy4个线性层,3个用于Q、K、V矩阵,一个将用于指定维度转换self.linears = clones(nn.Linear(d_k, d_k), 4)self.attn = Nonedef forward(self, query, key, value, mask=None):if mask is not None:mask = mask.unsqueeze(0)batch_size = query.size(0)# 三个线性层对输入进行进行隐空间特征提取query, key, value = \[model(x).view(batch_size, -1, self.head_num, self.head_dim).transpose(1, 2) for model, x inzip(self.linears, (query, key, value))]score, self.attn = attention(query, key, value, dropout=self.dropout, mask=mask)score = score.transpose(1, 2).contiguous().view(batch_size, -1, self.head_dim * self.head_num)return self.linears[-1](score)# 多头注意力机制的另一种实现 建议理解这一个代码,比较好理解# def forward2(self, query, key, value, mask=None):#     if mask is not None:#         mask = mask.unsqueeze(0)#     batch_size = query.size(0)#     query, key, value = \#         [model(x).view(batch_size * self.head_num, -1, self.head_dim) for model, x in#          zip(self.linears, (query, key, value))]#     score, self.attn = attention(query, key, value, dropout=self.dropout, mask=mask)#     score = score.view(batch_size, -1, self.head_dim * self.head_num)#     return self.linears[-1](score)

前馈全连接

class PositionalWiseFeedForward(nn.Module):"""前馈全连接"""def __init__(self, d_k, hidden_size, dropout=0.1):super(PositionalWiseFeedForward, self).__init__()self.w1 = nn.Linear(d_k, hidden_size)self.w2 = nn.Linear(hidden_size, d_k)self.dropout = nn.Dropout(p=dropout)def forward(self, x):out = self.w1(x)out = F.relu(out)out = self.dropout(out)return self.w2(out)

编码层

class EncoderLayer(nn.Module):""" 子层连接结构,将多头注意力机制和前馈全连接组装"""def __init__(self, d_k, attn, feed_forward, dropout):"""attn 多头注意力实例feed_forward 前馈全连接实例dropout 置零率实例"""super(EncoderLayer, self).__init__()self.attn = attnself.feed_forward = feed_forward# 拷贝2个子层连接结构,具体处理方式(多头/前馈)调用时指定self.sublayer = clones(SublayerConnection(d_k, dropout), 2)# 保存词嵌入维度,方便后续使用self.size = d_kdef forward(self, x, mask):""" 先走多头注意力机制,在过前馈全连接。 Transformer编码顺序"""x = self.sublayer[0](x, lambda x: self.attn(x, x, x, mask))return self.sublayer[1](x, self.feed_forward)

编码器实现

class Encoder(nn.Module):""" 编码器实现,N个编码层EncoderLayer的堆叠"""def __init__(self, encoder_layer, N):super(Encoder, self).__init__()self.layers = clones(encoder_layer, N)# 使用自定义规范会层 encoder_layer.size 词嵌入维度self.norn = LayerNorm(encoder_layer.size)# torch中规范会层# self.norn = nn.LayerNorm(encoder_layer.size)def forward(self, x, mask=None):for layer in self.layers:x = layer(x, mask)return self.norn(x)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/820069.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

a == 1 a== 2 a== 3 返回 true ?

1. 前言 下面这道题是 阿里、百度、腾讯 三个大厂都出过的面试题,一个前端同事跳槽面试也被问了这道题 // ? 位置应该怎么写,才能输出 trueconst a ?console.log(a 1 && a 2 && a 3) 看了大厂的面试题会对面试官的精神…

git操作基本命令

Git命令操作: 1、服务器上面有新的修改,pull出现错误操作如下 git stash git pull origin master git stash pop 2、删除本地一个文件test.py,想重新download远程服务器最新的文件 #git checkout test.py 3、查看当前处于哪一个分支 #git …

数码相框-显示JPG图片

LCD控制器会将LCD上的屏幕数据映射在相应的显存位置上。 通过libjpeg把jpg图片解压出来RGB原始数据。 libjpeg是使用c语言实现的读写jpeg文件的库。 使用libjpeg的应用程序是以"scanline"为单位进行图像处理的。 libjpeg解压图片的步骤: libjpeg的使…

【御控物联】物联网平台设备接入-JSON数据格式转化(场景案例四)

文章目录 一、背景二、解决方案三、在线转换工具四、技术资料 一、背景 物联网平台是一种实现设备接入、设备监控、设备管理、数据存储、消息多源转发和数据分析等能力的一体化平台。南向支持连接海量异构(协议多样)设备,实现设备数据云端存…

前端开发攻略---在输入框中输入中文但是还没选中的时候,搜索事件依然存在;中文输入法导致的高频事件。

1、演示 解决前 解决后 2、输入框事件介绍 compositionstart事件在用户开始使用输入法输入时触发。这意味着用户正在进行组合输入,比如在中文输入法中,用户可能正在输入一个多个字符的词语。在这个阶段,输入框的内容可能还没有完全确定&#…

RocketMQ 10 面试题FAQ

RocketMQ 面试FAQ 说说你们公司线上生产环境用的是什么消息中间件? 为什么要使用MQ? 因为项目比较大,做了分布式系统,所有远程服务调用请求都是同步执行经常出问题,所以引入了mq 解耦 系统耦合度降低,没有强依赖…

Testng测试框架(2)-测试用例@Test

测试方法用 Test 进行注释,将类或方法标记为测试的一部分。 Test() public void aFastTest() {System.out.println("Fast test"); }import org.testng.annotations.Test;public class TestExample {Test(description "测试用例1")public void…

如何通过Python向PDF添加文本水印_python给pdf文件加文字水印

先自我介绍一下,小编浙江大学毕业,去过华为、字节跳动等大厂,目前阿里P7 深知大多数程序员,想要提升技能,往往是自己摸索成长,但自己不成体系的自学效果低效又漫长,而且极易碰到天花板技术停滞…

频率传感器信号采集隔离转换模拟信号0-1KHz/0-5KHz/0-10KH转0-2.5V/0-5V/0-10V/0-10mA/0-20mA/4-20mA

主要特性: >> 精度等级&#xff1a;0.2 级 >> 全量程内极高的线性度&#xff08;非线性度<0.1%&#xff09; >> 辅助电源/信号输入/信号输出&#xff1a; 2500VDC 三隔离 >> 辅助电源&#xff1a;5VDC&#xff0c;12VDC&#xff0c;24VDC 等单…

Redis Desktop Manager 中文--强大的Redis数据库管理工具

Redis Desktop Manager&#xff08;简称RDM&#xff09;是一款开源且功能强大的图形化Redis管理工具。它支持Windows、macOS和Linux等多平台&#xff0c;为Redis数据库提供了直观友好的管理界面。通过RDM&#xff0c;用户可以轻松连接多个Redis服务器&#xff0c;管理连接信息&…

【自媒体创作利器】AI白日梦+ChatGPT 三分钟生成爆款短视频

AI白日梦https://brmgo.com/signup?codey5no6idev 引言 随着人工智能&#xff08;AI&#xff09;技术的快速发展&#xff0c;AI在各个领域都展现出了强大的应用潜力。其中&#xff0c;自然语言处理技术的进步使得智能对话系统得以实现&#xff0c;而ChatGPT作为其中的代表之一…

MyBatis操作数据库(3)

其它查询操作 #{}和${} MyBatis参数赋值有两种方式, 咱们前面使用了#{}进行赋值, 接下来来看两者的区别: #{}和${}的使用 1.先看Integer类型的参数: Select("select username, password, age, gender, phone from userinfo where id #{id}") UserInfo queryByI…

攻防世界---easyRE1

1.下载附件&#xff0c;打开后有两个文件 2.对32查壳 3.对64查壳 4.IDA分析&#xff0c;这里打开之后找到main函数点击main函数后按f5 5.看到了flag----拿去提交发现是对的&#xff0c;这道题是逆向中最简单的一道了 flag{db2f62a36a018bce28e46d976e3f9864}

LeetCode501:二叉搜索树中的众数

给你一个含重复值的二叉搜索树&#xff08;BST&#xff09;的根节点 root &#xff0c;找出并返回 BST 中的所有 众数&#xff08;即&#xff0c;出现频率最高的元素&#xff09;。 如果树中有不止一个众数&#xff0c;可以按 任意顺序 返回。 假定 BST 满足如下定义&#xf…

STL —— priority_queue

博主首页&#xff1a; 有趣的中国人 专栏首页&#xff1a; C专栏 本篇文章主要讲解 priority_queue 的相关内容 目录 1. 优先级队列简介 基本操作 2. 模拟实现 2.1 入队操作 2.2 出队操作 2.3 访问队列顶部元素 2.4 判断优先队列是否为空 2.5 获取优先队列的大小 …

什么是One-Class SVM

1. 简介 单类支持向量机&#xff0c;简称One-Class SVM(One-Class Support Vector Machine)&#xff0c;用于异常检测和离群点检测(无监督学习&#xff0c;其他svm属于有监督的)&#xff0c;可以在没有大量异常样本的情况下有效地检测异常。其目标是通过仅使用正常数据来建模&a…

【力扣 Hot100 | 第四天】4.15(括号生成)

文章目录 4.括号生成4.1题目4.2解法&#xff1a;回溯4.2.1回溯思路&#xff08;1&#xff09;函数返回值以及参数&#xff08;2&#xff09;终止条件&#xff08;3&#xff09;遍历过程 4.2.2代码 4.括号生成 4.1题目 数字 n 代表生成括号的对数&#xff0c;请你设计一个函数…

三斜求积术 To 海伦公式 ← 三角形面积

【知识点&#xff1a;三斜求积术】 所谓秦九韶的三斜求积术&#xff0c;即如果已知三角形的边长a&#xff0c;b&#xff0c;c&#xff0c;可求得该三角形的面积为&#xff1a; 而由三斜求积术可推得海伦公式。过程如下&#xff1a; 其中&#xff0c; 上面推导公式的 Latex 代码…

​​​​网络编程探索系列之——广播原理剖析

hello &#xff01;大家好呀&#xff01; 欢迎大家来到我的网络编程系列之广播原理剖析&#xff0c;在这篇文章中&#xff0c; 你将会学习到如何在网络编程中利用广播来与局域网内加入某个特定广播组的主机&#xff01; 希望这篇文章能对你有所帮助&#xff0c;大家要是觉得我写…

从零开始写 Docker(十一)---实现 mydocker exec 进入容器内部

本文为从零开始写 Docker 系列第十一篇&#xff0c;实现类似 docker exec 的功能&#xff0c;使得我们能够进入到指定容器内部。 完整代码见&#xff1a;https://github.com/lixd/mydocker 欢迎 Star 推荐阅读以下文章对 docker 基本实现有一个大致认识&#xff1a; 核心原理&…